Skip to content

Commit

Permalink
[COR-1114] Fix token validity check logic to use exp field in access …
Browse files Browse the repository at this point in the history
…token (#330)

* Add logs for token

* add logs

* Fixing the validity check logic for token

* nit

* nit

* Adding in memory token source provider

* nit

* changed Valid method to log and ignore parseDateClaim error

* nit

* Fix unit tests

* lint

* fix unit tests
  • Loading branch information
pmahindrakar-oss authored Jul 2, 2024
1 parent 721c9a8 commit fa58ba1
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 92 deletions.
56 changes: 45 additions & 11 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/status"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand All @@ -33,13 +34,9 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T
return fmt.Errorf("failed to initialized token source provider. Err: %w", err)
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
authorizationMetadataKey, err := getAuthMetadataKey(ctx, cfg, authMetadataClient)
if err != nil {
return err
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
Expand All @@ -58,6 +55,40 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T
return nil
}

// getAuthMetadataKey return the authorization metadata key used for api calls.
func getAuthMetadataKey(ctx context.Context, cfg *Config, authMetadataClient service.AuthMetadataServiceClient) (string, error) {
authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return "", fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
}
return authorizationMetadataKey, nil
}

// MaterializeInMemoryCredentials initializes the perRPCCredentials with the token source containing in memory cached token.
// This path doesn't perform the token refresh and only build the cred source with cached token.
func MaterializeInMemoryCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache,
perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}
authorizationMetadataKey, err := getAuthMetadataKey(ctx, cfg, authMetadataClient)
if err != nil {
return err
}
tokenSource, err := NewInMemoryTokenSourceProvider(tokenCache).GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}
wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)
return nil
}

func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand)
if err != nil {
Expand Down Expand Up @@ -145,9 +176,11 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
// If there is already a token in the cache (e.g. key-ring), we should use it immediately...
t, _ := tokenCache.GetToken()
if t != nil {
err := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
if isValid := utils.Valid(t); isValid {
err := MaterializeInMemoryCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}
}
}

Expand Down Expand Up @@ -186,7 +219,8 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
return err
}

return invoker(ctx, method, req, reply, cc, opts...)
err = invoker(ctx, method, req, reply, cc, opts...)

}
}
}
Expand Down
11 changes: 4 additions & 7 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -24,6 +23,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -136,15 +136,12 @@ func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl servic
}

func Test_newAuthInterceptor(t *testing.T) {
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute))
t.Run("Other Error", func(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
mockTokenCache := &mocks.TokenCache{}
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
interceptor := NewAuthInterceptor(&Config{}, mockTokenCache, f, p)
otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
Expand Down
23 changes: 7 additions & 16 deletions flyteidl/clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"testing"
"time"

Expand All @@ -24,6 +21,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -231,15 +229,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
RedirectUri: "http://localhost:54545/callback",
}
http.DefaultServeMux = http.NewServeMux()
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData.Expiry = time.Now().Add(time.Minute)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(time.Minute))
t.Run("cache hit", func(t *testing.T) {
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil)
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
Expand All @@ -249,11 +243,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
tokenData.Expiry = time.Now().Add(-time.Minute)
t.Run("cache miss auth failure", func(t *testing.T) {
tokenData = utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockTokenCache.On("Lock").Return()
mockTokenCache.On("Unlock").Return()
Expand Down Expand Up @@ -284,14 +278,11 @@ func Test_getPkceAuthTokenSource(t *testing.T) {
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)

t.Run("cached token expired", func(t *testing.T) {
plan, _ := ioutil.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))

// populate the cache
tokenCache := cache.NewTokenCacheInMemoryProvider()
assert.NoError(t, tokenCache.SaveToken(&tokenData))
assert.NoError(t, tokenCache.SaveToken(tokenData))

baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Expand Down
47 changes: 39 additions & 8 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/externalprocess"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand Down Expand Up @@ -228,29 +229,36 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()

if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Warnf(s.ctx, "failed to get token from cache: %v", err)
} else {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}
}

totalAttempts := s.cfg.MaxRetries + 1 // Add one for initial request attempt
backoff := wait.Backoff{
Duration: s.cfg.PerRetryTimeout.Duration,
Steps: totalAttempts,
}
var token *oauth2.Token
err := retry.OnError(backoff, func(err error) bool {

err = retry.OnError(backoff, func(err error) bool {
return err != nil
}, func() (err error) {
token, err = s.new.Token()
if err != nil {
logger.Infof(s.ctx, "failed to get token: %w", err)
return fmt.Errorf("failed to get token: %w", err)
logger.Infof(s.ctx, "failed to get new token: %w", err)
return fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(context.Background(), "Fetched new token with expiry %v", token.Expiry)
return nil
})
if err != nil {
logger.Warnf(s.ctx, "failed to get token: %v", err)
return nil, fmt.Errorf("failed to get token: %w", err)
logger.Warnf(s.ctx, "failed to get new token: %v", err)
return nil, fmt.Errorf("failed to get new token: %w", err)
}

logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)
Expand All @@ -263,6 +271,29 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
return token, nil
}

type InMemoryTokenSourceProvider struct {
tokenCache cache.TokenCache
}

func NewInMemoryTokenSourceProvider(tokenCache cache.TokenCache) TokenSourceProvider {
return InMemoryTokenSourceProvider{tokenCache: tokenCache}
}

func (i InMemoryTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return GetInMemoryAuthTokenSource(ctx, i.tokenCache)
}

// GetInMemoryAuthTokenSource Returns the token source with cached token
func GetInMemoryAuthTokenSource(ctx context.Context, tokenCache cache.TokenCache) (oauth2.TokenSource, error) {
authToken, err := tokenCache.GetToken()
if err != nil {
return nil, err
}
return &pkce.SimpleTokenSource{
CachedToken: authToken,
}, nil
}

type DeviceFlowTokenSourceProvider struct {
tokenOrchestrator deviceflow.TokenOrchestrator
}
Expand Down
25 changes: 13 additions & 12 deletions flyteidl/clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

tokenCacheMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
)

Expand Down Expand Up @@ -88,9 +89,9 @@ func TestCustomTokenSource_Token(t *testing.T) {
minuteAgo := time.Now().Add(-time.Minute)
hourAhead := time.Now().Add(time.Hour)
twoHourAhead := time.Now().Add(2 * time.Hour)
invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo}
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}
newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead}
invalidToken := utils.GenTokenWithCustomExpiry(t, minuteAgo)
validToken := utils.GenTokenWithCustomExpiry(t, hourAhead)
newToken := utils.GenTokenWithCustomExpiry(t, twoHourAhead)

tests := []struct {
name string
Expand All @@ -101,24 +102,24 @@ func TestCustomTokenSource_Token(t *testing.T) {
{
name: "no cached token",
token: nil,
newToken: &newToken,
expectedToken: &newToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "cached token valid",
token: &validToken,
token: validToken,
newToken: nil,
expectedToken: &validToken,
expectedToken: validToken,
},
{
name: "cached token expired",
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
token: invalidToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "failed new token",
token: &invalidToken,
token: invalidToken,
newToken: nil,
expectedToken: nil,
},
Expand All @@ -138,7 +139,7 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.True(t, ok)

mockSource := &adminMocks.TokenSource{}
if test.token != &validToken {
if test.token != validToken {
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -52,7 +53,8 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex
return nil, err
}

if token.Valid() {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}

Expand Down
Loading

0 comments on commit fa58ba1

Please sign in to comment.