Skip to content

Commit

Permalink
Publish gosdk from d25ca33cca7
Browse files Browse the repository at this point in the history
  • Loading branch information
maratori committed Nov 7, 2024
1 parent dd2dd84 commit a7ca996
Show file tree
Hide file tree
Showing 530 changed files with 136,129 additions and 0 deletions.
17 changes: 17 additions & 0 deletions auth/authenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package auth

import (
"context"
)

// Authenticator is used in gRPC interceptors.
type Authenticator interface {
// Auth returns a [context.Context] with necessary grpc metadata for authorization.
Auth(context.Context) (context.Context, error)

// HandleError is called with a [context.Context] received from [Authenticator.Auth]
// and an error from a gRPC call if it has the Unauthenticated code.
// If HandleError returns nil, a new auth will be requested to retry the gRPC call.
// If the gRPC call should not be retried, HandleError must return the incoming error.
HandleError(context.Context, error) error
}
106 changes: 106 additions & 0 deletions auth/bearer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package auth

import (
"context"
"errors"
"time"

"google.golang.org/grpc/metadata"
)

const AuthorizationHeader = "Authorization"

type BearerToken struct {
Token string
ExpiresAt time.Time
}

type BearerTokener interface {
// BearerToken returns a token to be used in "Authorization" header.
BearerToken(context.Context) (BearerToken, error)

// HandleError is called with the [BearerToken] received from [BearerTokener.BearerToken]
// and an error from a gRPC call if it has the Unauthenticated code.
// If HandleError returns nil, a new auth will be requested to retry the gRPC call.
// If the gRPC call should not be retried, HandleError must return the incoming error.
HandleError(context.Context, BearerToken, error) error
}

// StaticBearerToken implement [BearerTokener] with constant token.
type StaticBearerToken string

var _ BearerTokener = StaticBearerToken("")

func (t StaticBearerToken) BearerToken(context.Context) (BearerToken, error) {
return BearerToken{
Token: string(t),
ExpiresAt: time.Time{},
}, nil
}

func (t StaticBearerToken) HandleError(_ context.Context, _ BearerToken, err error) error {
return err
}

type AuthenticatorFromBearerTokener struct {
tokener BearerTokener
}

var _ Authenticator = AuthenticatorFromBearerTokener{}

func NewAuthenticatorFromBearerTokener(tokener BearerTokener) AuthenticatorFromBearerTokener {
return AuthenticatorFromBearerTokener{
tokener: tokener,
}
}

func (a AuthenticatorFromBearerTokener) Auth(ctx context.Context) (context.Context, error) {
token, err := a.tokener.BearerToken(ctx)
if err != nil {
return nil, err
}
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = make(metadata.MD)
}
md.Set(AuthorizationHeader, "Bearer "+token.Token)
ctx = metadata.NewOutgoingContext(ctx, md)
ctx = context.WithValue(ctx, ctxKeyBearerToken{}, token)
return ctx, nil
}

func (a AuthenticatorFromBearerTokener) HandleError(ctx context.Context, err error) error {
token, ok := ctx.Value(ctxKeyBearerToken{}).(BearerToken)
if !ok {
return err
}
return a.tokener.HandleError(ctx, token, err)
}

type ctxKeyBearerToken struct{}

type PropagateAuthorizationHeader struct{}

var _ Authenticator = PropagateAuthorizationHeader{}

func NewPropagateAuthorizationHeader() PropagateAuthorizationHeader {
return PropagateAuthorizationHeader{}
}

func (PropagateAuthorizationHeader) Auth(ctx context.Context) (context.Context, error) {
incoming := metadata.ValueFromIncomingContext(ctx, AuthorizationHeader)
if len(incoming) == 0 {
return nil, errors.New("missing authorization header in the incoming metadata")
}
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = make(metadata.MD)
}
md.Set(AuthorizationHeader, incoming...)
ctx = metadata.NewOutgoingContext(ctx, md)
return ctx, nil
}

func (PropagateAuthorizationHeader) HandleError(_ context.Context, err error) error {
return err
}
272 changes: 272 additions & 0 deletions auth/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
package auth

import (
"context"
"errors"
"log/slog"
"math"
"sync"
"time"

"golang.org/x/sync/singleflight"
)

// CachedServiceTokener is a middleware that enhances the functionality of the [BearerTokener] by caching the token.
// It also automatically refreshes the token in the background before it expires, ensuring seamless authentication.
// Recommended parameters from IAM:
// - lifetime 90%
// - initial retry 1 second
// - max retry 1 minute
type CachedServiceTokener struct {
logger *slog.Logger
tokener BearerTokener
lifetime float64
initialRetry time.Duration
retryMultiplier float64
maxRetry time.Duration
ticker *time.Ticker
now func() time.Time
group singleflight.Group

mu sync.RWMutex
cache *BearerToken
refreshAt time.Time
retryCount int
}

var _ BearerTokener = (*CachedServiceTokener)(nil)

func NewCachedServiceTokener(
logger *slog.Logger,
tokener BearerTokener,
lifetime float64,
initialRetry time.Duration,
retryMultiplier float64,
maxRetry time.Duration,
) *CachedServiceTokener {
stoppedTicker := time.NewTicker(time.Minute)
stoppedTicker.Stop()
return &CachedServiceTokener{
logger: logger,
tokener: tokener,
lifetime: lifetime,
initialRetry: initialRetry,
retryMultiplier: retryMultiplier,
maxRetry: maxRetry,
ticker: stoppedTicker,
now: time.Now,
group: singleflight.Group{},

mu: sync.RWMutex{},
cache: nil,
refreshAt: time.Time{},
retryCount: 0,
}
}

func (c *CachedServiceTokener) Run(ctx context.Context) error { //nolint:gocognit
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.ticker.C:

token, refreshAt, retryCount := c.getToken()

if token == nil {
continue
}

if token.ExpiresAt.IsZero() {
continue
}

if refreshAt.IsZero() {
continue
}

timeLeft := refreshAt.Sub(c.now())
if timeLeft > 0 {
c.ticker.Reset(timeLeft)
continue
}

c.logger.InfoContext(
ctx,
"Token is about to expire; initiating background token refresh",
slog.Time("expires_at", token.ExpiresAt),
slog.Int("attempt", retryCount+1),
)

_, err := c.requestToken(ctx, true)
if err != nil && !errors.Is(err, context.Canceled) {
var retry time.Duration
if retryCount <= 0 || math.Abs(c.retryMultiplier-1) <= 1e-9 {
retry = c.initialRetry
} else {
mul := math.Pow(c.retryMultiplier, float64(retryCount))
retry = time.Duration(math.Max(float64(c.maxRetry), float64(c.initialRetry)*mul))
}
c.logger.ErrorContext(
ctx,
"Background token refresh failed",
slog.Any("error", err),
slog.Int("attempt", retryCount+1),
slog.Duration("next_attempt", retry),
)
c.ticker.Reset(retry)
}
}
}
}

func (c *CachedServiceTokener) BearerToken(ctx context.Context) (BearerToken, error) {
token, _, _ := c.getToken()
if token != nil {
return *token, nil
}

return c.requestToken(ctx, false)
}

func (c *CachedServiceTokener) HandleError(ctx context.Context, token BearerToken, err error) error {
if err == nil {
return nil
}

c.mu.Lock()
if c.cache != nil && c.cache.Token == token.Token {
c.cache = nil
c.refreshAt = time.Time{}
c.ticker.Stop()
}
c.mu.Unlock()

return c.tokener.HandleError(ctx, token, err)
}

func (c *CachedServiceTokener) getToken() (*BearerToken, time.Time, int) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.cache, c.refreshAt, c.retryCount
}

func (c *CachedServiceTokener) requestToken(ctx context.Context, background bool) (BearerToken, error) {
res, err, _ := c.group.Do("", func() (interface{}, error) {
var refreshAfter time.Duration

now := c.now()
token, err := c.tokener.BearerToken(ctx)
if err != nil {
if background {
c.mu.Lock()
c.retryCount++
c.mu.Unlock()
}
return nil, err
}

if !token.ExpiresAt.IsZero() {
ttl := token.ExpiresAt.Sub(now)
if ttl > 0 {
refreshAfter = time.Duration(float64(ttl) * c.lifetime)
} else {
c.logger.ErrorContext(
ctx,
"Received already expired token",
slog.Time("expires_at", token.ExpiresAt),
)
}
}

c.mu.Lock()
c.cache = &token
if refreshAfter > 0 {
c.refreshAt = now.Add(refreshAfter)
} else {
c.refreshAt = time.Time{}
}
c.retryCount = 0
c.mu.Unlock()

if refreshAfter > 0 {
c.ticker.Reset(refreshAfter)
}

return token, nil
})
if err != nil {
return BearerToken{}, err
}

return res.(BearerToken), nil
}

type CachedBearerTokener struct {
tokener BearerTokener
group singleflight.Group

mu sync.RWMutex
cache *BearerToken
}

var _ BearerTokener = (*CachedBearerTokener)(nil)

func NewCachedBearerTokener(tokener BearerTokener) *CachedBearerTokener {
return &CachedBearerTokener{
tokener: tokener,
group: singleflight.Group{},

mu: sync.RWMutex{},
cache: nil,
}
}

func (c *CachedBearerTokener) BearerToken(ctx context.Context) (BearerToken, error) {
token := c.getToken()
if token != nil {
return *token, nil
}

return c.requestToken(ctx)
}

func (c *CachedBearerTokener) HandleError(ctx context.Context, token BearerToken, err error) error {
if err == nil {
return nil
}

c.mu.Lock()
if c.cache != nil && c.cache.Token == token.Token {
c.cache = nil
}
c.mu.Unlock()

return c.tokener.HandleError(ctx, token, err)
}

func (c *CachedBearerTokener) getToken() *BearerToken {
c.mu.RLock()
defer c.mu.RUnlock()
return c.cache
}

func (c *CachedBearerTokener) requestToken(ctx context.Context) (BearerToken, error) {
res, err, _ := c.group.Do("", func() (interface{}, error) {
token, err := c.tokener.BearerToken(ctx)
if err != nil {
return nil, err
}

c.mu.Lock()
c.cache = &token
c.mu.Unlock()

return token, nil
})
if err != nil {
return BearerToken{}, err
}

return res.(BearerToken), nil
}
Loading

0 comments on commit a7ca996

Please sign in to comment.