Skip to content

Commit

Permalink
[feat] ecr-credential-provider support to authenticate public registries
Browse files Browse the repository at this point in the history
  • Loading branch information
mmerkes committed May 7, 2023
1 parent 21acaa9 commit e694c92
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 26 deletions.
133 changes: 111 additions & 22 deletions cmd/ecr-credential-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,36 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecrpublic"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/klog/v2"
"k8s.io/kubelet/pkg/apis/credentialprovider/v1"
)

const ecrPublicRegion string = "us-east-1"

var ecrPattern = regexp.MustCompile(`^(\d{12})\.dkr\.ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.(amazonaws\.com(\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov)$`)

// public.ecr.aws/v2e8m4u0/example-registry
var ecrPublicPattern = regexp.MustCompile(`public.ecr.aws/([a-z0-9][a-z0-9-_]*)/([a-z0-9][a-z0-9-_./]*)`)

// ECR abstracts the calls we make to aws-sdk for testing purposes
type ECR interface {
GetAuthorizationToken(input *ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error)
}

// ECRPublic abstracts the calls we make to aws-sdk for testing purposes
type ECRPublic interface {
GetAuthorizationToken(input *ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error)
}

type ecrPlugin struct {
ecr ECR
ecr ECR
ecrPublic ECRPublic
}

func defaultECRProvider(region string, registryID string) (*ecr.ECR, error) {
func defaultECRProvider(region string) (*ecr.ECR, error) {
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(region)},
SharedConfigState: session.SharedConfigEnable,
Expand All @@ -59,40 +71,108 @@ func defaultECRProvider(region string, registryID string) (*ecr.ECR, error) {
return ecr.New(sess), nil
}

func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []string) (*v1.CredentialProviderResponse, error) {
registryID, region, registry, err := parseRepoURL(image)
func publicECRProvider() (*ecrpublic.ECRPublic, error) {
// ECR public registries are only in one region and only accessible from regions
// in the "aws" partition.
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(ecrPublicRegion)},
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, err
}

return ecrpublic.New(sess), nil
}

type credData struct {
registry string
authToken *string
expiresAt *time.Time
}

func (e *ecrPlugin) getPublicCredsData() (credData, error) {
klog.Infof("Getting creds for public registry")
var err error

if e.ecrPublic == nil {
e.ecrPublic, err = publicECRProvider()
}
if err != nil {
return credData{}, err
}

output, err := e.ecrPublic.GetAuthorizationToken(&ecrpublic.GetAuthorizationTokenInput{})
if err != nil {
return credData{}, err
}

if output == nil {
return credData{}, errors.New("response output from ECR was nil")
}

if output.AuthorizationData == nil {
return credData{}, errors.New("authorization data was empty")
}

return credData{
registry: "public.ecr.aws",
authToken: output.AuthorizationData.AuthorizationToken,
expiresAt: output.AuthorizationData.ExpiresAt,
}, nil
}

func (e *ecrPlugin) getPrivateCredsData(image string) (credData, error) {
klog.Infof("Getting creds for private registry %s", image)
registryID, region, registry, err := parseRepoURL(image)
if err != nil {
return credData{}, err
}

if e.ecr == nil {
e.ecr, err = defaultECRProvider(region, registryID)
e.ecr, err = defaultECRProvider(region)
if err != nil {
return nil, err
return credData{}, err
}
}

output, err := e.ecr.GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{
RegistryIds: []*string{aws.String(registryID)},
})
if err != nil {
return nil, err
return credData{}, err
}

if output == nil {
return nil, errors.New("response output from ECR was nil")
return credData{}, errors.New("response output from ECR was nil")
}

if len(output.AuthorizationData) == 0 {
return nil, errors.New("authorization data was empty")
return credData{}, errors.New("authorization data was empty")
}

return credData{
registry: registry,
authToken: output.AuthorizationData[0].AuthorizationToken,
expiresAt: output.AuthorizationData[0].ExpiresAt,
}, nil
}

func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []string) (*v1.CredentialProviderResponse, error) {
var creds credData
var err error

if ecrPublicPattern.MatchString(image) {
creds, err = e.getPublicCredsData()
} else {
creds, err = e.getPrivateCredsData(image)
}

data := output.AuthorizationData[0]
if data.AuthorizationToken == nil {
if creds.authToken == nil {
return nil, errors.New("authorization token in response was nil")
}

decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(data.AuthorizationToken))
decodedToken, err := base64.StdEncoding.DecodeString(aws.StringValue(creds.authToken))
if err != nil {
return nil, err
}
Expand All @@ -102,13 +182,13 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []str
return nil, errors.New("error parsing username and password from authorization token")
}

cacheDuration := getCacheDuration(data.ExpiresAt)
cacheDuration := getCacheDuration(creds.expiresAt)

return &v1.CredentialProviderResponse{
CacheKeyType: v1.RegistryPluginCacheKeyType,
CacheDuration: cacheDuration,
Auth: map[string]v1.AuthConfig{
registry: {
creds.registry: {
Username: parts[0],
Password: parts[1],
},
Expand All @@ -135,24 +215,33 @@ func getCacheDuration(expiresAt *time.Time) *metav1.Duration {
return cacheDuration
}

// parseRepoURL parses and splits the registry URL
// returns (registryID, region, registry).
// <registryID>.dkr.ecr(-fips).<region>.amazonaws.com(.cn)
func parseRepoURL(image string) (string, string, string, error) {
func getRegistryFromImage(image string) (string, error) {
if !strings.Contains(image, "https://") {
image = "https://" + image
}
parsed, err := url.Parse(image)
if err != nil {
return "", "", "", fmt.Errorf("error parsing image %s: %v", image, err)
return "", fmt.Errorf("error parsing image %s: %v", image, err)
}

return parsed.Hostname(), nil
}

// parseRepoURL parses and splits the registry URL
// returns (registryID, region, registry).
// <registryID>.dkr.ecr(-fips).<region>.amazonaws.com(.cn)
func parseRepoURL(image string) (string, string, string, error) {
registry, err := getRegistryFromImage(image)
if err != nil {
return "", "", "", err
}

splitURL := ecrPattern.FindStringSubmatch(parsed.Hostname())
splitURL := ecrPattern.FindStringSubmatch(registry)
if len(splitURL) < 4 {
return "", "", "", fmt.Errorf("%s is not a valid ECR repository URL", parsed.Hostname())
return "", "", "", fmt.Errorf("%s is not a valid ECR repository URL", registry)
}

return splitURL[1], splitURL[3], parsed.Hostname(), nil
return splitURL[1], splitURL[3], registry, nil
}

func main() {
Expand Down
111 changes: 107 additions & 4 deletions cmd/ecr-credential-provider/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecrpublic"
"github.com/golang/mock/gomock"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/cloud-provider-aws/pkg/providers/v2/mocks"
"k8s.io/kubelet/pkg/apis/credentialprovider/v1"
)

func generateGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecr.GetAuthorizationTokenOutput {
func generatePrivateGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecr.GetAuthorizationTokenOutput {
creds := []byte(fmt.Sprintf("%s:%s", user, password))
data := &ecr.AuthorizationData{
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString(creds)),
Expand All @@ -58,7 +59,7 @@ func generateResponse(registry string, username string, password string) *v1.Cre
}
}

func Test_GetCredentials(t *testing.T) {
func Test_GetCredentials_Private(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -76,7 +77,7 @@ func Test_GetCredentials(t *testing.T) {
{
name: "success",
image: "123456789123.dkr.ecr.us-west-2.amazonaws.com",
getAuthorizationTokenOutput: generateGetAuthorizationTokenOutput("user", "pass", "", nil),
getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil),
response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"),
},
{
Expand Down Expand Up @@ -148,6 +149,108 @@ func Test_GetCredentials(t *testing.T) {
}
}

func generatePublicGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecrpublic.GetAuthorizationTokenOutput {
creds := []byte(fmt.Sprintf("%s:%s", user, password))
data := &ecrpublic.AuthorizationData{
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString(creds)),
ExpiresAt: expiration,
}
output := &ecrpublic.GetAuthorizationTokenOutput{
AuthorizationData: data,
}
return output
}

func Test_GetCredentials_Public(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockECRPublic := mocks.NewMockECRPublic(ctrl)

testcases := []struct {
name string
image string
args []string
getAuthorizationTokenOutput *ecrpublic.GetAuthorizationTokenOutput
getAuthorizationTokenError error
response *v1.CredentialProviderResponse
expectedError error
}{
{
name: "success",
image: "public.ecr.aws",
getAuthorizationTokenOutput: generatePublicGetAuthorizationTokenOutput("user", "pass", "", nil),
response: generateResponse("public.ecr.aws", "user", "pass"),
},
{
name: "empty authorization data",
image: "public.ecr.aws",
getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{},
getAuthorizationTokenError: nil,
expectedError: errors.New("authorization data was empty"),
},
{
name: "nil response",
image: "public.ecr.aws",
getAuthorizationTokenOutput: nil,
getAuthorizationTokenError: nil,
expectedError: errors.New("response output from ECR was nil"),
},
{
name: "empty authorization token",
image: "public.ecr.aws",
getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{AuthorizationData: &ecrpublic.AuthorizationData{}},
getAuthorizationTokenError: nil,
expectedError: errors.New("authorization token in response was nil"),
},
{
name: "invalid authorization token",
image: "public.ecr.aws",
getAuthorizationTokenOutput: nil,
getAuthorizationTokenError: errors.New("getAuthorizationToken failed"),
expectedError: errors.New("getAuthorizationToken failed"),
},
{
name: "invalid authorization token",
image: "public.ecr.aws",
getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{
AuthorizationData: &ecrpublic.AuthorizationData{
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(fmt.Sprint("foo")))),
},
},
getAuthorizationTokenError: nil,
expectedError: errors.New("error parsing username and password from authorization token"),
},
}

for _, testcase := range testcases {
t.Run(testcase.name, func(t *testing.T) {
p := &ecrPlugin{ecrPublic: mockECRPublic}
mockECRPublic.EXPECT().GetAuthorizationToken(gomock.Any()).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError)

creds, err := p.GetCredentials(context.TODO(), testcase.image, testcase.args)

if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) {
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())
}

if testcase.expectedError == nil {
if creds.CacheKeyType != testcase.response.CacheKeyType {
t.Fatalf("Unexpected CacheKeyType. Expected: %s, got: %s", testcase.response.CacheKeyType, creds.CacheKeyType)
}

if creds.Auth[testcase.image] != testcase.response.Auth[testcase.image] {
t.Fatalf("Unexpected Auth. Expected: %s, got: %s", testcase.response.Auth[testcase.image], creds.Auth[testcase.image])
}

if creds.CacheDuration.Duration != testcase.response.CacheDuration.Duration {
t.Fatalf("Unexpected CacheDuration. Expected: %s, got: %s", testcase.response.CacheDuration.Duration, creds.CacheDuration.Duration)
}
}
})
}
}

func Test_ParseURL(t *testing.T) {
testcases := []struct {
name string
Expand All @@ -159,7 +262,7 @@ func Test_ParseURL(t *testing.T) {
}{
{
name: "success",
image: "123456789123.dkr.ecr.us-west-2.amazonaws.com",
image: "123456789123.dkr.ecr.us-west-2.amazonaws.com:v1.28",
registryID: "123456789123",
region: "us-west-2",
registry: "123456789123.dkr.ecr.us-west-2.amazonaws.com",
Expand Down
Loading

0 comments on commit e694c92

Please sign in to comment.