-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
530 changed files
with
136,129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package auth | ||
|
||
import ( | ||
"context" | ||
) | ||
|
||
// Authenticator is used in gRPC interceptors. | ||
type Authenticator interface { | ||
// Auth returns a [context.Context] with necessary grpc metadata for authorization. | ||
Auth(context.Context) (context.Context, error) | ||
|
||
// HandleError is called with a [context.Context] received from [Authenticator.Auth] | ||
// and an error from a gRPC call if it has the Unauthenticated code. | ||
// If HandleError returns nil, a new auth will be requested to retry the gRPC call. | ||
// If the gRPC call should not be retried, HandleError must return the incoming error. | ||
HandleError(context.Context, error) error | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
package auth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"time" | ||
|
||
"google.golang.org/grpc/metadata" | ||
) | ||
|
||
const AuthorizationHeader = "Authorization" | ||
|
||
type BearerToken struct { | ||
Token string | ||
ExpiresAt time.Time | ||
} | ||
|
||
type BearerTokener interface { | ||
// BearerToken returns a token to be used in "Authorization" header. | ||
BearerToken(context.Context) (BearerToken, error) | ||
|
||
// HandleError is called with the [BearerToken] received from [BearerTokener.BearerToken] | ||
// and an error from a gRPC call if it has the Unauthenticated code. | ||
// If HandleError returns nil, a new auth will be requested to retry the gRPC call. | ||
// If the gRPC call should not be retried, HandleError must return the incoming error. | ||
HandleError(context.Context, BearerToken, error) error | ||
} | ||
|
||
// StaticBearerToken implement [BearerTokener] with constant token. | ||
type StaticBearerToken string | ||
|
||
var _ BearerTokener = StaticBearerToken("") | ||
|
||
func (t StaticBearerToken) BearerToken(context.Context) (BearerToken, error) { | ||
return BearerToken{ | ||
Token: string(t), | ||
ExpiresAt: time.Time{}, | ||
}, nil | ||
} | ||
|
||
func (t StaticBearerToken) HandleError(_ context.Context, _ BearerToken, err error) error { | ||
return err | ||
} | ||
|
||
type AuthenticatorFromBearerTokener struct { | ||
tokener BearerTokener | ||
} | ||
|
||
var _ Authenticator = AuthenticatorFromBearerTokener{} | ||
|
||
func NewAuthenticatorFromBearerTokener(tokener BearerTokener) AuthenticatorFromBearerTokener { | ||
return AuthenticatorFromBearerTokener{ | ||
tokener: tokener, | ||
} | ||
} | ||
|
||
func (a AuthenticatorFromBearerTokener) Auth(ctx context.Context) (context.Context, error) { | ||
token, err := a.tokener.BearerToken(ctx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
md, ok := metadata.FromOutgoingContext(ctx) | ||
if !ok { | ||
md = make(metadata.MD) | ||
} | ||
md.Set(AuthorizationHeader, "Bearer "+token.Token) | ||
ctx = metadata.NewOutgoingContext(ctx, md) | ||
ctx = context.WithValue(ctx, ctxKeyBearerToken{}, token) | ||
return ctx, nil | ||
} | ||
|
||
func (a AuthenticatorFromBearerTokener) HandleError(ctx context.Context, err error) error { | ||
token, ok := ctx.Value(ctxKeyBearerToken{}).(BearerToken) | ||
if !ok { | ||
return err | ||
} | ||
return a.tokener.HandleError(ctx, token, err) | ||
} | ||
|
||
type ctxKeyBearerToken struct{} | ||
|
||
type PropagateAuthorizationHeader struct{} | ||
|
||
var _ Authenticator = PropagateAuthorizationHeader{} | ||
|
||
func NewPropagateAuthorizationHeader() PropagateAuthorizationHeader { | ||
return PropagateAuthorizationHeader{} | ||
} | ||
|
||
func (PropagateAuthorizationHeader) Auth(ctx context.Context) (context.Context, error) { | ||
incoming := metadata.ValueFromIncomingContext(ctx, AuthorizationHeader) | ||
if len(incoming) == 0 { | ||
return nil, errors.New("missing authorization header in the incoming metadata") | ||
} | ||
md, ok := metadata.FromOutgoingContext(ctx) | ||
if !ok { | ||
md = make(metadata.MD) | ||
} | ||
md.Set(AuthorizationHeader, incoming...) | ||
ctx = metadata.NewOutgoingContext(ctx, md) | ||
return ctx, nil | ||
} | ||
|
||
func (PropagateAuthorizationHeader) HandleError(_ context.Context, err error) error { | ||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
package auth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"log/slog" | ||
"math" | ||
"sync" | ||
"time" | ||
|
||
"golang.org/x/sync/singleflight" | ||
) | ||
|
||
// CachedServiceTokener is a middleware that enhances the functionality of the [BearerTokener] by caching the token. | ||
// It also automatically refreshes the token in the background before it expires, ensuring seamless authentication. | ||
// Recommended parameters from IAM: | ||
// - lifetime 90% | ||
// - initial retry 1 second | ||
// - max retry 1 minute | ||
type CachedServiceTokener struct { | ||
logger *slog.Logger | ||
tokener BearerTokener | ||
lifetime float64 | ||
initialRetry time.Duration | ||
retryMultiplier float64 | ||
maxRetry time.Duration | ||
ticker *time.Ticker | ||
now func() time.Time | ||
group singleflight.Group | ||
|
||
mu sync.RWMutex | ||
cache *BearerToken | ||
refreshAt time.Time | ||
retryCount int | ||
} | ||
|
||
var _ BearerTokener = (*CachedServiceTokener)(nil) | ||
|
||
func NewCachedServiceTokener( | ||
logger *slog.Logger, | ||
tokener BearerTokener, | ||
lifetime float64, | ||
initialRetry time.Duration, | ||
retryMultiplier float64, | ||
maxRetry time.Duration, | ||
) *CachedServiceTokener { | ||
stoppedTicker := time.NewTicker(time.Minute) | ||
stoppedTicker.Stop() | ||
return &CachedServiceTokener{ | ||
logger: logger, | ||
tokener: tokener, | ||
lifetime: lifetime, | ||
initialRetry: initialRetry, | ||
retryMultiplier: retryMultiplier, | ||
maxRetry: maxRetry, | ||
ticker: stoppedTicker, | ||
now: time.Now, | ||
group: singleflight.Group{}, | ||
|
||
mu: sync.RWMutex{}, | ||
cache: nil, | ||
refreshAt: time.Time{}, | ||
retryCount: 0, | ||
} | ||
} | ||
|
||
func (c *CachedServiceTokener) Run(ctx context.Context) error { //nolint:gocognit | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case <-c.ticker.C: | ||
|
||
token, refreshAt, retryCount := c.getToken() | ||
|
||
if token == nil { | ||
continue | ||
} | ||
|
||
if token.ExpiresAt.IsZero() { | ||
continue | ||
} | ||
|
||
if refreshAt.IsZero() { | ||
continue | ||
} | ||
|
||
timeLeft := refreshAt.Sub(c.now()) | ||
if timeLeft > 0 { | ||
c.ticker.Reset(timeLeft) | ||
continue | ||
} | ||
|
||
c.logger.InfoContext( | ||
ctx, | ||
"Token is about to expire; initiating background token refresh", | ||
slog.Time("expires_at", token.ExpiresAt), | ||
slog.Int("attempt", retryCount+1), | ||
) | ||
|
||
_, err := c.requestToken(ctx, true) | ||
if err != nil && !errors.Is(err, context.Canceled) { | ||
var retry time.Duration | ||
if retryCount <= 0 || math.Abs(c.retryMultiplier-1) <= 1e-9 { | ||
retry = c.initialRetry | ||
} else { | ||
mul := math.Pow(c.retryMultiplier, float64(retryCount)) | ||
retry = time.Duration(math.Max(float64(c.maxRetry), float64(c.initialRetry)*mul)) | ||
} | ||
c.logger.ErrorContext( | ||
ctx, | ||
"Background token refresh failed", | ||
slog.Any("error", err), | ||
slog.Int("attempt", retryCount+1), | ||
slog.Duration("next_attempt", retry), | ||
) | ||
c.ticker.Reset(retry) | ||
} | ||
} | ||
} | ||
} | ||
|
||
func (c *CachedServiceTokener) BearerToken(ctx context.Context) (BearerToken, error) { | ||
token, _, _ := c.getToken() | ||
if token != nil { | ||
return *token, nil | ||
} | ||
|
||
return c.requestToken(ctx, false) | ||
} | ||
|
||
func (c *CachedServiceTokener) HandleError(ctx context.Context, token BearerToken, err error) error { | ||
if err == nil { | ||
return nil | ||
} | ||
|
||
c.mu.Lock() | ||
if c.cache != nil && c.cache.Token == token.Token { | ||
c.cache = nil | ||
c.refreshAt = time.Time{} | ||
c.ticker.Stop() | ||
} | ||
c.mu.Unlock() | ||
|
||
return c.tokener.HandleError(ctx, token, err) | ||
} | ||
|
||
func (c *CachedServiceTokener) getToken() (*BearerToken, time.Time, int) { | ||
c.mu.RLock() | ||
defer c.mu.RUnlock() | ||
return c.cache, c.refreshAt, c.retryCount | ||
} | ||
|
||
func (c *CachedServiceTokener) requestToken(ctx context.Context, background bool) (BearerToken, error) { | ||
res, err, _ := c.group.Do("", func() (interface{}, error) { | ||
var refreshAfter time.Duration | ||
|
||
now := c.now() | ||
token, err := c.tokener.BearerToken(ctx) | ||
if err != nil { | ||
if background { | ||
c.mu.Lock() | ||
c.retryCount++ | ||
c.mu.Unlock() | ||
} | ||
return nil, err | ||
} | ||
|
||
if !token.ExpiresAt.IsZero() { | ||
ttl := token.ExpiresAt.Sub(now) | ||
if ttl > 0 { | ||
refreshAfter = time.Duration(float64(ttl) * c.lifetime) | ||
} else { | ||
c.logger.ErrorContext( | ||
ctx, | ||
"Received already expired token", | ||
slog.Time("expires_at", token.ExpiresAt), | ||
) | ||
} | ||
} | ||
|
||
c.mu.Lock() | ||
c.cache = &token | ||
if refreshAfter > 0 { | ||
c.refreshAt = now.Add(refreshAfter) | ||
} else { | ||
c.refreshAt = time.Time{} | ||
} | ||
c.retryCount = 0 | ||
c.mu.Unlock() | ||
|
||
if refreshAfter > 0 { | ||
c.ticker.Reset(refreshAfter) | ||
} | ||
|
||
return token, nil | ||
}) | ||
if err != nil { | ||
return BearerToken{}, err | ||
} | ||
|
||
return res.(BearerToken), nil | ||
} | ||
|
||
type CachedBearerTokener struct { | ||
tokener BearerTokener | ||
group singleflight.Group | ||
|
||
mu sync.RWMutex | ||
cache *BearerToken | ||
} | ||
|
||
var _ BearerTokener = (*CachedBearerTokener)(nil) | ||
|
||
func NewCachedBearerTokener(tokener BearerTokener) *CachedBearerTokener { | ||
return &CachedBearerTokener{ | ||
tokener: tokener, | ||
group: singleflight.Group{}, | ||
|
||
mu: sync.RWMutex{}, | ||
cache: nil, | ||
} | ||
} | ||
|
||
func (c *CachedBearerTokener) BearerToken(ctx context.Context) (BearerToken, error) { | ||
token := c.getToken() | ||
if token != nil { | ||
return *token, nil | ||
} | ||
|
||
return c.requestToken(ctx) | ||
} | ||
|
||
func (c *CachedBearerTokener) HandleError(ctx context.Context, token BearerToken, err error) error { | ||
if err == nil { | ||
return nil | ||
} | ||
|
||
c.mu.Lock() | ||
if c.cache != nil && c.cache.Token == token.Token { | ||
c.cache = nil | ||
} | ||
c.mu.Unlock() | ||
|
||
return c.tokener.HandleError(ctx, token, err) | ||
} | ||
|
||
func (c *CachedBearerTokener) getToken() *BearerToken { | ||
c.mu.RLock() | ||
defer c.mu.RUnlock() | ||
return c.cache | ||
} | ||
|
||
func (c *CachedBearerTokener) requestToken(ctx context.Context) (BearerToken, error) { | ||
res, err, _ := c.group.Do("", func() (interface{}, error) { | ||
token, err := c.tokener.BearerToken(ctx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
c.mu.Lock() | ||
c.cache = &token | ||
c.mu.Unlock() | ||
|
||
return token, nil | ||
}) | ||
if err != nil { | ||
return BearerToken{}, err | ||
} | ||
|
||
return res.(BearerToken), nil | ||
} |
Oops, something went wrong.