Skip to content
This repository has been archived by the owner on Mar 5, 2024. It is now read-only.

Commit

Permalink
Implement regional STS endpoint support (#229)
Browse files Browse the repository at this point in the history
* Implement regional STS endpoint support
* Add an example to the description and check the DNS resolution of the lookup.
* Update the README
  • Loading branch information
cjbradfield authored and pingles committed Mar 13, 2019
1 parent 179d14f commit adf03bb
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ We have a [#kiam Slack channel](https://kubernetes.slack.com/messages/CBQLKVABH/
* [Prometheus and StatsD metrics](docs/METRICS.md)
* Uses the Kubernetes Events API to record IAM errors against the Pod so that cluster users can more readily diagnose IAM problems (via `kubectl describe pod ...`)
* Text and JSON log formats
* Optional regional STS endpoint support

## Overview
From the [AWS documentation on IAM roles](http://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles.html):
Expand Down
1 change: 1 addition & 0 deletions cmd/kiam/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (o *serverOptions) bind(parser parser) {
parser.Flag("session-duration", "Requested session duration for STS Tokens.").Default("15m").DurationVar(&o.SessionDuration)
parser.Flag("session-refresh", "How soon STS Tokens should be refreshed before their expiration.").Default("5m").DurationVar(&o.SessionRefresh)
parser.Flag("assume-role-arn", "IAM Role to assume before processing requests").Default("").StringVar(&o.AssumeRoleArn)
parser.Flag("region", "AWS Region to use for regional STS calls (e.g. us-west-2). Defaults to the global endpoint.").Default("").StringVar(&o.Region)
}

func (opts *serverCommand) Run() {
Expand Down
73 changes: 64 additions & 9 deletions pkg/aws/sts/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ package sts

import (
"context"
"fmt"
"net"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -29,19 +33,70 @@ type STSGateway interface {
Issue(ctx context.Context, role, session string, expiry time.Duration) (*Credentials, error)
}

type DefaultSTSGateway struct {
session *session.Session
type regionalResolver struct {
endpoint endpoints.ResolvedEndpoint
}

func (r *regionalResolver) EndpointFor(svc, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
return r.endpoint, nil
}

func DefaultGateway(assumeRoleArn string) *DefaultSTSGateway {
if assumeRoleArn == "" {
return &DefaultSTSGateway{session: session.Must(session.NewSession())}
func newRegionalResolver(region string) (endpoints.Resolver, error) {
var host string

defaultResolver := endpoints.DefaultResolver()

// if it is a FIPS region, let the default resolver give us a result.
if strings.HasSuffix(region, "-fips") {
endpoint, err := defaultResolver.EndpointFor("sts", region)
if err != nil {
return nil, err
}
return &regionalResolver{endpoint}, nil
}

if _, exists := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region); !exists {
return nil, fmt.Errorf("Invalid region: %s", region)
}

if strings.HasPrefix(region, "cn-") {
host = fmt.Sprintf("sts.%s.amazonaws.com.cn", region)
} else {
session := session.Must(session.NewSession(&aws.Config{
Credentials: stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn),
}))
return &DefaultSTSGateway{session: session}
host = fmt.Sprintf("sts.%s.amazonaws.com", region)
}

if _, err := net.LookupHost(host); err != nil {
return nil, fmt.Errorf("Regional STS endpoint does not exist: %s", host)
}

return &regionalResolver{endpoints.ResolvedEndpoint{
URL: fmt.Sprintf("https://%s", host),
SigningRegion: region,
}}, nil
}

type DefaultSTSGateway struct {
session *session.Session
resolver endpoints.Resolver
}

func DefaultGateway(assumeRoleArn, region string) (*DefaultSTSGateway, error) {
config := aws.NewConfig()
if assumeRoleArn != "" {
config.WithCredentials(stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn))
}

if region != "" {
resolver, err := newRegionalResolver(region)
if err != nil {
return nil, err
}

config.WithRegion(region).WithEndpointResolver(resolver)
}

session := session.Must(session.NewSession(config))
return &DefaultSTSGateway{session: session}, nil
}

func (g *DefaultSTSGateway) Issue(ctx context.Context, roleARN, sessionName string, expiry time.Duration) (*Credentials, error) {
Expand Down
88 changes: 88 additions & 0 deletions pkg/aws/sts/gateway_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2017 uSwitch
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sts

import (
"testing"

"github.com/aws/aws-sdk-go/service/sts"
)

func TestRegionalGateway(t *testing.T) {
gateway, err := DefaultGateway("", "us-west-2")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "us-west-2" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts.us-west-2.amazonaws.com" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}

func TestRegionalGatewayCn(t *testing.T) {
gateway, err := DefaultGateway("", "cn-north-1")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "cn-north-1" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts.cn-north-1.amazonaws.com.cn" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}

func TestRegionalGatewayFips(t *testing.T) {
gateway, err := DefaultGateway("", "us-east-1-fips")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "us-east-1" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts-fips.us-east-1.amazonaws.com" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}

func TestDefaultGlobalGateway(t *testing.T) {
gateway, err := DefaultGateway("", "")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "us-east-1" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts.amazonaws.com" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}
14 changes: 10 additions & 4 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/sirupsen/logrus"
Expand All @@ -29,15 +33,12 @@ import (
pb "github.com/uswitch/kiam/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"io/ioutil"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/kubernetes/scheme"
typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/tools/record"
"net"
"time"
)

// Config controls the setup of the gRPC server
Expand All @@ -54,6 +55,7 @@ type Config struct {
ParallelFetcherProcesses int
PrefetchBufferSize int
AssumeRoleArn string
Region string
}

// TLSConfig controls TLS
Expand Down Expand Up @@ -231,7 +233,11 @@ func NewServer(config *Config) (*KiamServer, error) {
server.namespaces = k8s.NewNamespaceCache(k8s.NewListWatch(client, k8s.ResourceNamespaces), time.Minute)
server.eventRecorder = eventRecorder(client)

stsGateway := sts.DefaultGateway(config.AssumeRoleArn)
stsGateway, err := sts.DefaultGateway(config.AssumeRoleArn, config.Region)
if err != nil {
return nil, err
}

arnResolver, err := newRoleARNResolver(config)
if err != nil {
return nil, err
Expand Down

0 comments on commit adf03bb

Please sign in to comment.