Skip to content

Commit

Permalink
Merge pull request #603 from mmerkes/pubecr
Browse files Browse the repository at this point in the history
[feat] ecr-credential-provider support to authenticate public registries
  • Loading branch information
k8s-ci-robot authored May 8, 2023
2 parents 52888ea + b21c580 commit b691271
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 19 deletions.
100 changes: 91 additions & 9 deletions cmd/ecr-credential-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,34 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecrpublic"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/klog/v2"
"k8s.io/kubelet/pkg/apis/credentialprovider/v1"
)

const ecrPublicRegion string = "us-east-1"
const ecrPublicURL string = "public.ecr.aws"

var ecrPattern = regexp.MustCompile(`^(\d{12})\.dkr\.ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.(amazonaws\.com(\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov)$`)

// ECR abstracts the calls we make to aws-sdk for testing purposes
type ECR interface {
GetAuthorizationToken(input *ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error)
}

// ECRPublic abstracts the calls we make to aws-sdk for testing purposes
type ECRPublic interface {
GetAuthorizationToken(input *ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error)
}

type ecrPlugin struct {
ecr ECR
ecr ECR
ecrPublic ECRPublic
}

func defaultECRProvider(region string, registryID string) (*ecr.ECR, error) {
func defaultECRProvider(region string) (*ecr.ECR, error) {
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(region)},
SharedConfigState: session.SharedConfigEnable,
Expand All @@ -59,14 +69,66 @@ func defaultECRProvider(region string, registryID string) (*ecr.ECR, error) {
return ecr.New(sess), nil
}

func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []string) (*v1.CredentialProviderResponse, error) {
func publicECRProvider() (*ecrpublic.ECRPublic, error) {
// ECR public registries are only in one region and only accessible from regions
// in the "aws" partition.
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(ecrPublicRegion)},
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, err
}

return ecrpublic.New(sess), nil
}

type credsData struct {
registry string
authToken *string
expiresAt *time.Time
}

func (e *ecrPlugin) getPublicCredsData() (*credsData, error) {
klog.Infof("Getting creds for public registry")
var err error

if e.ecrPublic == nil {
e.ecrPublic, err = publicECRProvider()
}
if err != nil {
return nil, err
}

output, err := e.ecrPublic.GetAuthorizationToken(&ecrpublic.GetAuthorizationTokenInput{})
if err != nil {
return nil, err
}

if output == nil {
return nil, errors.New("response output from ECR was nil")
}

if output.AuthorizationData == nil {
return nil, errors.New("authorization data was empty")
}

return &credsData{
registry: ecrPublicURL,
authToken: output.AuthorizationData.AuthorizationToken,
expiresAt: output.AuthorizationData.ExpiresAt,
}, nil
}

func (e *ecrPlugin) getPrivateCredsData(image string) (*credsData, error) {
klog.Infof("Getting creds for private registry %s", image)
registryID, region, registry, err := parseRepoURL(image)
if err != nil {
return nil, err
}

if e.ecr == nil {
e.ecr, err = defaultECRProvider(region, registryID)
e.ecr, err = defaultECRProvider(region)
if err != nil {
return nil, err
}
Expand All @@ -87,12 +149,32 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []str
return nil, errors.New("authorization data was empty")
}

data := output.AuthorizationData[0]
if data.AuthorizationToken == nil {
return &credsData{
registry: registry,
authToken: output.AuthorizationData[0].AuthorizationToken,
expiresAt: output.AuthorizationData[0].ExpiresAt,
}, nil
}

func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []string) (*v1.CredentialProviderResponse, error) {
var creds *credsData
var err error

if strings.Contains(image, ecrPublicURL) {
creds, err = e.getPublicCredsData()
} else {
creds, err = e.getPrivateCredsData(image)
}

if err != nil {
return nil, err
}

if creds.authToken == nil {
return nil, errors.New("authorization token in response was nil")
}

decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(data.AuthorizationToken))
decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(creds.authToken))
if err != nil {
return nil, err
}
Expand All @@ -102,13 +184,13 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []str
return nil, errors.New("error parsing username and password from authorization token")
}

cacheDuration := getCacheDuration(data.ExpiresAt)
cacheDuration := getCacheDuration(creds.expiresAt)

return &v1.CredentialProviderResponse{
CacheKeyType: v1.RegistryPluginCacheKeyType,
CacheDuration: cacheDuration,
Auth: map[string]v1.AuthConfig{
registry: {
creds.registry: {
Username: parts[0],
Password: parts[1],
},
Expand Down
109 changes: 106 additions & 3 deletions cmd/ecr-credential-provider/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecrpublic"
"github.com/golang/mock/gomock"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/cloud-provider-aws/pkg/providers/v2/mocks"
"k8s.io/kubelet/pkg/apis/credentialprovider/v1"
)

func generateGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecr.GetAuthorizationTokenOutput {
func generatePrivateGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecr.GetAuthorizationTokenOutput {
creds := []byte(fmt.Sprintf("%s:%s", user, password))
data := &ecr.AuthorizationData{
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString(creds)),
Expand All @@ -58,7 +59,7 @@ func generateResponse(registry string, username string, password string) *v1.Cre
}
}

func Test_GetCredentials(t *testing.T) {
func Test_GetCredentials_Private(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -76,7 +77,7 @@ func Test_GetCredentials(t *testing.T) {
{
name: "success",
image: "123456789123.dkr.ecr.us-west-2.amazonaws.com",
getAuthorizationTokenOutput: generateGetAuthorizationTokenOutput("user", "pass", "", nil),
getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil),
response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"),
},
{
Expand Down Expand Up @@ -148,6 +149,108 @@ func Test_GetCredentials(t *testing.T) {
}
}

func generatePublicGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecrpublic.GetAuthorizationTokenOutput {
creds := []byte(fmt.Sprintf("%s:%s", user, password))
data := &ecrpublic.AuthorizationData{
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString(creds)),
ExpiresAt: expiration,
}
output := &ecrpublic.GetAuthorizationTokenOutput{
AuthorizationData: data,
}
return output
}

func Test_GetCredentials_Public(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockECRPublic := mocks.NewMockECRPublic(ctrl)

testcases := []struct {
name string
image string
args []string
getAuthorizationTokenOutput *ecrpublic.GetAuthorizationTokenOutput
getAuthorizationTokenError error
response *v1.CredentialProviderResponse
expectedError error
}{
{
name: "success",
image: "public.ecr.aws",
getAuthorizationTokenOutput: generatePublicGetAuthorizationTokenOutput("user", "pass", "", nil),
response: generateResponse("public.ecr.aws", "user", "pass"),
},
{
name: "empty authorization data",
image: "public.ecr.aws",
getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{},
getAuthorizationTokenError: nil,
expectedError: errors.New("authorization data was empty"),
},
{
name: "nil response",
image: "public.ecr.aws",
getAuthorizationTokenOutput: nil,
getAuthorizationTokenError: nil,
expectedError: errors.New("response output from ECR was nil"),
},
{
name: "empty authorization token",
image: "public.ecr.aws",
getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{AuthorizationData: &ecrpublic.AuthorizationData{}},
getAuthorizationTokenError: nil,
expectedError: errors.New("authorization token in response was nil"),
},
{
name: "invalid authorization token",
image: "public.ecr.aws",
getAuthorizationTokenOutput: nil,
getAuthorizationTokenError: errors.New("getAuthorizationToken failed"),
expectedError: errors.New("getAuthorizationToken failed"),
},
{
name: "invalid authorization token",
image: "public.ecr.aws",
getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{
AuthorizationData: &ecrpublic.AuthorizationData{
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(fmt.Sprint("foo")))),
},
},
getAuthorizationTokenError: nil,
expectedError: errors.New("error parsing username and password from authorization token"),
},
}

for _, testcase := range testcases {
t.Run(testcase.name, func(t *testing.T) {
p := &ecrPlugin{ecrPublic: mockECRPublic}
mockECRPublic.EXPECT().GetAuthorizationToken(gomock.Any()).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError)

creds, err := p.GetCredentials(context.TODO(), testcase.image, testcase.args)

if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) {
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())
}

if testcase.expectedError == nil {
if creds.CacheKeyType != testcase.response.CacheKeyType {
t.Fatalf("Unexpected CacheKeyType. Expected: %s, got: %s", testcase.response.CacheKeyType, creds.CacheKeyType)
}

if creds.Auth[testcase.image] != testcase.response.Auth[testcase.image] {
t.Fatalf("Unexpected Auth. Expected: %s, got: %s", testcase.response.Auth[testcase.image], creds.Auth[testcase.image])
}

if creds.CacheDuration.Duration != testcase.response.CacheDuration.Duration {
t.Fatalf("Unexpected CacheDuration. Expected: %s, got: %s", testcase.response.CacheDuration.Duration, creds.CacheDuration.Duration)
}
}
})
}
}

func Test_ParseURL(t *testing.T) {
testcases := []struct {
name string
Expand Down
54 changes: 47 additions & 7 deletions pkg/providers/v2/mocks/mock_ecr.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b691271

Please sign in to comment.