diff --git a/sdk/storage/azblob/internal/shared/challenge_policy.go b/sdk/storage/azblob/internal/shared/challenge_policy.go index c5d9ca14d95a..e7c8e9213d80 100644 --- a/sdk/storage/azblob/internal/shared/challenge_policy.go +++ b/sdk/storage/azblob/internal/shared/challenge_policy.go @@ -31,10 +31,6 @@ func NewStorageChallengePolicy(cred azcore.TokenCredential) policy.Policy { } func (s *storageAuthorizer) onRequest(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error { - if len(s.scopes) == 0 || s.tenantID == "" { - // returning nil indicates the bearer token policy should send the request - return nil - } return authNZ(policy.TokenRequestOptions{Scopes: s.scopes}) } diff --git a/sdk/storage/azblob/internal/shared/challenge_policy_test.go b/sdk/storage/azblob/internal/shared/challenge_policy_test.go index 1032519ce619..f666947ef9c7 100644 --- a/sdk/storage/azblob/internal/shared/challenge_policy_test.go +++ b/sdk/storage/azblob/internal/shared/challenge_policy_test.go @@ -13,7 +13,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/stretchr/testify/require" - "net/http" "strings" "testing" "time" @@ -25,52 +24,69 @@ func (cf credentialFunc) GetToken(ctx context.Context, options policy.TokenReque return cf(ctx, options) } -func TestChallengePolicy(t *testing.T) { +func TestChallengePolicyStorage(t *testing.T) { accessToken := "***" - storageResource := "https://storage.azure.com" storageScope := "https://storage.azure.com/.default" - challenge := `Bearer authorization_uri="https://login.microsoftonline.com/{tenant}", resource_id="{storageResource}"` + + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithStatusCode(200), + ) + authenticated := false + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + authenticated = true + require.Equal(t, []string{storageScope}, tro.Scopes) + return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewStorageChallengePolicy(cred) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost") + require.NoError(t, err) + _, err = pl.Do(req) + require.NoError(t, err) + require.True(t, authenticated, "policy should have authenticated") +} + +func TestChallengePolicyDisk(t *testing.T) { + accessToken := "***" diskResource := "https://disk.azure.com/" diskScope := "https://disk.azure.com//.default" + challenge := `Bearer authorization_uri="https://login.microsoftonline.com/{tenant}", resource_id="{storageResource}"` - for _, test := range []struct { - expectedScope, format, resource string - }{ - {format: challenge, resource: storageResource, expectedScope: storageScope}, - {format: challenge, resource: diskResource, expectedScope: diskScope}, - } { - t.Run("", func(t *testing.T) { - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse( - mock.WithHeader("WWW-Authenticate", strings.ReplaceAll(test.format, "{storageResource}", test.resource)), - mock.WithStatusCode(401), - ) - srv.AppendResponse(mock.WithPredicate(func(r *http.Request) bool { - if authz := r.Header.Values("Authorization"); len(authz) != 1 || authz[0] != "Bearer "+accessToken { - t.Errorf(`unexpected Authorization "%s"`, authz) - } - return true - })) - srv.AppendResponse() - authenticated := false - cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { - authenticated = true - require.Equal(t, []string{test.expectedScope}, tro.Scopes) - return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil - }) - p := NewStorageChallengePolicy(cred) - pl := runtime.NewPipeline("", "", - runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, - &policy.ClientOptions{Transport: srv}, - ) - req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost") - require.NoError(t, err) - _, err = pl.Do(req) - require.NoError(t, err) - require.True(t, authenticated, "policy should have authenticated") - }) - } + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", strings.ReplaceAll(challenge, "{storageResource}", diskResource)), + mock.WithStatusCode(401), + ) + srv.AppendResponse( + mock.WithStatusCode(200), + ) + attemptedAuthentication := false + authenticated := false + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + if attemptedAuthentication { + authenticated = true + require.Equal(t, []string{diskScope}, tro.Scopes) + return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil + } + attemptedAuthentication = true + return azcore.AccessToken{}, nil + }) + p := NewStorageChallengePolicy(cred) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost") + require.NoError(t, err) + _, err = pl.Do(req) + require.NoError(t, err) + require.True(t, authenticated, "policy should have authenticated") } func TestParseTenant(t *testing.T) {