From 02e246ac0e3692ccd62d0dba9b7da917b5819063 Mon Sep 17 00:00:00 2001 From: Chris Bradfield Date: Thu, 7 Mar 2019 11:24:44 -0800 Subject: [PATCH 1/3] Implement regional STS endpoint support --- cmd/kiam/server.go | 1 + pkg/aws/sts/gateway.go | 40 +++++++++++++++++++----- pkg/aws/sts/gateway_test.go | 62 +++++++++++++++++++++++++++++++++++++ pkg/server/server.go | 10 +++--- 4 files changed, 102 insertions(+), 11 deletions(-) create mode 100644 pkg/aws/sts/gateway_test.go diff --git a/cmd/kiam/server.go b/cmd/kiam/server.go index 7ce11edd..2f6aa986 100644 --- a/cmd/kiam/server.go +++ b/cmd/kiam/server.go @@ -57,6 +57,7 @@ func (o *serverOptions) bind(parser parser) { parser.Flag("session-duration", "Requested session duration for STS Tokens.").Default("15m").DurationVar(&o.SessionDuration) parser.Flag("session-refresh", "How soon STS Tokens should be refreshed before their expiration.").Default("5m").DurationVar(&o.SessionRefresh) parser.Flag("assume-role-arn", "IAM Role to assume before processing requests").Default("").StringVar(&o.AssumeRoleArn) + parser.Flag("region", "AWS Region to use for STS calls").Default("").StringVar(&o.Region) } func (opts *serverCommand) Run() { diff --git a/pkg/aws/sts/gateway.go b/pkg/aws/sts/gateway.go index ef5d8457..af6fa6be 100644 --- a/pkg/aws/sts/gateway.go +++ b/pkg/aws/sts/gateway.go @@ -15,10 +15,13 @@ package sts import ( "context" + "fmt" + "strings" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/prometheus/client_golang/prometheus" @@ -33,15 +36,38 @@ type DefaultSTSGateway struct { session *session.Session } -func DefaultGateway(assumeRoleArn string) *DefaultSTSGateway { - if assumeRoleArn == "" { - return &DefaultSTSGateway{session: session.Must(session.NewSession())} +func DefaultGateway(assumeRoleArn, region string) *DefaultSTSGateway { + config := aws.NewConfig() + if assumeRoleArn != "" { + config.WithCredentials(stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn)) + } + + if region != "" { + config.WithRegion(region).WithEndpointResolver(endpoints.ResolverFunc(endpointFor)) + } + + session := session.Must(session.NewSession(config)) + return &DefaultSTSGateway{session: session} +} + +func endpointFor(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + var url string + + _, exists := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region) + defaultResolver := endpoints.DefaultResolver() + + // if the region doesn't exist or it is a fips endpoint, fallback to the default resolver + if !exists || strings.HasSuffix(region, "-fips") { + return defaultResolver.EndpointFor(service, region, opts...) + } + + if strings.HasPrefix(region, "cn-") { + url = fmt.Sprintf("https://sts.%s.amazonaws.com.cn", region) } else { - session := session.Must(session.NewSession(&aws.Config{ - Credentials: stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn), - })) - return &DefaultSTSGateway{session: session} + url = fmt.Sprintf("https://sts.%s.amazonaws.com", region) } + + return endpoints.ResolvedEndpoint{URL: url, SigningRegion: region}, nil } func (g *DefaultSTSGateway) Issue(ctx context.Context, roleARN, sessionName string, expiry time.Duration) (*Credentials, error) { diff --git a/pkg/aws/sts/gateway_test.go b/pkg/aws/sts/gateway_test.go new file mode 100644 index 00000000..64dfba44 --- /dev/null +++ b/pkg/aws/sts/gateway_test.go @@ -0,0 +1,62 @@ +// Copyright 2017 uSwitch +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package sts + +import ( + "testing" + + "github.com/aws/aws-sdk-go/service/sts" +) + +func TestRegionalGateway(t *testing.T) { + gateway := DefaultGateway("", "us-west-2") + + config := gateway.session.ClientConfig(sts.EndpointsID) + + if config.SigningRegion != "us-west-2" { + t.Error("Unexpected region. Region was: ", config.SigningRegion) + } + + if config.Endpoint != "https://sts.us-west-2.amazonaws.com" { + t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint) + } +} + +func TestRegionalGatewayCn(t *testing.T) { + gateway := DefaultGateway("", "cn-north-1") + + config := gateway.session.ClientConfig(sts.EndpointsID) + + if config.SigningRegion != "cn-north-1" { + t.Error("Unexpected region. Region was: ", config.SigningRegion) + } + + if config.Endpoint != "https://sts.cn-north-1.amazonaws.com.cn" { + t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint) + } +} + +func TestRegionalGatewayFips(t *testing.T) { + gateway := DefaultGateway("", "us-east-1-fips") + + config := gateway.session.ClientConfig(sts.EndpointsID) + + if config.SigningRegion != "us-east-1" { + t.Error("Unexpected region. Region was: ", config.SigningRegion) + } + + if config.Endpoint != "https://sts-fips.us-east-1.amazonaws.com" { + t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index b4216ee7..dc1bddbe 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -18,6 +18,10 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" + "net" + "time" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/grpc-ecosystem/go-grpc-prometheus" log "github.com/sirupsen/logrus" @@ -29,15 +33,12 @@ import ( pb "github.com/uswitch/kiam/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "io/ioutil" "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/record" - "net" - "time" ) // Config controls the setup of the gRPC server @@ -54,6 +55,7 @@ type Config struct { ParallelFetcherProcesses int PrefetchBufferSize int AssumeRoleArn string + Region string } // TLSConfig controls TLS @@ -231,7 +233,7 @@ func NewServer(config *Config) (*KiamServer, error) { server.namespaces = k8s.NewNamespaceCache(k8s.NewListWatch(client, k8s.ResourceNamespaces), time.Minute) server.eventRecorder = eventRecorder(client) - stsGateway := sts.DefaultGateway(config.AssumeRoleArn) + stsGateway := sts.DefaultGateway(config.AssumeRoleArn, config.Region) arnResolver, err := newRoleARNResolver(config) if err != nil { return nil, err From 265f141da9d940f7b1cf0931c0b8d0ced038c770 Mon Sep 17 00:00:00 2001 From: Christopher Bradfield Date: Tue, 12 Mar 2019 21:07:50 -0700 Subject: [PATCH 2/3] Add an example to the description and check the DNS resolution of the lookup. Refactor to cache the result so that we don't validate the region on every request. --- cmd/kiam/server.go | 2 +- pkg/aws/sts/gateway.go | 75 +++++++++++++++++++++++++------------ pkg/aws/sts/gateway_test.go | 32 ++++++++++++++-- pkg/server/server.go | 6 ++- 4 files changed, 87 insertions(+), 28 deletions(-) diff --git a/cmd/kiam/server.go b/cmd/kiam/server.go index 2f6aa986..8e181e5f 100644 --- a/cmd/kiam/server.go +++ b/cmd/kiam/server.go @@ -57,7 +57,7 @@ func (o *serverOptions) bind(parser parser) { parser.Flag("session-duration", "Requested session duration for STS Tokens.").Default("15m").DurationVar(&o.SessionDuration) parser.Flag("session-refresh", "How soon STS Tokens should be refreshed before their expiration.").Default("5m").DurationVar(&o.SessionRefresh) parser.Flag("assume-role-arn", "IAM Role to assume before processing requests").Default("").StringVar(&o.AssumeRoleArn) - parser.Flag("region", "AWS Region to use for STS calls").Default("").StringVar(&o.Region) + parser.Flag("region", "AWS Region to use for regional STS calls (e.g. us-west-2). Defaults to the global endpoint.").Default("").StringVar(&o.Region) } func (opts *serverCommand) Run() { diff --git a/pkg/aws/sts/gateway.go b/pkg/aws/sts/gateway.go index af6fa6be..d5f34fc4 100644 --- a/pkg/aws/sts/gateway.go +++ b/pkg/aws/sts/gateway.go @@ -16,6 +16,7 @@ package sts import ( "context" "fmt" + "net" "strings" "time" @@ -32,42 +33,70 @@ type STSGateway interface { Issue(ctx context.Context, role, session string, expiry time.Duration) (*Credentials, error) } -type DefaultSTSGateway struct { - session *session.Session +type regionalResolver struct { + endpoint endpoints.ResolvedEndpoint } -func DefaultGateway(assumeRoleArn, region string) *DefaultSTSGateway { - config := aws.NewConfig() - if assumeRoleArn != "" { - config.WithCredentials(stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn)) - } - - if region != "" { - config.WithRegion(region).WithEndpointResolver(endpoints.ResolverFunc(endpointFor)) - } - - session := session.Must(session.NewSession(config)) - return &DefaultSTSGateway{session: session} +func (r *regionalResolver) EndpointFor(svc, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + return r.endpoint, nil } -func endpointFor(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - var url string +func newRegionalResolver(region string) (endpoints.Resolver, error) { + var host string - _, exists := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region) defaultResolver := endpoints.DefaultResolver() - // if the region doesn't exist or it is a fips endpoint, fallback to the default resolver - if !exists || strings.HasSuffix(region, "-fips") { - return defaultResolver.EndpointFor(service, region, opts...) + // if it is a FIPS region, let the default resolver give us a result. + if strings.HasSuffix(region, "-fips") { + endpoint, err := defaultResolver.EndpointFor("sts", region) + if err != nil { + return nil, err + } + return ®ionalResolver{endpoint}, nil + } + + if _, exists := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region); !exists { + return nil, fmt.Errorf("Invalid region: %s", region) } if strings.HasPrefix(region, "cn-") { - url = fmt.Sprintf("https://sts.%s.amazonaws.com.cn", region) + host = fmt.Sprintf("sts.%s.amazonaws.com.cn", region) } else { - url = fmt.Sprintf("https://sts.%s.amazonaws.com", region) + host = fmt.Sprintf("sts.%s.amazonaws.com", region) } - return endpoints.ResolvedEndpoint{URL: url, SigningRegion: region}, nil + if _, err := net.LookupHost(host); err != nil { + return nil, fmt.Errorf("Regional STS endpoint does not exist: %s", host) + } + + return ®ionalResolver{endpoints.ResolvedEndpoint{ + URL: fmt.Sprintf("https://%s", host), + SigningRegion: region, + }}, nil +} + +type DefaultSTSGateway struct { + session *session.Session + resolver endpoints.Resolver +} + +func DefaultGateway(assumeRoleArn, region string) (*DefaultSTSGateway, error) { + config := aws.NewConfig() + if assumeRoleArn != "" { + config.WithCredentials(stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn)) + } + + if region != "" { + resolver, err := newRegionalResolver(region) + if err != nil { + return nil, err + } + + config.WithRegion(region).WithEndpointResolver(resolver) + } + + session := session.Must(session.NewSession(config)) + return &DefaultSTSGateway{session: session}, nil } func (g *DefaultSTSGateway) Issue(ctx context.Context, roleARN, sessionName string, expiry time.Duration) (*Credentials, error) { diff --git a/pkg/aws/sts/gateway_test.go b/pkg/aws/sts/gateway_test.go index 64dfba44..b33d3767 100644 --- a/pkg/aws/sts/gateway_test.go +++ b/pkg/aws/sts/gateway_test.go @@ -20,7 +20,10 @@ import ( ) func TestRegionalGateway(t *testing.T) { - gateway := DefaultGateway("", "us-west-2") + gateway, err := DefaultGateway("", "us-west-2") + if err != nil { + t.Error(err) + } config := gateway.session.ClientConfig(sts.EndpointsID) @@ -34,7 +37,10 @@ func TestRegionalGateway(t *testing.T) { } func TestRegionalGatewayCn(t *testing.T) { - gateway := DefaultGateway("", "cn-north-1") + gateway, err := DefaultGateway("", "cn-north-1") + if err != nil { + t.Error(err) + } config := gateway.session.ClientConfig(sts.EndpointsID) @@ -48,7 +54,10 @@ func TestRegionalGatewayCn(t *testing.T) { } func TestRegionalGatewayFips(t *testing.T) { - gateway := DefaultGateway("", "us-east-1-fips") + gateway, err := DefaultGateway("", "us-east-1-fips") + if err != nil { + t.Error(err) + } config := gateway.session.ClientConfig(sts.EndpointsID) @@ -60,3 +69,20 @@ func TestRegionalGatewayFips(t *testing.T) { t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint) } } + +func TestDefaultGlobalGateway(t *testing.T) { + gateway, err := DefaultGateway("", "") + if err != nil { + t.Error(err) + } + + config := gateway.session.ClientConfig(sts.EndpointsID) + + if config.SigningRegion != "us-east-1" { + t.Error("Unexpected region. Region was: ", config.SigningRegion) + } + + if config.Endpoint != "https://sts.amazonaws.com" { + t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index dc1bddbe..e96f7b25 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -233,7 +233,11 @@ func NewServer(config *Config) (*KiamServer, error) { server.namespaces = k8s.NewNamespaceCache(k8s.NewListWatch(client, k8s.ResourceNamespaces), time.Minute) server.eventRecorder = eventRecorder(client) - stsGateway := sts.DefaultGateway(config.AssumeRoleArn, config.Region) + stsGateway, err := sts.DefaultGateway(config.AssumeRoleArn, config.Region) + if err != nil { + return nil, err + } + arnResolver, err := newRoleARNResolver(config) if err != nil { return nil, err From 5afd7d367d7fa392278f733608e9d0b0a23e6377 Mon Sep 17 00:00:00 2001 From: Christopher Bradfield Date: Tue, 12 Mar 2019 21:18:48 -0700 Subject: [PATCH 3/3] Update the README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 01d4e2c1..b26adfa2 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ We have a [#kiam Slack channel](https://kubernetes.slack.com/messages/CBQLKVABH/ * [Prometheus and StatsD metrics](docs/METRICS.md) * Uses the Kubernetes Events API to record IAM errors against the Pod so that cluster users can more readily diagnose IAM problems (via `kubectl describe pod ...`) * Text and JSON log formats +* Optional regional STS endpoint support ## Overview From the [AWS documentation on IAM roles](http://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles.html):