Skip to content
This repository has been archived by the owner on Mar 5, 2024. It is now read-only.

Implement regional STS endpoint support #229

Merged
merged 3 commits into from
Mar 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ We have a [#kiam Slack channel](https://kubernetes.slack.com/messages/CBQLKVABH/
* [Prometheus and StatsD metrics](docs/METRICS.md)
* Uses the Kubernetes Events API to record IAM errors against the Pod so that cluster users can more readily diagnose IAM problems (via `kubectl describe pod ...`)
* Text and JSON log formats
* Optional regional STS endpoint support

## Overview
From the [AWS documentation on IAM roles](http://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles.html):
Expand Down
1 change: 1 addition & 0 deletions cmd/kiam/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (o *serverOptions) bind(parser parser) {
parser.Flag("session-duration", "Requested session duration for STS Tokens.").Default("15m").DurationVar(&o.SessionDuration)
parser.Flag("session-refresh", "How soon STS Tokens should be refreshed before their expiration.").Default("5m").DurationVar(&o.SessionRefresh)
parser.Flag("assume-role-arn", "IAM Role to assume before processing requests").Default("").StringVar(&o.AssumeRoleArn)
parser.Flag("region", "AWS Region to use for regional STS calls (e.g. us-west-2). Defaults to the global endpoint.").Default("").StringVar(&o.Region)
}

func (opts *serverCommand) Run() {
Expand Down
73 changes: 64 additions & 9 deletions pkg/aws/sts/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ package sts

import (
"context"
"fmt"
"net"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -29,19 +33,70 @@ type STSGateway interface {
Issue(ctx context.Context, role, session string, expiry time.Duration) (*Credentials, error)
}

type DefaultSTSGateway struct {
session *session.Session
type regionalResolver struct {
endpoint endpoints.ResolvedEndpoint
}

func (r *regionalResolver) EndpointFor(svc, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
return r.endpoint, nil
}

func DefaultGateway(assumeRoleArn string) *DefaultSTSGateway {
if assumeRoleArn == "" {
return &DefaultSTSGateway{session: session.Must(session.NewSession())}
func newRegionalResolver(region string) (endpoints.Resolver, error) {
var host string

defaultResolver := endpoints.DefaultResolver()

// if it is a FIPS region, let the default resolver give us a result.
if strings.HasSuffix(region, "-fips") {
endpoint, err := defaultResolver.EndpointFor("sts", region)
if err != nil {
return nil, err
}
return &regionalResolver{endpoint}, nil
}

if _, exists := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region); !exists {
return nil, fmt.Errorf("Invalid region: %s", region)
}

if strings.HasPrefix(region, "cn-") {
host = fmt.Sprintf("sts.%s.amazonaws.com.cn", region)
} else {
session := session.Must(session.NewSession(&aws.Config{
Credentials: stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn),
}))
return &DefaultSTSGateway{session: session}
host = fmt.Sprintf("sts.%s.amazonaws.com", region)
}

if _, err := net.LookupHost(host); err != nil {
return nil, fmt.Errorf("Regional STS endpoint does not exist: %s", host)
}

return &regionalResolver{endpoints.ResolvedEndpoint{
URL: fmt.Sprintf("https://%s", host),
SigningRegion: region,
}}, nil
}

type DefaultSTSGateway struct {
session *session.Session
resolver endpoints.Resolver
}

func DefaultGateway(assumeRoleArn, region string) (*DefaultSTSGateway, error) {
config := aws.NewConfig()
if assumeRoleArn != "" {
config.WithCredentials(stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn))
}

if region != "" {
resolver, err := newRegionalResolver(region)
if err != nil {
return nil, err
}

config.WithRegion(region).WithEndpointResolver(resolver)
}

session := session.Must(session.NewSession(config))
return &DefaultSTSGateway{session: session}, nil
}

func (g *DefaultSTSGateway) Issue(ctx context.Context, roleARN, sessionName string, expiry time.Duration) (*Credentials, error) {
Expand Down
88 changes: 88 additions & 0 deletions pkg/aws/sts/gateway_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2017 uSwitch
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sts

import (
"testing"

"github.com/aws/aws-sdk-go/service/sts"
)

func TestRegionalGateway(t *testing.T) {
gateway, err := DefaultGateway("", "us-west-2")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "us-west-2" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts.us-west-2.amazonaws.com" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}

func TestRegionalGatewayCn(t *testing.T) {
gateway, err := DefaultGateway("", "cn-north-1")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "cn-north-1" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts.cn-north-1.amazonaws.com.cn" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}

func TestRegionalGatewayFips(t *testing.T) {
gateway, err := DefaultGateway("", "us-east-1-fips")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "us-east-1" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts-fips.us-east-1.amazonaws.com" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}

func TestDefaultGlobalGateway(t *testing.T) {
gateway, err := DefaultGateway("", "")
if err != nil {
t.Error(err)
}

config := gateway.session.ClientConfig(sts.EndpointsID)

if config.SigningRegion != "us-east-1" {
t.Error("Unexpected region. Region was: ", config.SigningRegion)
}

if config.Endpoint != "https://sts.amazonaws.com" {
t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint)
}
}
14 changes: 10 additions & 4 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/sirupsen/logrus"
Expand All @@ -29,15 +33,12 @@ import (
pb "github.com/uswitch/kiam/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"io/ioutil"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/kubernetes/scheme"
typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/tools/record"
"net"
"time"
)

// Config controls the setup of the gRPC server
Expand All @@ -54,6 +55,7 @@ type Config struct {
ParallelFetcherProcesses int
PrefetchBufferSize int
AssumeRoleArn string
Region string
}

// TLSConfig controls TLS
Expand Down Expand Up @@ -231,7 +233,11 @@ func NewServer(config *Config) (*KiamServer, error) {
server.namespaces = k8s.NewNamespaceCache(k8s.NewListWatch(client, k8s.ResourceNamespaces), time.Minute)
server.eventRecorder = eventRecorder(client)

stsGateway := sts.DefaultGateway(config.AssumeRoleArn)
stsGateway, err := sts.DefaultGateway(config.AssumeRoleArn, config.Region)
if err != nil {
return nil, err
}

arnResolver, err := newRoleARNResolver(config)
if err != nil {
return nil, err
Expand Down