From b21c5802ffc6686ab3ac64e8b950d9bb4c16db59 Mon Sep 17 00:00:00 2001 From: Matt Merkes Date: Mon, 8 May 2023 05:46:50 +0000 Subject: [PATCH] [feat] ecr-credential-provider support to authenticate public registries --- cmd/ecr-credential-provider/main.go | 100 +++++++++++++++++++-- cmd/ecr-credential-provider/main_test.go | 109 ++++++++++++++++++++++- pkg/providers/v2/mocks/mock_ecr.go | 54 +++++++++-- 3 files changed, 244 insertions(+), 19 deletions(-) diff --git a/cmd/ecr-credential-provider/main.go b/cmd/ecr-credential-provider/main.go index b4e1a4aa51..d000a3b299 100644 --- a/cmd/ecr-credential-provider/main.go +++ b/cmd/ecr-credential-provider/main.go @@ -30,12 +30,16 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecr" + "github.com/aws/aws-sdk-go/service/ecrpublic" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/klog/v2" "k8s.io/kubelet/pkg/apis/credentialprovider/v1" ) +const ecrPublicRegion string = "us-east-1" +const ecrPublicURL string = "public.ecr.aws" + var ecrPattern = regexp.MustCompile(`^(\d{12})\.dkr\.ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.(amazonaws\.com(\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov)$`) // ECR abstracts the calls we make to aws-sdk for testing purposes @@ -43,11 +47,17 @@ type ECR interface { GetAuthorizationToken(input *ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error) } +// ECRPublic abstracts the calls we make to aws-sdk for testing purposes +type ECRPublic interface { + GetAuthorizationToken(input *ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error) +} + type ecrPlugin struct { - ecr ECR + ecr ECR + ecrPublic ECRPublic } -func defaultECRProvider(region string, registryID string) (*ecr.ECR, error) { +func defaultECRProvider(region string) (*ecr.ECR, error) { sess, err := session.NewSessionWithOptions(session.Options{ Config: aws.Config{Region: aws.String(region)}, SharedConfigState: session.SharedConfigEnable, @@ -59,14 +69,66 @@ func defaultECRProvider(region string, registryID string) (*ecr.ECR, error) { return ecr.New(sess), nil } -func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []string) (*v1.CredentialProviderResponse, error) { +func publicECRProvider() (*ecrpublic.ECRPublic, error) { + // ECR public registries are only in one region and only accessible from regions + // in the "aws" partition. + sess, err := session.NewSessionWithOptions(session.Options{ + Config: aws.Config{Region: aws.String(ecrPublicRegion)}, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + return ecrpublic.New(sess), nil +} + +type credsData struct { + registry string + authToken *string + expiresAt *time.Time +} + +func (e *ecrPlugin) getPublicCredsData() (*credsData, error) { + klog.Infof("Getting creds for public registry") + var err error + + if e.ecrPublic == nil { + e.ecrPublic, err = publicECRProvider() + } + if err != nil { + return nil, err + } + + output, err := e.ecrPublic.GetAuthorizationToken(&ecrpublic.GetAuthorizationTokenInput{}) + if err != nil { + return nil, err + } + + if output == nil { + return nil, errors.New("response output from ECR was nil") + } + + if output.AuthorizationData == nil { + return nil, errors.New("authorization data was empty") + } + + return &credsData{ + registry: ecrPublicURL, + authToken: output.AuthorizationData.AuthorizationToken, + expiresAt: output.AuthorizationData.ExpiresAt, + }, nil +} + +func (e *ecrPlugin) getPrivateCredsData(image string) (*credsData, error) { + klog.Infof("Getting creds for private registry %s", image) registryID, region, registry, err := parseRepoURL(image) if err != nil { return nil, err } if e.ecr == nil { - e.ecr, err = defaultECRProvider(region, registryID) + e.ecr, err = defaultECRProvider(region) if err != nil { return nil, err } @@ -87,12 +149,32 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []str return nil, errors.New("authorization data was empty") } - data := output.AuthorizationData[0] - if data.AuthorizationToken == nil { + return &credsData{ + registry: registry, + authToken: output.AuthorizationData[0].AuthorizationToken, + expiresAt: output.AuthorizationData[0].ExpiresAt, + }, nil +} + +func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []string) (*v1.CredentialProviderResponse, error) { + var creds *credsData + var err error + + if strings.Contains(image, ecrPublicURL) { + creds, err = e.getPublicCredsData() + } else { + creds, err = e.getPrivateCredsData(image) + } + + if err != nil { + return nil, err + } + + if creds.authToken == nil { return nil, errors.New("authorization token in response was nil") } - decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(data.AuthorizationToken)) + decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(creds.authToken)) if err != nil { return nil, err } @@ -102,13 +184,13 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []str return nil, errors.New("error parsing username and password from authorization token") } - cacheDuration := getCacheDuration(data.ExpiresAt) + cacheDuration := getCacheDuration(creds.expiresAt) return &v1.CredentialProviderResponse{ CacheKeyType: v1.RegistryPluginCacheKeyType, CacheDuration: cacheDuration, Auth: map[string]v1.AuthConfig{ - registry: { + creds.registry: { Username: parts[0], Password: parts[1], }, diff --git a/cmd/ecr-credential-provider/main_test.go b/cmd/ecr-credential-provider/main_test.go index 6d16c5f971..b52a3821d0 100644 --- a/cmd/ecr-credential-provider/main_test.go +++ b/cmd/ecr-credential-provider/main_test.go @@ -26,13 +26,14 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecr" + "github.com/aws/aws-sdk-go/service/ecrpublic" "github.com/golang/mock/gomock" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/cloud-provider-aws/pkg/providers/v2/mocks" "k8s.io/kubelet/pkg/apis/credentialprovider/v1" ) -func generateGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecr.GetAuthorizationTokenOutput { +func generatePrivateGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecr.GetAuthorizationTokenOutput { creds := []byte(fmt.Sprintf("%s:%s", user, password)) data := &ecr.AuthorizationData{ AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString(creds)), @@ -58,7 +59,7 @@ func generateResponse(registry string, username string, password string) *v1.Cre } } -func Test_GetCredentials(t *testing.T) { +func Test_GetCredentials_Private(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -76,7 +77,7 @@ func Test_GetCredentials(t *testing.T) { { name: "success", image: "123456789123.dkr.ecr.us-west-2.amazonaws.com", - getAuthorizationTokenOutput: generateGetAuthorizationTokenOutput("user", "pass", "", nil), + getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil), response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"), }, { @@ -148,6 +149,108 @@ func Test_GetCredentials(t *testing.T) { } } +func generatePublicGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecrpublic.GetAuthorizationTokenOutput { + creds := []byte(fmt.Sprintf("%s:%s", user, password)) + data := &ecrpublic.AuthorizationData{ + AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString(creds)), + ExpiresAt: expiration, + } + output := &ecrpublic.GetAuthorizationTokenOutput{ + AuthorizationData: data, + } + return output +} + +func Test_GetCredentials_Public(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockECRPublic := mocks.NewMockECRPublic(ctrl) + + testcases := []struct { + name string + image string + args []string + getAuthorizationTokenOutput *ecrpublic.GetAuthorizationTokenOutput + getAuthorizationTokenError error + response *v1.CredentialProviderResponse + expectedError error + }{ + { + name: "success", + image: "public.ecr.aws", + getAuthorizationTokenOutput: generatePublicGetAuthorizationTokenOutput("user", "pass", "", nil), + response: generateResponse("public.ecr.aws", "user", "pass"), + }, + { + name: "empty authorization data", + image: "public.ecr.aws", + getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{}, + getAuthorizationTokenError: nil, + expectedError: errors.New("authorization data was empty"), + }, + { + name: "nil response", + image: "public.ecr.aws", + getAuthorizationTokenOutput: nil, + getAuthorizationTokenError: nil, + expectedError: errors.New("response output from ECR was nil"), + }, + { + name: "empty authorization token", + image: "public.ecr.aws", + getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{AuthorizationData: &ecrpublic.AuthorizationData{}}, + getAuthorizationTokenError: nil, + expectedError: errors.New("authorization token in response was nil"), + }, + { + name: "invalid authorization token", + image: "public.ecr.aws", + getAuthorizationTokenOutput: nil, + getAuthorizationTokenError: errors.New("getAuthorizationToken failed"), + expectedError: errors.New("getAuthorizationToken failed"), + }, + { + name: "invalid authorization token", + image: "public.ecr.aws", + getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{ + AuthorizationData: &ecrpublic.AuthorizationData{ + AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(fmt.Sprint("foo")))), + }, + }, + getAuthorizationTokenError: nil, + expectedError: errors.New("error parsing username and password from authorization token"), + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + p := &ecrPlugin{ecrPublic: mockECRPublic} + mockECRPublic.EXPECT().GetAuthorizationToken(gomock.Any()).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError) + + creds, err := p.GetCredentials(context.TODO(), testcase.image, testcase.args) + + if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) { + t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error()) + } + + if testcase.expectedError == nil { + if creds.CacheKeyType != testcase.response.CacheKeyType { + t.Fatalf("Unexpected CacheKeyType. Expected: %s, got: %s", testcase.response.CacheKeyType, creds.CacheKeyType) + } + + if creds.Auth[testcase.image] != testcase.response.Auth[testcase.image] { + t.Fatalf("Unexpected Auth. Expected: %s, got: %s", testcase.response.Auth[testcase.image], creds.Auth[testcase.image]) + } + + if creds.CacheDuration.Duration != testcase.response.CacheDuration.Duration { + t.Fatalf("Unexpected CacheDuration. Expected: %s, got: %s", testcase.response.CacheDuration.Duration, creds.CacheDuration.Duration) + } + } + }) + } +} + func Test_ParseURL(t *testing.T) { testcases := []struct { name string diff --git a/pkg/providers/v2/mocks/mock_ecr.go b/pkg/providers/v2/mocks/mock_ecr.go index a4d6a8b2d3..8a5b7445ab 100644 --- a/pkg/providers/v2/mocks/mock_ecr.go +++ b/pkg/providers/v2/mocks/mock_ecr.go @@ -5,35 +5,37 @@ package mocks import ( + reflect "reflect" + ecr "github.com/aws/aws-sdk-go/service/ecr" + ecrpublic "github.com/aws/aws-sdk-go/service/ecrpublic" gomock "github.com/golang/mock/gomock" - reflect "reflect" ) -// MockECR is a mock of ECR interface +// MockECR is a mock of ECR interface. type MockECR struct { ctrl *gomock.Controller recorder *MockECRMockRecorder } -// MockECRMockRecorder is the mock recorder for MockECR +// MockECRMockRecorder is the mock recorder for MockECR. type MockECRMockRecorder struct { mock *MockECR } -// NewMockECR creates a new mock instance +// NewMockECR creates a new mock instance. func NewMockECR(ctrl *gomock.Controller) *MockECR { mock := &MockECR{ctrl: ctrl} mock.recorder = &MockECRMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockECR) EXPECT() *MockECRMockRecorder { return m.recorder } -// GetAuthorizationToken mocks base method +// GetAuthorizationToken mocks base method. func (m *MockECR) GetAuthorizationToken(input *ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAuthorizationToken", input) @@ -42,8 +44,46 @@ func (m *MockECR) GetAuthorizationToken(input *ecr.GetAuthorizationTokenInput) ( return ret0, ret1 } -// GetAuthorizationToken indicates an expected call of GetAuthorizationToken +// GetAuthorizationToken indicates an expected call of GetAuthorizationToken. func (mr *MockECRMockRecorder) GetAuthorizationToken(input interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationToken", reflect.TypeOf((*MockECR)(nil).GetAuthorizationToken), input) } + +// MockECRPublic is a mock of ECRPublic interface. +type MockECRPublic struct { + ctrl *gomock.Controller + recorder *MockECRPublicMockRecorder +} + +// MockECRPublicMockRecorder is the mock recorder for MockECRPublic. +type MockECRPublicMockRecorder struct { + mock *MockECRPublic +} + +// NewMockECRPublic creates a new mock instance. +func NewMockECRPublic(ctrl *gomock.Controller) *MockECRPublic { + mock := &MockECRPublic{ctrl: ctrl} + mock.recorder = &MockECRPublicMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockECRPublic) EXPECT() *MockECRPublicMockRecorder { + return m.recorder +} + +// GetAuthorizationToken mocks base method. +func (m *MockECRPublic) GetAuthorizationToken(input *ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthorizationToken", input) + ret0, _ := ret[0].(*ecrpublic.GetAuthorizationTokenOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAuthorizationToken indicates an expected call of GetAuthorizationToken. +func (mr *MockECRPublicMockRecorder) GetAuthorizationToken(input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationToken", reflect.TypeOf((*MockECRPublic)(nil).GetAuthorizationToken), input) +}