Skip to content

Commit

Permalink
fix: enforce host checking before exchanging a refresh token (#2069) (#…
Browse files Browse the repository at this point in the history
…2071)

Signed-off-by: Binbin Li <[email protected]>
  • Loading branch information
binbin-li authored Jan 27, 2025
1 parent 96117d9 commit 6d62a55
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 128 deletions.
247 changes: 124 additions & 123 deletions charts/ratify/README.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions charts/ratify/templates/store.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ spec:
authProvider:
name: azureWorkloadIdentity
clientID: {{ .Values.azureWorkloadIdentity.clientId }}
endpoints:
{{- toYaml .Values.oras.authProviders.azureContainerRegistryEndpoints | nindent 8 }}
{{- end }}
{{- if .Values.oras.authProviders.azureManagedIdentityEnabled }}
authProvider:
name: azureManagedIdentity
clientID: {{ .Values.azureManagedIdentity.clientId }}
endpoints:
{{- toYaml .Values.oras.authProviders.azureContainerRegistryEndpoints | nindent 8 }}
{{- end }}
{{- if .Values.oras.authProviders.k8secretsEnabled }}
authProvider:
Expand Down
1 change: 1 addition & 0 deletions charts/ratify/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ oras:
authProviders:
azureWorkloadIdentityEnabled: false
azureManagedIdentityEnabled: false
azureContainerRegistryEndpoints: []
k8secretsEnabled: false
awsEcrBasicEnabled: false
awsApiOverride:
Expand Down
16 changes: 13 additions & 3 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ type MIAuthProvider struct {
authClientFactory AuthClientFactory
registryHostGetter RegistryHostGetter
getManagedIdentityToken ManagedIdentityTokenGetter
endpoints []string
}

type azureManagedIdentityAuthProviderConf struct {
Name string `json:"name"`
ClientID string `json:"clientID"`
Name string `json:"name"`
ClientID string `json:"clientID"`
Endpoints []string `json:"endpoints,omitempty"`
}

const (
Expand Down Expand Up @@ -106,9 +108,12 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
return nil, re.ErrorCodeEnvNotSet.WithDetail("AZURE_CLIENT_ID environment variable is empty").WithComponentType(re.AuthProvider)
}
}

endpoints, err := parseEndpoints(conf.Endpoints)
if err != nil {
return nil, err
return nil, re.ErrorCodeConfigInvalid.WithError(err)
}

// retrieve an AAD Access token
token, err := getManagedIdentityToken(context.Background(), client, azidentity.NewManagedIdentityCredential)
if err != nil {
Expand All @@ -121,6 +126,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
tenantID: tenant,
authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation
getManagedIdentityToken: &defaultManagedIdentityTokenGetterImpl{}, // Concrete implementation
endpoints: endpoints,
}, nil
}

Expand Down Expand Up @@ -155,6 +161,10 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}

if err := validateHost(artifactHostName, d.endpoints); err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithError(err)
}

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) {
newToken, err := d.getManagedIdentityToken.GetManagedIdentityToken(ctx, d.clientID)
Expand Down
2 changes: 2 additions & 0 deletions pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -200,6 +201,7 @@ func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down
16 changes: 14 additions & 2 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ type WIAuthProvider struct {
registryHostGetter RegistryHostGetter
getAADAccessToken AADAccessTokenGetter
reportMetrics MetricsReporter
endpoints []string
}

type azureWIAuthProviderConf struct {
Name string `json:"name"`
ClientID string `json:"clientID,omitempty"`
Name string `json:"name"`
ClientID string `json:"clientID,omitempty"`
Endpoints []string `json:"endpoints,omitempty"`
}

const (
Expand Down Expand Up @@ -113,6 +115,11 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
}
}

endpoints, err := parseEndpoints(conf.Endpoints)
if err != nil {
return nil, re.ErrorCodeConfigInvalid.WithError(err)
}

// retrieve an AAD Access token
token, err := defaultGetAADAccessToken(context.Background(), tenant, clientID, AADResource)
if err != nil {
Expand All @@ -127,6 +134,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
registryHostGetter: &defaultRegistryHostGetterImpl{}, // Concrete implementation
getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation
reportMetrics: &defaultMetricsReporterImpl{},
endpoints: endpoints,
}, nil
}

Expand Down Expand Up @@ -157,6 +165,10 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}

if err := validateHost(artifactHostName, d.endpoints); err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithError(err)
}

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) {
newToken, err := d.getAADAccessToken.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -126,6 +127,7 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -161,6 +163,7 @@ func TestWIAuthProvider_Provide_AADTokenFailure(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -238,6 +241,7 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -273,6 +277,7 @@ func TestWIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down
1 change: 1 addition & 0 deletions pkg/common/oras/authprovider/azure/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
dockerTokenLoginUsernameGUID = "00000000-0000-0000-0000-000000000000"
AADResource = "https://containerregistry.azure.net/.default"
defaultACRExpiryDuration time.Duration = 3 * time.Hour
defaultACREndpoint = "*.azurecr.io"
)

var logOpt = logger.Option{
Expand Down
53 changes: 53 additions & 0 deletions pkg/common/oras/authprovider/azure/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package azure

import (
"context"
"fmt"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider"
Expand Down Expand Up @@ -82,3 +84,54 @@ type defaultRegistryHostGetterImpl struct{}
func (g *defaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) {
return provider.GetRegistryHostName(artifact)
}

// parseEndpoints checks if the endpoints are valid for auth provider. If no
// endpoints are provided, it defaults to the default ACR endpoint.
// A valid endpoint is either a fully qualified domain name or a wildcard domain
// name folloiwing RFC 1034.
// Valid examples:
// - *.example.com
// - example.com
//
// Invalid examples:
// - *
// - example.*
// - *example.com.*
// - *.
func parseEndpoints(endpoints []string) ([]string, error) {
if len(endpoints) == 0 {
return []string{defaultACREndpoint}, nil
}
for _, endpoint := range endpoints {
switch strings.Count(endpoint, "*") {
case 0:
continue
case 1:
if !strings.HasPrefix(endpoint, "*.") {
return nil, fmt.Errorf("invalid wildcard domain name: %s, it must start with '*.'", endpoint)
}
if len(endpoint) < 3 {
return nil, fmt.Errorf("invalid wildcard domain name: %s, it must have at least one character after '*.'", endpoint)
}
default:
return nil, fmt.Errorf("invalid wildcard domain name: %s, it must have at most one wildcard character", endpoint)
}
}
return endpoints, nil
}

// validateHost checks if the host is matching endpoints supported by the auth
// provider.
func validateHost(host string, endpoints []string) error {
for _, endpoint := range endpoints {
if endpoint[0] == '*' {
if _, zone, ok := strings.Cut(host, "."); ok && zone == endpoint[2:] {
return nil
}
}
if host == endpoint {
return nil
}
}
return fmt.Errorf("the artifact host %s is not in the scope of the store auth provider", host)
}
90 changes: 90 additions & 0 deletions pkg/common/oras/authprovider/azure/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,93 @@ func TestAuthenticationClientWrapper_ExchangeAADAccessTokenForACRRefreshToken(t
_, err := wrapper.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options)
assert.Nil(t, err)
}

func TestValidateEndpoints(t *testing.T) {
tests := []struct {
name string
endpoint string
expectedErr bool
}{
{
name: "global wildcard",
endpoint: "*",
expectedErr: true,
},
{
name: "multiple wildcard",
endpoint: "*.example.*",
expectedErr: true,
},
{
name: "no subdomain",
endpoint: "*.",
expectedErr: true,
},
{
name: "full qualified domain",
endpoint: "example.com",
expectedErr: false,
},
{
name: "valid wildcard domain",
endpoint: "*.example.com",
expectedErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parseEndpoints([]string{tt.endpoint})
if tt.expectedErr != (err != nil) {
t.Fatalf("expected error: %v, got error: %v", tt.expectedErr, err)
}
})
}
}

func TestValidateHost(t *testing.T) {
endpoints := []string{
"*.azurecr.io",
"example.azurecr.io",
}
tests := []struct {
name string
host string
expectedErr bool
}{
{
name: "empty host",
host: "",
expectedErr: true,
},
{
name: "valid host",
host: "example.azurecr.io",
expectedErr: false,
},
{
name: "no subdomain",
host: "azurecr.io",
expectedErr: true,
},
{
name: "multiple subdomains",
host: "example.test.azurecr.io",
expectedErr: true,
},
{
name: "matched host",
host: "test.azurecr.io",
expectedErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateHost(tt.host, endpoints)
if tt.expectedErr != (err != nil) {
t.Fatalf("expected error: %v, got error: %v", tt.expectedErr, err)
}
})
}
}

0 comments on commit 6d62a55

Please sign in to comment.