diff --git a/README.md b/README.md index f14bf3d9..6156a1e9 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,31 @@ metadata: iam.amazonaws.com/role: reportingdb-reader ``` +You can control the session name used when assuming the role via an annotation added to the `Pod`, which may be used to further identify the session. For example: + +```yaml +kind: Pod +metadata: + name: foo + namespace: session-name-example + annotations: + iam.amazonaws.com/role: reportingdb-reader + iam.amazonaws.com/session-name: my-session-name +``` + +You can also control the external id used when assuming the role via an annotation added to the `Pod`, which +maybe used to avoid [confused deputy scenarios in cross-organisation role assumption](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-user_externalid.html). For example: + +```yaml +kind: Pod +metadata: + name: foo + namespace: external-id-example + annotations: + iam.amazonaws.com/role: reportingdb-reader + iam.amazonaws.com/external-id: dac7ad46-acab-4ec3-a78e-f3962ecf45d7 +``` + Further, all namespaces must also have an annotation with a regular expression expressing which roles are permitted to be assumed within that namespace. **Without the namespace annotation the pod will be unable to assume any roles.** ```yaml diff --git a/pkg/aws/sts/arn_resolver.go b/pkg/aws/sts/arn_resolver.go index 885e0dba..7d8f90f6 100644 --- a/pkg/aws/sts/arn_resolver.go +++ b/pkg/aws/sts/arn_resolver.go @@ -22,6 +22,11 @@ type Resolver struct { prefix string } +type ResolvedRole struct { + Name string + ARN string +} + // DefaultResolver will add the prefix to any roles which // don't start with arn: func DefaultResolver(prefix string) *Resolver { @@ -29,20 +34,20 @@ func DefaultResolver(prefix string) *Resolver { } // Resolve converts from a role string into the absolute role arn. -func (r *Resolver) Resolve(role string) (*RoleIdentity, error) { +func (r *Resolver) Resolve(role string) (*ResolvedRole, error) { if role == "" { return nil, fmt.Errorf("role can't be empty") } if strings.HasPrefix(role, "arn:") { - return &RoleIdentity{ARN: role, Role: roleFromArn(role)}, nil + return &ResolvedRole{ARN: role, Name: roleFromArn(role)}, nil } if strings.HasPrefix(role, "/") { role = strings.TrimPrefix(role, "/") } - return &RoleIdentity{ARN: fmt.Sprintf("%s%s", r.prefix, role), Role: role}, nil + return &ResolvedRole{ARN: fmt.Sprintf("%s%s", r.prefix, role), Name: role}, nil } // arn:aws:iam::account-id:role/role-name-with-path @@ -50,3 +55,7 @@ func roleFromArn(arn string) string { splits := strings.SplitAfterN(arn, ":", 6) return strings.TrimPrefix(splits[5], "role/") } + +func (i *ResolvedRole) Equals(other *ResolvedRole) bool { + return *i == *other +} diff --git a/pkg/aws/sts/arn_resolver_test.go b/pkg/aws/sts/arn_resolver_test.go index 17883de2..dfc24d07 100644 --- a/pkg/aws/sts/arn_resolver_test.go +++ b/pkg/aws/sts/arn_resolver_test.go @@ -17,7 +17,7 @@ import ( "testing" ) -func TestRoleIdentityEquality(t *testing.T) { +func TestResolvedRoleEquality(t *testing.T) { resolver := DefaultResolver("arn:aws:iam::account-id:role/") i1, _ := resolver.Resolve("foo") i2, _ := resolver.Resolve("foo") @@ -34,10 +34,10 @@ func TestRoleIdentityEquality(t *testing.T) { func TestAddsPrefix(t *testing.T) { resolver := DefaultResolver("arn:aws:iam::account-id:role/") - identity, _ := resolver.Resolve("myrole") + resolvedRole, _ := resolver.Resolve("myrole") - if identity.ARN != "arn:aws:iam::account-id:role/myrole" { - t.Error("unexpected role, was:", identity.ARN) + if resolvedRole.ARN != "arn:aws:iam::account-id:role/myrole" { + t.Error("unexpected role, was:", resolvedRole.ARN) } } @@ -52,47 +52,47 @@ func TestReturnsErrorForEmptyRole(t *testing.T) { func TestAddsPrefixWithRoleBeginningWithSlash(t *testing.T) { resolver := DefaultResolver("arn:aws:iam::account-id:role/") - identity, _ := resolver.Resolve("/myrole") + resolvedRole, _ := resolver.Resolve("/myrole") - if identity.ARN != "arn:aws:iam::account-id:role/myrole" { - t.Error("unexpected role, was:", identity.ARN) + if resolvedRole.ARN != "arn:aws:iam::account-id:role/myrole" { + t.Error("unexpected role, was:", resolvedRole.ARN) } - if identity.Role != "myrole" { - t.Error("unexpected role, was", identity.Role) + if resolvedRole.Name != "myrole" { + t.Error("unexpected role, was", resolvedRole.Name) } } func TestAddsPrefixWithRoleBeginningWithPathWithoutSlash(t *testing.T) { resolver := DefaultResolver("arn:aws:iam::account-id:role/") - identity, _ := resolver.Resolve("kiam/myrole") + resolvedRole, _ := resolver.Resolve("kiam/myrole") - if identity.ARN != "arn:aws:iam::account-id:role/kiam/myrole" { - t.Error("unexpected role, was:", identity.ARN) + if resolvedRole.ARN != "arn:aws:iam::account-id:role/kiam/myrole" { + t.Error("unexpected role, was:", resolvedRole.ARN) } - if identity.Role != "kiam/myrole" { - t.Error("unexpected role", identity.Role) + if resolvedRole.Name != "kiam/myrole" { + t.Error("unexpected role", resolvedRole.Name) } } func TestAddsPrefixWithRoleBeginningWithSlashPath(t *testing.T) { resolver := DefaultResolver("arn:aws:iam::account-id:role/") - identity, _ := resolver.Resolve("/kiam/myrole") + resolvedRole, _ := resolver.Resolve("/kiam/myrole") - if identity.ARN != "arn:aws:iam::account-id:role/kiam/myrole" { - t.Error("unexpected role, was:", identity.ARN) + if resolvedRole.ARN != "arn:aws:iam::account-id:role/kiam/myrole" { + t.Error("unexpected role, was:", resolvedRole.ARN) } } func TestUsesAbsoluteARN(t *testing.T) { resolver := DefaultResolver("arn:aws:iam::account-id:role/") - identity, _ := resolver.Resolve("arn:aws:iam::some-other-account:role/path-prefix/another-role") + resolvedRole, _ := resolver.Resolve("arn:aws:iam::some-other-account:role/path-prefix/another-role") - if identity.ARN != "arn:aws:iam::some-other-account:role/path-prefix/another-role" { - t.Error("unexpected role, was:", identity.ARN) + if resolvedRole.ARN != "arn:aws:iam::some-other-account:role/path-prefix/another-role" { + t.Error("unexpected role, was:", resolvedRole.ARN) } - if identity.Role != "path-prefix/another-role" { - t.Error("expected role to be set, was", identity.Role) + if resolvedRole.Name != "path-prefix/another-role" { + t.Error("expected role to be set, was", resolvedRole.Name) } } diff --git a/pkg/aws/sts/cache.go b/pkg/aws/sts/cache.go index 8f7fbe24..e14bd169 100644 --- a/pkg/aws/sts/cache.go +++ b/pkg/aws/sts/cache.go @@ -16,6 +16,7 @@ package sts import ( "context" "fmt" + "regexp" "time" "github.com/patrickmn/go-cache" @@ -34,11 +35,6 @@ type credentialsCache struct { gateway STSGateway } -type RoleIdentity struct { - Role string - ARN string // Amazon Resource Name for the Role -} - type CachedCredentials struct { Identity *RoleIdentity Credentials *Credentials @@ -56,7 +52,7 @@ func DefaultCache( ) *credentialsCache { c := &credentialsCache{ expiring: make(chan *CachedCredentials, 1), - sessionName: fmt.Sprintf("kiam-%s", sessionName), + sessionName: sessionName, sessionDuration: sessionDuration, cacheTTL: sessionDuration - sessionRefresh, gateway: gateway, @@ -95,7 +91,7 @@ func (c *credentialsCache) Expiring() chan *CachedCredentials { // CredentialsForRole looks for cached credentials or requests them from the STSGateway. Requested credentials // must have their ARN set. func (c *credentialsCache) CredentialsForRole(ctx context.Context, identity *RoleIdentity) (*Credentials, error) { - logger := log.WithFields(log.Fields{"pod.iam.role": identity.Role, "pod.iam.roleArn": identity.ARN}) + logger := log.WithFields(identity.LogFields()) item, found := c.cache.Get(identity.String()) if found { @@ -117,7 +113,16 @@ func (c *credentialsCache) CredentialsForRole(ctx context.Context, identity *Rol cacheMiss.Inc() issue := func() (interface{}, error) { - credentials, err := c.gateway.Issue(ctx, identity.ARN, c.sessionName, c.sessionDuration) + sessionName := c.getSessionName(identity) + + stsIssueRequest := &STSIssueRequest{ + RoleARN: identity.Role.ARN, + SessionName: sessionName, + ExternalID: identity.ExternalID, + SessionDuration: c.sessionDuration, + } + + credentials, err := c.gateway.Issue(ctx, stsIssueRequest) if err != nil { errorIssuing.Inc() logger.Errorf("error requesting credentials: %s", err.Error()) @@ -146,10 +151,26 @@ func (c *credentialsCache) CredentialsForRole(ctx context.Context, identity *Rol return cachedCreds.Credentials, nil } -func (i *RoleIdentity) String() string { - return i.ARN +func (c *credentialsCache) getSessionName(identity *RoleIdentity) string { + sessionName := c.sessionName + + if identity.SessionName != "" { + sessionName = identity.SessionName + } + + sessionName = fmt.Sprintf("kiam-%s", sessionName) + return sanitizeSessionName(sessionName) } -func (i *RoleIdentity) Equals(other *RoleIdentity) bool { - return *i == *other +// Ensure the session name meets length requirements and +// also coercce any character that doens't meet the pattern +// requirements to a hyhen so that we ensure a valid session name. +func sanitizeSessionName(sessionName string) string { + sanitize := regexp.MustCompile(`([^\w+=,.@-])`) + + if len(sessionName) > 64 { + sessionName = sessionName[0:63] + } + + return sanitize.ReplaceAllString(sessionName, "-") } diff --git a/pkg/aws/sts/credentials_cache_test.go b/pkg/aws/sts/credentials_cache_test.go index 78d2bb9d..2a89c7c5 100644 --- a/pkg/aws/sts/credentials_cache_test.go +++ b/pkg/aws/sts/credentials_cache_test.go @@ -15,20 +15,26 @@ package sts import ( "context" - "github.com/prometheus/client_golang/prometheus/testutil" "testing" "time" + + "github.com/prometheus/client_golang/prometheus/testutil" ) type stubGateway struct { - c *Credentials - issueCount int - requestedRole string + c *Credentials + issueCount int + requestedRole string + requestedSessionName string + requestedExternalID string } -func (s *stubGateway) Issue(ctx context.Context, roleARN, sessionName string, expiry time.Duration) (*Credentials, error) { +func (s *stubGateway) Issue(ctx context.Context, request *STSIssueRequest) (*Credentials, error) { s.issueCount = s.issueCount + 1 - s.requestedRole = roleARN + s.requestedRole = request.RoleARN + s.requestedSessionName = request.SessionName + s.requestedExternalID = request.ExternalID + return s.c, nil } @@ -37,7 +43,7 @@ func TestRequestsCredentialsFromGatewayWithEmptyCache(t *testing.T) { cache := DefaultCache(stubGateway, "session", 15*time.Minute, 5*time.Minute) ctx := context.Background() - credentialsIdentity := &RoleIdentity{Role: "role", ARN: "arn:account:role"} + credentialsIdentity := &RoleIdentity{Role: ResolvedRole{Name: "role", ARN: "arn:account:role"}} creds, _ := cache.CredentialsForRole(ctx, credentialsIdentity) if creds.Code != "foo" { t.Error("didnt return expected credentials code, was", creds.Code) @@ -55,3 +61,51 @@ func TestRequestsCredentialsFromGatewayWithEmptyCache(t *testing.T) { t.Error("unexpected role, was:", stubGateway.requestedRole) } } + +func TestRequestsCredentialsWithSessionName(t *testing.T) { + var tests = []struct { + name string + sessionName string + expectedSessionName string + }{ + {"Default", "testing", "kiam-testing"}, + {"InvalidCharsReplacedWithHyphen", "testing@#&-test%", "kiam-testing@---test-"}, + {"LongNameLimitedTo64Chars", "Unsplvku4rP9A71Zb5DUQtKviVKSENh0GlKxVRPXGvfDyXXXy8OGqTVfc05DCAhKT9oHXU", "kiam-Unsplvku4rP9A71Zb5DUQtKviVKSENh0GlKxVRPXGvfDyXXXy8OGqTVfc0"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stubGateway := &stubGateway{c: &Credentials{Code: "foo"}} + cache := DefaultCache(stubGateway, "session", 15*time.Minute, 5*time.Minute) + ctx := context.Background() + + credentialsIdentity := &RoleIdentity{ + Role: ResolvedRole{Name: "role", ARN: "arn:account:role"}, + SessionName: tt.sessionName, + } + + _, _ = cache.CredentialsForRole(ctx, credentialsIdentity) + + if stubGateway.requestedSessionName != tt.expectedSessionName { + t.Error("unexpected session-name, was:", stubGateway.requestedSessionName) + } + }) + } +} + +func TestRequestsCredentialsWithExternalID(t *testing.T) { + stubGateway := &stubGateway{c: &Credentials{Code: "foo"}} + cache := DefaultCache(stubGateway, "session", 15*time.Minute, 5*time.Minute) + ctx := context.Background() + + credentialsIdentity := &RoleIdentity{ + Role: ResolvedRole{Name: "role", ARN: "arn:account:role"}, + ExternalID: "123456", + } + + _, _ = cache.CredentialsForRole(ctx, credentialsIdentity) + + if stubGateway.requestedExternalID != "123456" { + t.Error("unexpected external-id, was:", stubGateway.requestedExternalID) + } +} diff --git a/pkg/aws/sts/gateway.go b/pkg/aws/sts/gateway.go index 892fb302..e9bd3381 100644 --- a/pkg/aws/sts/gateway.go +++ b/pkg/aws/sts/gateway.go @@ -23,8 +23,15 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +type STSIssueRequest struct { + RoleARN string + SessionName string + ExternalID string + SessionDuration time.Duration +} + type STSGateway interface { - Issue(ctx context.Context, role, session string, expiry time.Duration) (*Credentials, error) + Issue(ctx context.Context, request *STSIssueRequest) (*Credentials, error) } type DefaultSTSGateway struct { @@ -35,7 +42,7 @@ func DefaultGateway(config *aws.Config) (*DefaultSTSGateway, error) { return &DefaultSTSGateway{session: session.Must(session.NewSession(config))}, nil } -func (g *DefaultSTSGateway) Issue(ctx context.Context, roleARN, sessionName string, expiry time.Duration) (*Credentials, error) { +func (g *DefaultSTSGateway) Issue(ctx context.Context, request *STSIssueRequest) (*Credentials, error) { timer := prometheus.NewTimer(assumeRole) defer timer.ObserveDuration() @@ -44,10 +51,15 @@ func (g *DefaultSTSGateway) Issue(ctx context.Context, roleARN, sessionName stri svc := sts.New(g.session) in := &sts.AssumeRoleInput{ - DurationSeconds: aws.Int64(int64(expiry.Seconds())), - RoleArn: aws.String(roleARN), - RoleSessionName: aws.String(sessionName), + DurationSeconds: aws.Int64(int64(request.SessionDuration.Seconds())), + RoleArn: aws.String(request.RoleARN), + RoleSessionName: aws.String(request.SessionName), } + + if request.ExternalID != "" { + in.ExternalId = aws.String(request.ExternalID) + } + resp, err := svc.AssumeRoleWithContext(ctx, in) if err != nil { return nil, err diff --git a/pkg/aws/sts/interfaces.go b/pkg/aws/sts/interfaces.go index e9377ebd..51251d04 100644 --- a/pkg/aws/sts/interfaces.go +++ b/pkg/aws/sts/interfaces.go @@ -28,5 +28,5 @@ type CredentialsCache interface { // ARNResolver encapsulates resolution of roles into ARNs. type ARNResolver interface { - Resolve(role string) (*RoleIdentity, error) + Resolve(role string) (*ResolvedRole, error) } diff --git a/pkg/aws/sts/role_identity.go b/pkg/aws/sts/role_identity.go new file mode 100644 index 00000000..8e21ed31 --- /dev/null +++ b/pkg/aws/sts/role_identity.go @@ -0,0 +1,37 @@ +package sts + +import ( + "fmt" + + log "github.com/sirupsen/logrus" +) + +type RoleIdentity struct { + Role ResolvedRole + SessionName string + ExternalID string +} + +func NewRoleIdentity(arnResolver ARNResolver, role, sessionName, externalID string) (*RoleIdentity, error) { + resolvedRole, err := arnResolver.Resolve(role) + if err != nil { + return nil, err + } + + return &RoleIdentity{ + Role: *resolvedRole, + SessionName: sessionName, + ExternalID: externalID, + }, nil +} + +func (i *RoleIdentity) String() string { + return fmt.Sprintf("%s|%s|%s", i.Role.ARN, i.SessionName, i.ExternalID) +} + +func (i *RoleIdentity) LogFields() log.Fields { + return log.Fields{ + "pod.iam.role": i.Role, + "pod.iam.roleArn": i.Role.ARN, + } +} diff --git a/pkg/k8s/interfaces.go b/pkg/k8s/interfaces.go index 343db9d7..05415077 100644 --- a/pkg/k8s/interfaces.go +++ b/pkg/k8s/interfaces.go @@ -15,7 +15,9 @@ package k8s import ( "context" - "k8s.io/api/core/v1" + + "github.com/uswitch/kiam/pkg/aws/sts" + v1 "k8s.io/api/core/v1" ) type PodGetter interface { @@ -26,7 +28,7 @@ type PodAnnouncer interface { // Will receive a Pod whenever there's a change/addition for a Pod with a role. Pods() <-chan *v1.Pod // Return whether there are still uncompleted pods in the specified role - IsActivePodsForRole(role string) (bool, error) + IsActivePodsForRole(identity *sts.RoleIdentity) (bool, error) } type NamespaceFinder interface { diff --git a/pkg/k8s/pod_cache.go b/pkg/k8s/pod_cache.go index c94f1e26..7bd883a0 100644 --- a/pkg/k8s/pod_cache.go +++ b/pkg/k8s/pod_cache.go @@ -19,7 +19,8 @@ import ( "time" log "github.com/sirupsen/logrus" - "k8s.io/api/core/v1" + "github.com/uswitch/kiam/pkg/aws/sts" + v1 "k8s.io/api/core/v1" "k8s.io/client-go/tools/cache" ) @@ -34,10 +35,10 @@ type PodCache struct { // IP address so that Kiam can identify which role a Pod should assume. It periodically syncs the list of // pods and can announce Pods. When announcing Pods via the channel it will drop events if the buffer // is full- bufferSize determines how many. -func NewPodCache(source cache.ListerWatcher, syncInterval time.Duration, bufferSize int) *PodCache { +func NewPodCache(arnResolver sts.ARNResolver, source cache.ListerWatcher, syncInterval time.Duration, bufferSize int) *PodCache { indexers := cache.Indexers{ - indexPodIP: podIPIndex, - indexPodRole: podRoleIndex, + indexPodIP: podIPIndex, + indexPodRoleIdentity: podRoleIdentityIndex(arnResolver), } pods := make(chan *v1.Pod, bufferSize) podHandler := &podHandler{pods} @@ -70,8 +71,8 @@ func (s *PodCache) Pods() <-chan *v1.Pod { // using the provided role. This is used to identify whether the // role credentials should be maintained. Part of the PodAnnouncer // interface -func (s *PodCache) IsActivePodsForRole(role string) (bool, error) { - items, err := s.indexer.ByIndex(indexPodRole, role) +func (s *PodCache) IsActivePodsForRole(identity *sts.RoleIdentity) (bool, error) { + items, err := s.indexer.ByIndex(indexPodRoleIdentity, identity.String()) if err != nil { return false, err } @@ -138,8 +139,8 @@ func (s *PodCache) GetPodByIP(ip string) (*v1.Pod, error) { } const ( - indexPodIP = "byIP" - indexPodRole = "byRole" + indexPodIP = "byIP" + indexPodRoleIdentity = "byRoleIdentity" ) func podIPIndex(obj interface{}) ([]string, error) { @@ -152,14 +153,23 @@ func podIPIndex(obj interface{}) ([]string, error) { return []string{pod.Status.PodIP}, nil } -func podRoleIndex(obj interface{}) ([]string, error) { - pod := obj.(*v1.Pod) - role := PodRole(pod) - if role == "" { - return []string{}, nil - } +func podRoleIdentityIndex(arnResolver sts.ARNResolver) func(obj interface{}) ([]string, error) { + return func(obj interface{}) ([]string, error) { + pod := obj.(*v1.Pod) + role := PodRole(pod) + if role == "" { + return []string{}, nil + } + + sessionName := PodSessionName(pod) + externalID := PodExternalID(pod) + identity, err := sts.NewRoleIdentity(arnResolver, role, sessionName, externalID) + if err != nil { + return nil, err + } - return []string{role}, nil + return []string{identity.String()}, nil + } } // Run starts the controller processing updates. Blocks until the cache has synced @@ -180,9 +190,25 @@ func PodRole(pod *v1.Pod) string { return pod.ObjectMeta.Annotations[AnnotationIAMRoleKey] } +// PodSessionName returns the IAM role session-name specified in the annotation for the Pod +func PodSessionName(pod *v1.Pod) string { + return pod.ObjectMeta.Annotations[AnnotationIAMSessionNameKey] +} + +// PodExternalID returns the IAM role external-id specified in the annotation for the Pod +func PodExternalID(pod *v1.Pod) string { + return pod.ObjectMeta.Annotations[AnnotationIAMExternalIDKey] +} + // AnnotationIAMRoleKey is the key for the annotation specifying the IAM Role const AnnotationIAMRoleKey = "iam.amazonaws.com/role" +// AnnotationIAMSessionNameKey is the key for the annotation specifying the session-name +const AnnotationIAMSessionNameKey = "iam.amazonaws.com/session-name" + +// AnnotationIAMExternalIDKey is the key for the annotation specifying the external-id +const AnnotationIAMExternalIDKey = "iam.amazonaws.com/external-id" + type podHandler struct { pods chan<- *v1.Pod } diff --git a/pkg/k8s/pod_cache_test.go b/pkg/k8s/pod_cache_test.go index 01c912d6..8f3ea24b 100644 --- a/pkg/k8s/pod_cache_test.go +++ b/pkg/k8s/pod_cache_test.go @@ -16,11 +16,13 @@ package k8s import ( "context" "fmt" + "testing" + "time" + "github.com/fortytw2/leaktest" + "github.com/uswitch/kiam/pkg/aws/sts" "github.com/uswitch/kiam/pkg/testutil" kt "k8s.io/client-go/tools/cache/testing" - "testing" - "time" ) const bufferSize = 10 @@ -32,7 +34,8 @@ func TestFindsRunningPod(t *testing.T) { defer cancel() source := kt.NewFakeControllerSource() - c := NewPodCache(source, time.Second, bufferSize) + arnResolver := sts.DefaultResolver("arn:account:") + c := NewPodCache(arnResolver, source, time.Second, bufferSize) source.Add(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Failed", "failed_role")) source.Add(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Running", "running_role")) c.Run(ctx) @@ -54,24 +57,81 @@ func TestFindRoleActive(t *testing.T) { defer cancel() source := kt.NewFakeControllerSource() - c := NewPodCache(source, time.Second, bufferSize) + arnResolver := sts.DefaultResolver("arn:account:") + c := NewPodCache(arnResolver, source, time.Second, bufferSize) source.Add(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Failed", "failed_role")) source.Modify(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Failed", "running_role")) source.Modify(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Running", "running_role")) c.Run(ctx) defer source.Shutdown() - active, _ := c.IsActivePodsForRole("failed_role") + identity, _ := sts.NewRoleIdentity(arnResolver, "failed_role", "", "") + active, _ := c.IsActivePodsForRole(identity) if active { t.Error("expected no active pods in failed_role") } - active, _ = c.IsActivePodsForRole("running_role") + identity, _ = sts.NewRoleIdentity(arnResolver, "running_role", "", "") + active, _ = c.IsActivePodsForRole(identity) if !active { t.Error("expected running pod") } } +func TestFindRoleActiveWithSessionName(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + source := kt.NewFakeControllerSource() + arnResolver := sts.DefaultResolver("arn:account:") + c := NewPodCache(arnResolver, source, time.Second, bufferSize) + source.Add(testutil.NewPodWithSessionName("ns", "active-reader", "192.168.0.1", "Running", "reader", "active-reader")) + source.Add(testutil.NewPodWithSessionName("ns", "stopped-reader", "192.168.0.2", "Succeeded", "reader", "stopped-reader")) + c.Run(ctx) + defer source.Shutdown() + + identity, _ := sts.NewRoleIdentity(arnResolver, "reader", "active-reader", "") + active, _ := c.IsActivePodsForRole(identity) + if !active { + t.Error("expected running pod for active-reader") + } + + identity, _ = sts.NewRoleIdentity(arnResolver, "reader", "stopped-reader", "") + active, _ = c.IsActivePodsForRole(identity) + if active { + t.Error("expected no active pods for stopped-reader") + } +} + +func TestFindRoleActiveWithExternalID(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + source := kt.NewFakeControllerSource() + arnResolver := sts.DefaultResolver("arn:account:") + c := NewPodCache(arnResolver, source, time.Second, bufferSize) + source.Add(testutil.NewPodWithExternalID("ns", "active-reader", "192.168.0.1", "Running", "reader", "1234")) + source.Add(testutil.NewPodWithExternalID("ns", "stopped-reader", "192.168.0.2", "Succeeded", "reader", "4321")) + c.Run(ctx) + defer source.Shutdown() + + identity, _ := sts.NewRoleIdentity(arnResolver, "reader", "", "1234") + active, _ := c.IsActivePodsForRole(identity) + if !active { + t.Error("expected running pod for active-reader") + } + + identity, _ = sts.NewRoleIdentity(arnResolver, "reader", "", "4321") + active, _ = c.IsActivePodsForRole(identity) + if active { + t.Error("expected no active pods for stopped-reader") + } +} + func BenchmarkFindPodsByIP(b *testing.B) { b.StopTimer() @@ -79,7 +139,8 @@ func BenchmarkFindPodsByIP(b *testing.B) { defer cancel() source := kt.NewFakeControllerSource() - c := NewPodCache(source, time.Second, bufferSize) + arnResolver := sts.DefaultResolver("arn:account:") + c := NewPodCache(arnResolver, source, time.Second, bufferSize) for i := 0; i < 1000; i++ { source.Add(testutil.NewPodWithRole("ns", fmt.Sprintf("name-%d", i), fmt.Sprintf("ip-%d", i), "Running", "foo_role")) } @@ -106,12 +167,14 @@ func BenchmarkIsActiveRole(b *testing.B) { role := i % 100 source.Add(testutil.NewPodWithRole("ns", fmt.Sprintf("name-%d", i), fmt.Sprintf("ip-%d", i), "Running", fmt.Sprintf("role-%d", role))) } - c := NewPodCache(source, time.Second, bufferSize) + arnResolver := sts.DefaultResolver("arn:account:") + c := NewPodCache(arnResolver, source, time.Second, bufferSize) c.Run(ctx) b.StartTimer() for n := 0; n < b.N; n++ { - c.IsActivePodsForRole("role-0") + identity, _ := sts.NewRoleIdentity(arnResolver, "role-0", "", "") + c.IsActivePodsForRole(identity) } } diff --git a/pkg/k8s/testing/stub_finder.go b/pkg/k8s/testing/stub_finder.go index 1e741b41..b0b0f1ec 100644 --- a/pkg/k8s/testing/stub_finder.go +++ b/pkg/k8s/testing/stub_finder.go @@ -15,8 +15,10 @@ package testing import ( "context" + + "github.com/uswitch/kiam/pkg/aws/sts" "github.com/uswitch/kiam/pkg/k8s" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" ) func NewStubFinder(pod *v1.Pod) *StubFinder { @@ -51,7 +53,7 @@ func (f *stubAnnouncer) Pods() <-chan *v1.Pod { return f.pods } -func (f *stubAnnouncer) IsActivePodsForRole(role string) (bool, error) { +func (f *stubAnnouncer) IsActivePodsForRole(identity *sts.RoleIdentity) (bool, error) { return true, nil } diff --git a/pkg/prefetch/manager.go b/pkg/prefetch/manager.go index 05caa590..2ba6cd65 100644 --- a/pkg/prefetch/manager.go +++ b/pkg/prefetch/manager.go @@ -43,10 +43,15 @@ func (m *CredentialManager) fetchCredentials(ctx context.Context, pod *v1.Pod) { } role := k8s.PodRole(pod) - identity, err := m.arnResolver.Resolve(role) + sessionName := k8s.PodSessionName(pod) + externalID := k8s.PodExternalID(pod) + + identity, err := sts.NewRoleIdentity(m.arnResolver, role, sessionName, externalID) if err != nil { + logger.Errorf("error creating role identity: %s", err.Error()) return } + issued, err := m.fetchCredentialsFromCache(ctx, identity) if err != nil { logger.Errorf("error warming credentials: %s", err.Error()) @@ -81,7 +86,7 @@ func (m *CredentialManager) Run(ctx context.Context, parallelRoutines int) { func (m *CredentialManager) handleExpiring(ctx context.Context, credentials *sts.CachedCredentials) { logger := log.WithFields(sts.CredentialsFields(credentials.Identity, credentials.Credentials)) - active, err := m.IsRoleActive(credentials.Identity.Role) + active, err := m.IsRoleActive(credentials.Identity) if err != nil { logger.Errorf("error checking whether role active: %s", err.Error()) return @@ -99,6 +104,6 @@ func (m *CredentialManager) handleExpiring(ctx context.Context, credentials *sts } } -func (m *CredentialManager) IsRoleActive(role string) (bool, error) { - return m.announcer.IsActivePodsForRole(role) +func (m *CredentialManager) IsRoleActive(identity *sts.RoleIdentity) (bool, error) { + return m.announcer.IsActivePodsForRole(identity) } diff --git a/pkg/prefetch/manager_test.go b/pkg/prefetch/manager_test.go index 3624a747..049bfe92 100644 --- a/pkg/prefetch/manager_test.go +++ b/pkg/prefetch/manager_test.go @@ -33,7 +33,7 @@ func TestPrefetchRunningPods(t *testing.T) { requestedRoles := make(chan string) announcer := kt.NewStubAnnouncer() cache := testutil.NewStubCredentialsCache(func(identity *sts.RoleIdentity) (*sts.Credentials, error) { - requestedRoles <- identity.Role + requestedRoles <- identity.Role.Name return &sts.Credentials{}, nil }) manager := NewManager(cache, announcer, sts.DefaultResolver("prefix")) @@ -82,3 +82,49 @@ func TestRenewsCredentialsForRunningPod(t *testing.T) { t.Error("fail, didn't re-request expiring credentials in time") } } + +func TestPodSessionName(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + requested := make(chan *sts.RoleIdentity) + announcer := kt.NewStubAnnouncer() + cache := testutil.NewStubCredentialsCache(func(identity *sts.RoleIdentity) (*sts.Credentials, error) { + requested <- identity + return &sts.Credentials{}, nil + }) + + manager := NewManager(cache, announcer, sts.DefaultResolver("prefix")) + go manager.Run(ctx, 1) + + announcer.Announce(testutil.NewPodWithSessionName("ns", "name", "ip", "Running", "role", "session-name")) + identity := <-requested + if identity.SessionName != "session-name" { + t.Error("should have requested session-name") + } +} + +func TestPodExternalID(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + requested := make(chan *sts.RoleIdentity) + announcer := kt.NewStubAnnouncer() + cache := testutil.NewStubCredentialsCache(func(identity *sts.RoleIdentity) (*sts.Credentials, error) { + requested <- identity + return &sts.Credentials{}, nil + }) + + manager := NewManager(cache, announcer, sts.DefaultResolver("prefix")) + go manager.Run(ctx, 1) + + announcer.Announce(testutil.NewPodWithExternalID("ns", "name", "ip", "Running", "role", "external-id")) + identity := <-requested + if identity.ExternalID != "external-id" { + t.Error("should have requested external-id") + } +} diff --git a/pkg/server/policy.go b/pkg/server/policy.go index a8593068..3db1b4a9 100644 --- a/pkg/server/policy.go +++ b/pkg/server/policy.go @@ -17,9 +17,10 @@ package server import ( "context" "fmt" - v1 "k8s.io/api/core/v1" "regexp" + v1 "k8s.io/api/core/v1" + "github.com/uswitch/kiam/pkg/aws/sts" "github.com/uswitch/kiam/pkg/k8s" ) @@ -81,7 +82,7 @@ func (p *RequestingAnnotatedRolePolicy) IsAllowedAssumeRole(ctx context.Context, return &allowed{}, nil } - return &forbidden{requested: role, annotated: annotatedIdentiy.Role}, nil + return &forbidden{requested: role, annotated: annotatedIdentiy.Name}, nil } type NamespacePermittedRoleNamePolicy struct { diff --git a/pkg/server/server.go b/pkg/server/server.go index 45f9d273..fe77f6f9 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -105,10 +105,14 @@ func (k *KiamServer) GetPodCredentials(ctx context.Context, req *pb.GetPodCreden return nil, ErrPolicyForbidden } - identity, err := k.arnResolver.Resolve(req.Role) + sessionName := k8s.PodSessionName(pod) + externalID := k8s.PodExternalID(pod) + + identity, err := sts.NewRoleIdentity(k.arnResolver, req.Role, sessionName, externalID) if err != nil { return nil, err } + creds, err := k.credentialsProvider.CredentialsForRole(ctx, identity) if err != nil { logger.Errorf("error retrieving credentials: %s", err.Error()) diff --git a/pkg/server/server_builder.go b/pkg/server/server_builder.go index fee187fe..2ddf650b 100644 --- a/pkg/server/server_builder.go +++ b/pkg/server/server_builder.go @@ -18,6 +18,9 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "net" + "time" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" log "github.com/sirupsen/logrus" "github.com/uswitch/k8sc/official" @@ -33,8 +36,6 @@ import ( "k8s.io/client-go/kubernetes/scheme" typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/record" - "net" - "time" ) // KiamServerBuilder helps construct the KiamServer @@ -109,7 +110,12 @@ func (b *KiamServerBuilder) WithKubernetesClient() (*KiamServerBuilder, error) { return nil, err } - podCache := k8s.NewPodCache(k8s.NewListWatch(client, k8s.ResourcePods), b.config.PodSyncInterval, b.config.PrefetchBufferSize) + arnResolver, err := newRoleARNResolver(b.config) + if err != nil { + return nil, err + } + + podCache := k8s.NewPodCache(arnResolver, k8s.NewListWatch(client, k8s.ResourcePods), b.config.PodSyncInterval, b.config.PrefetchBufferSize) nsCache := k8s.NewNamespaceCache(k8s.NewListWatch(client, k8s.ResourceNamespaces), time.Minute) b.WithCaches(podCache, nsCache) diff --git a/pkg/server/server_integration_test.go b/pkg/server/server_integration_test.go index e2805457..35ede73f 100644 --- a/pkg/server/server_integration_test.go +++ b/pkg/server/server_integration_test.go @@ -16,12 +16,14 @@ package server import ( "context" + "testing" + "time" + "github.com/fortytw2/leaktest" + "github.com/uswitch/kiam/pkg/aws/sts" "github.com/uswitch/kiam/pkg/k8s" "google.golang.org/grpc" kt "k8s.io/client-go/tools/cache/testing" - "testing" - "time" ) const ( @@ -120,7 +122,7 @@ func newTestServer(ctx context.Context) (*KiamServer, *kt.FakeControllerSource, source := kt.NewFakeControllerSource() defer source.Shutdown() - podCache := k8s.NewPodCache(source, time.Second, defaultBuffer) + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) podCache.Run(ctx) namespaceCache := k8s.NewNamespaceCache(source, time.Second) namespaceCache.Run(ctx) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 49224bd4..7929279b 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -3,10 +3,11 @@ package server import ( "context" "fmt" - v1 "k8s.io/api/core/v1" "testing" "time" + v1 "k8s.io/api/core/v1" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/fortytw2/leaktest" "github.com/uswitch/kiam/pkg/aws/sts" @@ -40,7 +41,7 @@ func TestReturnsErrorWhenPodNotFound(t *testing.T) { source := kt.NewFakeControllerSource() defer source.Shutdown() - podCache := k8s.NewPodCache(source, time.Second, defaultBuffer) + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) server := &KiamServer{pods: podCache} _, err := server.GetPodCredentials(context.Background(), &pb.GetPodCredentialsRequest{}) @@ -60,7 +61,7 @@ func TestReturnsPolicyErrorWhenForbidden(t *testing.T) { defer source.Shutdown() source.Add(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Running", "running_role")) - podCache := k8s.NewPodCache(source, time.Second, defaultBuffer) + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) podCache.Run(ctx) server := &KiamServer{pods: podCache, assumePolicy: &forbidPolicy{}, arnResolver: sts.DefaultResolver("prefix")} @@ -80,7 +81,7 @@ func TestReturnsAnnotatedPodRole(t *testing.T) { defer source.Shutdown() source.Add(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Running", "running_role")) - podCache := k8s.NewPodCache(source, time.Second, defaultBuffer) + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) podCache.Run(ctx) server := &KiamServer{pods: podCache, assumePolicy: &allowPolicy{}, credentialsProvider: &stubCredentialsProvider{accessKey: "A1234"}} @@ -101,7 +102,7 @@ func TestReturnsErrorFromGetPodRoleWhenPodNotFound(t *testing.T) { defer source.Shutdown() source.Add(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Running", "running_role")) - podCache := k8s.NewPodCache(source, time.Second, defaultBuffer) + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) podCache.Run(ctx) server := &KiamServer{pods: podCache, assumePolicy: &allowPolicy{}, credentialsProvider: &stubCredentialsProvider{accessKey: "A1234"}} @@ -125,7 +126,7 @@ func TestReturnsCredentials(t *testing.T) { defer source.Shutdown() source.Add(testutil.NewPodWithRole("ns", "name", "192.168.0.1", "Running", roleName)) - podCache := k8s.NewPodCache(source, time.Second, defaultBuffer) + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) podCache.Run(ctx) server := &KiamServer{pods: podCache, assumePolicy: &allowPolicy{}, credentialsProvider: &stubCredentialsProvider{accessKey: "A1234"}, arnResolver: sts.DefaultResolver("prefix")} @@ -143,11 +144,72 @@ func TestReturnsCredentials(t *testing.T) { } } +func TestGetPodCredentialsWithSessionName(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const roleName = "role" + const sessionName = "session" + + source := kt.NewFakeControllerSource() + defer source.Shutdown() + source.Add(testutil.NewPodWithSessionName("ns", "name", "192.168.0.1", "Running", roleName, sessionName)) + + credentialsProvider := stubCredentialsProvider{accessKey: "A1234"} + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) + podCache.Run(ctx) + server := &KiamServer{pods: podCache, assumePolicy: &allowPolicy{}, credentialsProvider: &credentialsProvider, arnResolver: sts.DefaultResolver("prefix")} + + _, err := server.GetPodCredentials(ctx, &pb.GetPodCredentialsRequest{Ip: "192.168.0.1", Role: roleName}) + if err != nil { + t.Error("unexpected error", err) + } + + identity := credentialsProvider.requestedIdentity + if identity.SessionName != sessionName { + t.Error("unexpected session-name", identity.SessionName) + } +} + +func TestGetPodCredentialsWithExternalID(t *testing.T) { + defer leaktest.Check(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const roleName = "role" + const externalID = "external-id" + + source := kt.NewFakeControllerSource() + defer source.Shutdown() + source.Add(testutil.NewPodWithExternalID("ns", "name", "192.168.0.1", "Running", roleName, externalID)) + + credentialsProvider := stubCredentialsProvider{accessKey: "A1234"} + podCache := k8s.NewPodCache(sts.DefaultResolver("arn:account:"), source, time.Second, defaultBuffer) + podCache.Run(ctx) + server := &KiamServer{pods: podCache, assumePolicy: &allowPolicy{}, credentialsProvider: &credentialsProvider, arnResolver: sts.DefaultResolver("prefix")} + + _, err := server.GetPodCredentials(ctx, &pb.GetPodCredentialsRequest{Ip: "192.168.0.1", Role: roleName}) + if err != nil { + t.Error("unexpected error", err) + } + + identity := credentialsProvider.requestedIdentity + if identity.ExternalID != externalID { + t.Error("unexpected external-id", identity.ExternalID) + } +} + type stubCredentialsProvider struct { - accessKey string + accessKey string + requestedIdentity *sts.RoleIdentity } func (c *stubCredentialsProvider) CredentialsForRole(ctx context.Context, identity *sts.RoleIdentity) (*sts.Credentials, error) { + c.requestedIdentity = identity + return &sts.Credentials{ AccessKeyId: c.accessKey, }, nil diff --git a/pkg/testutil/kubernetes.go b/pkg/testutil/kubernetes.go index 5a11b0bc..58486231 100644 --- a/pkg/testutil/kubernetes.go +++ b/pkg/testutil/kubernetes.go @@ -17,7 +17,7 @@ import ( "fmt" "time" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -59,3 +59,15 @@ func NewPodWithRole(namespace, name, ip, phase, role string) *v1.Pod { pod.ObjectMeta.Annotations = map[string]string{"iam.amazonaws.com/role": role} return pod } + +func NewPodWithSessionName(namespace, name, ip, phase, role, sessionName string) *v1.Pod { + pod := NewPodWithRole(namespace, name, ip, phase, role) + pod.ObjectMeta.Annotations["iam.amazonaws.com/session-name"] = sessionName + return pod +} + +func NewPodWithExternalID(namespace, name, ip, phase, role, externalID string) *v1.Pod { + pod := NewPodWithRole(namespace, name, ip, phase, role) + pod.ObjectMeta.Annotations["iam.amazonaws.com/external-id"] = externalID + return pod +} diff --git a/proto/service.pb.go b/proto/service.pb.go index 831a0be3..174c8e66 100644 --- a/proto/service.pb.go +++ b/proto/service.pb.go @@ -8,14 +8,15 @@ package kiam import ( context "context" + reflect "reflect" + sync "sync" + proto "github.com/golang/protobuf/proto" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" ) const (