From 9e635cac54cb6ffcd5f43cf071e67fd516d73071 Mon Sep 17 00:00:00 2001 From: Logan Hanks Date: Wed, 3 Jul 2019 14:03:25 -0700 Subject: [PATCH] Add xray.AWSSession to install handlers on session (#97) An application has to call xray.AWS for each AWS client it constructs. This creates opportunities for blind spots if someone forgets to configure a new client. The xray.AWSSession installs the same handlers at the Session level. Clients inherit handlers from the session they're created with. As long as the application systematically reuses the same session to create clients, it only needs to install X-Ray handlers once. --- xray/aws.go | 40 +++++++++++------ xray/aws_test.go | 114 ++++++++++++++++++++++++++++++++--------------- 2 files changed, 106 insertions(+), 48 deletions(-) diff --git a/xray/aws.go b/xray/aws.go index 5be4cd89..f1344590 100644 --- a/xray/aws.go +++ b/xray/aws.go @@ -21,6 +21,7 @@ import ( "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-xray-sdk-go/internal/logger" "github.com/aws/aws-xray-sdk-go/resources" ) @@ -137,15 +138,16 @@ var xRayAfterRetryHandler = request.NamedHandler{ }, } -func pushHandlers(c *client.Client) { - c.Handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler) - c.Handlers.Build.PushBackNamed(xRayAfterBuildHandler) - c.Handlers.Sign.PushFrontNamed(xRayBeforeSignHandler) - c.Handlers.Send.PushBackNamed(xRayAfterSendHandler) - c.Handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler) - c.Handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler) - c.Handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler) - c.Handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler) +func pushHandlers(handlers *request.Handlers, completionWhitelistFilename string) { + handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler) + handlers.Build.PushBackNamed(xRayAfterBuildHandler) + handlers.Sign.PushFrontNamed(xRayBeforeSignHandler) + handlers.Send.PushBackNamed(xRayAfterSendHandler) + handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler) + handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler) + handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler) + handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler) + handlers.Complete.PushFrontNamed(xrayCompleteHandler(completionWhitelistFilename)) } // AWS adds X-Ray tracing to an AWS client. @@ -153,8 +155,7 @@ func AWS(c *client.Client) { if c == nil { panic("Please initialize the provided AWS client before passing to the AWS() method.") } - pushHandlers(c) - c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler("")) + pushHandlers(&c.Handlers, "") } // AWSWithWhitelist allows a custom parameter whitelist JSON file to be defined. @@ -162,8 +163,21 @@ func AWSWithWhitelist(c *client.Client, filename string) { if c == nil { panic("Please initialize the provided AWS client before passing to the AWSWithWhitelist() method.") } - pushHandlers(c) - c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler(filename)) + pushHandlers(&c.Handlers, filename) +} + +// AWSSession adds X-Ray tracing to an AWS session. Clients created under this +// session will inherit X-Ray tracing. +func AWSSession(s *session.Session) *session.Session { + pushHandlers(&s.Handlers, "") + return s +} + +// AWSSessionWithWhitelist allows a custom parameter whitelist JSON file to be +// defined. +func AWSSessionWithWhitelist(s *session.Session, filename string) *session.Session { + pushHandlers(&s.Handlers, filename) + return s } func xrayCompleteHandler(filename string) request.NamedHandler { diff --git a/xray/aws_test.go b/xray/aws_test.go index 3fd1209d..48916600 100644 --- a/xray/aws_test.go +++ b/xray/aws_test.go @@ -14,22 +14,86 @@ import ( "github.com/stretchr/testify/assert" ) -func TestClientSuccessfulConnection(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b := []byte(`{}`) - w.WriteHeader(http.StatusOK) - w.Write(b) - })) - - svc := lambda.New(session.Must(session.NewSession(&aws.Config{ - Endpoint: aws.String(ts.URL), - Region: aws.String("fake-moon-1"), - Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) +func TestAWS(t *testing.T) { + // Runs a suite of tests against two different methods of registering + // handlers on an AWS client. + + type test func(*testing.T, *lambda.Lambda) + tests := []struct { + name string + test test + failConn bool + }{ + {"failed connection", testClientFailedConnection, true}, + {"successful connection", testClientSuccessfulConnection, false}, + {"without segment", testClientWithoutSegment, false}, + } - ctx, root := BeginSegment(context.Background(), "Test") + onClient := func(s *session.Session) *lambda.Lambda { + svc := lambda.New(s) + AWS(svc.Client) + return svc + } - AWS(svc.Client) + onSession := func(s *session.Session) *lambda.Lambda { + return lambda.New(AWSSession(s)) + } + + const whitelist = "../resources/AWSWhitelist.json" + + onClientWithWhitelist := func(s *session.Session) *lambda.Lambda { + svc := lambda.New(s) + AWSWithWhitelist(svc.Client, whitelist) + return svc + } + + onSessionWithWhitelist := func(s *session.Session) *lambda.Lambda { + return lambda.New(AWSSessionWithWhitelist(s, whitelist)) + } + + type constructor func(*session.Session) *lambda.Lambda + constructors := []struct { + name string + constructor constructor + }{ + {"AWS()", onClient}, + {"AWSSession()", onSession}, + {"AWSWithWhitelist()", onClientWithWhitelist}, + {"AWSSessionWithWhitelist()", onSessionWithWhitelist}, + } + + // Run all combinations of constructors + tests. + for _, cons := range constructors { + t.Run(cons.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.test(t, cons.constructor(fakeSession(t, test.failConn))) + }) + } + }) + } +} +func fakeSession(t *testing.T, failConn bool) *session.Session { + cfg := &aws.Config{ + Region: aws.String("fake-moon-1"), + Credentials: credentials.NewStaticCredentials("akid", "secret", "noop"), + } + if !failConn { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := []byte(`{}`) + w.WriteHeader(http.StatusOK) + w.Write(b) + })) + cfg.Endpoint = aws.String(ts.URL) + } + s, err := session.NewSession(cfg) + assert.NoError(t, err) + return s +} + +func testClientSuccessfulConnection(t *testing.T, svc *lambda.Lambda) { + ctx, root := BeginSegment(context.Background(), "Test") _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) root.Close(nil) assert.NoError(t, err) @@ -76,15 +140,8 @@ func TestClientSuccessfulConnection(t *testing.T) { } } -func TestClientFailedConnection(t *testing.T) { - svc := lambda.New(session.Must(session.NewSession(&aws.Config{ - Region: aws.String("fake-moon-1"), - Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) - +func testClientFailedConnection(t *testing.T, svc *lambda.Lambda) { ctx, root := BeginSegment(context.Background(), "Test") - - AWS(svc.Client) - _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) root.Close(nil) assert.Error(t, err) @@ -116,24 +173,11 @@ func TestClientFailedConnection(t *testing.T) { assert.NotEmpty(t, connectSubseg.Subsegments) } -func TestClientWithoutSegment(t *testing.T) { +func testClientWithoutSegment(t *testing.T, svc *lambda.Lambda) { Configure(Config{ContextMissingStrategy: &TestContextMissingStrategy{}}) defer ResetConfig() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b := []byte(`{}`) - w.WriteHeader(http.StatusOK) - w.Write(b) - })) - - svc := lambda.New(session.Must(session.NewSession(&aws.Config{ - Endpoint: aws.String(ts.URL), - Region: aws.String("fake-moon-1"), - Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) ctx := context.Background() - - AWS(svc.Client) - _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) assert.NoError(t, err) }