diff --git a/xray/aws.go b/xray/aws.go index 5be4cd89..f1344590 100644 --- a/xray/aws.go +++ b/xray/aws.go @@ -21,6 +21,7 @@ import ( "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-xray-sdk-go/internal/logger" "github.com/aws/aws-xray-sdk-go/resources" ) @@ -137,15 +138,16 @@ var xRayAfterRetryHandler = request.NamedHandler{ }, } -func pushHandlers(c *client.Client) { - c.Handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler) - c.Handlers.Build.PushBackNamed(xRayAfterBuildHandler) - c.Handlers.Sign.PushFrontNamed(xRayBeforeSignHandler) - c.Handlers.Send.PushBackNamed(xRayAfterSendHandler) - c.Handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler) - c.Handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler) - c.Handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler) - c.Handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler) +func pushHandlers(handlers *request.Handlers, completionWhitelistFilename string) { + handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler) + handlers.Build.PushBackNamed(xRayAfterBuildHandler) + handlers.Sign.PushFrontNamed(xRayBeforeSignHandler) + handlers.Send.PushBackNamed(xRayAfterSendHandler) + handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler) + handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler) + handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler) + handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler) + handlers.Complete.PushFrontNamed(xrayCompleteHandler(completionWhitelistFilename)) } // AWS adds X-Ray tracing to an AWS client. @@ -153,8 +155,7 @@ func AWS(c *client.Client) { if c == nil { panic("Please initialize the provided AWS client before passing to the AWS() method.") } - pushHandlers(c) - c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler("")) + pushHandlers(&c.Handlers, "") } // AWSWithWhitelist allows a custom parameter whitelist JSON file to be defined. @@ -162,8 +163,21 @@ func AWSWithWhitelist(c *client.Client, filename string) { if c == nil { panic("Please initialize the provided AWS client before passing to the AWSWithWhitelist() method.") } - pushHandlers(c) - c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler(filename)) + pushHandlers(&c.Handlers, filename) +} + +// AWSSession adds X-Ray tracing to an AWS session. Clients created under this +// session will inherit X-Ray tracing. +func AWSSession(s *session.Session) *session.Session { + pushHandlers(&s.Handlers, "") + return s +} + +// AWSSessionWithWhitelist allows a custom parameter whitelist JSON file to be +// defined. +func AWSSessionWithWhitelist(s *session.Session, filename string) *session.Session { + pushHandlers(&s.Handlers, filename) + return s } func xrayCompleteHandler(filename string) request.NamedHandler { diff --git a/xray/aws_test.go b/xray/aws_test.go index 3fd1209d..48916600 100644 --- a/xray/aws_test.go +++ b/xray/aws_test.go @@ -14,22 +14,86 @@ import ( "github.com/stretchr/testify/assert" ) -func TestClientSuccessfulConnection(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b := []byte(`{}`) - w.WriteHeader(http.StatusOK) - w.Write(b) - })) - - svc := lambda.New(session.Must(session.NewSession(&aws.Config{ - Endpoint: aws.String(ts.URL), - Region: aws.String("fake-moon-1"), - Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) +func TestAWS(t *testing.T) { + // Runs a suite of tests against two different methods of registering + // handlers on an AWS client. + + type test func(*testing.T, *lambda.Lambda) + tests := []struct { + name string + test test + failConn bool + }{ + {"failed connection", testClientFailedConnection, true}, + {"successful connection", testClientSuccessfulConnection, false}, + {"without segment", testClientWithoutSegment, false}, + } - ctx, root := BeginSegment(context.Background(), "Test") + onClient := func(s *session.Session) *lambda.Lambda { + svc := lambda.New(s) + AWS(svc.Client) + return svc + } - AWS(svc.Client) + onSession := func(s *session.Session) *lambda.Lambda { + return lambda.New(AWSSession(s)) + } + + const whitelist = "../resources/AWSWhitelist.json" + + onClientWithWhitelist := func(s *session.Session) *lambda.Lambda { + svc := lambda.New(s) + AWSWithWhitelist(svc.Client, whitelist) + return svc + } + + onSessionWithWhitelist := func(s *session.Session) *lambda.Lambda { + return lambda.New(AWSSessionWithWhitelist(s, whitelist)) + } + + type constructor func(*session.Session) *lambda.Lambda + constructors := []struct { + name string + constructor constructor + }{ + {"AWS()", onClient}, + {"AWSSession()", onSession}, + {"AWSWithWhitelist()", onClientWithWhitelist}, + {"AWSSessionWithWhitelist()", onSessionWithWhitelist}, + } + + // Run all combinations of constructors + tests. + for _, cons := range constructors { + t.Run(cons.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.test(t, cons.constructor(fakeSession(t, test.failConn))) + }) + } + }) + } +} +func fakeSession(t *testing.T, failConn bool) *session.Session { + cfg := &aws.Config{ + Region: aws.String("fake-moon-1"), + Credentials: credentials.NewStaticCredentials("akid", "secret", "noop"), + } + if !failConn { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := []byte(`{}`) + w.WriteHeader(http.StatusOK) + w.Write(b) + })) + cfg.Endpoint = aws.String(ts.URL) + } + s, err := session.NewSession(cfg) + assert.NoError(t, err) + return s +} + +func testClientSuccessfulConnection(t *testing.T, svc *lambda.Lambda) { + ctx, root := BeginSegment(context.Background(), "Test") _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) root.Close(nil) assert.NoError(t, err) @@ -76,15 +140,8 @@ func TestClientSuccessfulConnection(t *testing.T) { } } -func TestClientFailedConnection(t *testing.T) { - svc := lambda.New(session.Must(session.NewSession(&aws.Config{ - Region: aws.String("fake-moon-1"), - Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) - +func testClientFailedConnection(t *testing.T, svc *lambda.Lambda) { ctx, root := BeginSegment(context.Background(), "Test") - - AWS(svc.Client) - _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) root.Close(nil) assert.Error(t, err) @@ -116,24 +173,11 @@ func TestClientFailedConnection(t *testing.T) { assert.NotEmpty(t, connectSubseg.Subsegments) } -func TestClientWithoutSegment(t *testing.T) { +func testClientWithoutSegment(t *testing.T, svc *lambda.Lambda) { Configure(Config{ContextMissingStrategy: &TestContextMissingStrategy{}}) defer ResetConfig() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b := []byte(`{}`) - w.WriteHeader(http.StatusOK) - w.Write(b) - })) - - svc := lambda.New(session.Must(session.NewSession(&aws.Config{ - Endpoint: aws.String(ts.URL), - Region: aws.String("fake-moon-1"), - Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")}))) ctx := context.Background() - - AWS(svc.Client) - _, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{}) assert.NoError(t, err) }