Skip to content

Commit

Permalink
use aws sdk withcontext variants where possible (#8355)
Browse files Browse the repository at this point in the history
use aws api withcontext variants where possible
  • Loading branch information
rosstimothy authored Oct 6, 2021
1 parent 549b72c commit 8f783cf
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 73 deletions.
8 changes: 4 additions & 4 deletions lib/backend/dynamo/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func SetAutoScaling(ctx context.Context, svc *applicationautoscaling.Application
}

// Define scaling targets. Defines minimum and maximum {read,write} capacity.
if _, err := svc.RegisterScalableTarget(&applicationautoscaling.RegisterScalableTargetInput{
if _, err := svc.RegisterScalableTargetWithContext(ctx, &applicationautoscaling.RegisterScalableTargetInput{
MinCapacity: aws.Int64(params.ReadMinCapacity),
MaxCapacity: aws.Int64(params.ReadMaxCapacity),
ResourceId: aws.String(resourceID),
Expand All @@ -82,7 +82,7 @@ func SetAutoScaling(ctx context.Context, svc *applicationautoscaling.Application
}); err != nil {
return convertError(err)
}
if _, err := svc.RegisterScalableTarget(&applicationautoscaling.RegisterScalableTargetInput{
if _, err := svc.RegisterScalableTargetWithContext(ctx, &applicationautoscaling.RegisterScalableTargetInput{
MinCapacity: aws.Int64(params.WriteMinCapacity),
MaxCapacity: aws.Int64(params.WriteMaxCapacity),
ResourceId: aws.String(resourceID),
Expand All @@ -94,7 +94,7 @@ func SetAutoScaling(ctx context.Context, svc *applicationautoscaling.Application

// Define scaling policy. Defines the ratio of {read,write} consumed capacity to
// provisioned capacity DynamoDB will try and maintain.
if _, err := svc.PutScalingPolicy(&applicationautoscaling.PutScalingPolicyInput{
if _, err := svc.PutScalingPolicyWithContext(ctx, &applicationautoscaling.PutScalingPolicyInput{
PolicyName: aws.String(getReadScalingPolicyName(resourceID)),
PolicyType: aws.String(applicationautoscaling.PolicyTypeTargetTrackingScaling),
ResourceId: aws.String(resourceID),
Expand All @@ -109,7 +109,7 @@ func SetAutoScaling(ctx context.Context, svc *applicationautoscaling.Application
}); err != nil {
return convertError(err)
}
if _, err := svc.PutScalingPolicy(&applicationautoscaling.PutScalingPolicyInput{
if _, err := svc.PutScalingPolicyWithContext(ctx, &applicationautoscaling.PutScalingPolicyInput{
PolicyName: aws.String(getWriteScalingPolicyName(resourceID)),
PolicyType: aws.String(applicationautoscaling.PolicyTypeTargetTrackingScaling),
ResourceId: aws.String(resourceID),
Expand Down
4 changes: 2 additions & 2 deletions lib/backend/dynamo/configure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestContinuousBackups(t *testing.T) {

// Remove table after tests are done.
t.Cleanup(func() {
deleteTable(context.Background(), b.svc, b.Config.TableName)
require.NoError(t, deleteTable(context.Background(), b.svc, b.Config.TableName))
})

// Check status of continuous backups.
Expand All @@ -70,7 +70,7 @@ func TestAutoScaling(t *testing.T) {

// Remove table after tests are done.
t.Cleanup(func() {
deleteTable(context.Background(), b.svc, b.Config.TableName)
require.NoError(t, deleteTable(context.Background(), b.svc, b.Config.TableName))
})

// Check auto scaling values match.
Expand Down
6 changes: 3 additions & 3 deletions lib/backend/dynamo/dynamodbbk.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ func (b *Backend) createTable(ctx context.Context, tableName string, rangeKey st
KeySchema: elems,
ProvisionedThroughput: &pThroughput,
}
_, err := b.svc.CreateTable(&c)
_, err := b.svc.CreateTableWithContext(ctx, &c)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -721,7 +721,7 @@ func (b *Backend) getRecords(ctx context.Context, startKey, endKey string, limit
if limit > 0 {
input.Limit = aws.Int64(int64(limit))
}
out, err := b.svc.Query(&input)
out, err := b.svc.QueryWithContext(ctx, &input)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -851,7 +851,7 @@ func (b *Backend) getKey(ctx context.Context, key []byte) (*record, error) {
TableName: aws.String(b.TableName),
ConsistentRead: aws.Bool(true),
}
out, err := b.svc.GetItem(&input)
out, err := b.svc.GetItemWithContext(ctx, &input)
if err != nil || len(out.Item) == 0 {
return nil, trace.NotFound("%q is not found", string(key))
}
Expand Down
67 changes: 33 additions & 34 deletions lib/events/dynamoevents/dynamoevents.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,22 +280,22 @@ func New(ctx context.Context, cfg Config, backend backend.Backend) (*Log, error)
b.svc = dynamodb.New(b.session)

// check if the table exists?
ts, err := b.getTableStatus(b.Tablename)
ts, err := b.getTableStatus(ctx, b.Tablename)
if err != nil {
return nil, trace.Wrap(err)
}
switch ts {
case tableStatusOK:
break
case tableStatusMissing:
err = b.createTable(b.Tablename)
err = b.createTable(ctx, b.Tablename)
case tableStatusNeedsMigration:
return nil, trace.BadParameter("unsupported schema")
}
if err != nil {
return nil, trace.Wrap(err)
}
err = b.turnOnTimeToLive()
err = b.turnOnTimeToLive(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -392,7 +392,7 @@ type migrationTask struct {
// jittered interval until retrying migration again. This allows one server to pull ahead
// and finish or make significant progress on the migration.
func (l *Log) migrateRFD24(ctx context.Context) error {
hasIndexV1, err := l.indexExists(l.Tablename, indexTimeSearch)
hasIndexV1, err := l.indexExists(ctx, l.Tablename, indexTimeSearch)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -424,7 +424,7 @@ func (l *Log) migrateRFD24(ctx context.Context) error {
// Acquire a lock so that only one auth server attempts to perform the migration at any given time.
// If an auth server does in a HA-setup the other auth servers will pick up the migration automatically.
err = backend.RunWhileLocked(ctx, l.backend, rfd24MigrationLock, rfd24MigrationLockTTL, func(ctx context.Context) error {
hasIndexV1, err := l.indexExists(l.Tablename, indexTimeSearch)
hasIndexV1, err := l.indexExists(ctx, l.Tablename, indexTimeSearch)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -442,7 +442,7 @@ func (l *Log) migrateRFD24(ctx context.Context) error {

// Remove the old index, marking migration as complete
log.Info("Removing old DynamoDB index")
err = l.removeV1GSI()
err = l.removeV1GSI(ctx)
if err != nil {
return trace.WrapWithMessage(err, "Migrated all events to v6.2 format successfully but failed to remove old index.")
}
Expand Down Expand Up @@ -654,7 +654,7 @@ func (l *Log) GetSessionChunk(namespace string, sid session.ID, offsetBytes, max
return nil, nil
}

// Returns all events that happen during a session sorted by time
// GetSessionEvents Returns all events that happen during a session sorted by time
// (oldest first).
//
// after tells to use only return events after a specified cursor Id
Expand Down Expand Up @@ -1013,8 +1013,8 @@ func (l *Log) WaitForDelivery(ctx context.Context) error {
return nil
}

func (l *Log) turnOnTimeToLive() error {
status, err := l.svc.DescribeTimeToLive(&dynamodb.DescribeTimeToLiveInput{
func (l *Log) turnOnTimeToLive(ctx context.Context) error {
status, err := l.svc.DescribeTimeToLiveWithContext(ctx, &dynamodb.DescribeTimeToLiveInput{
TableName: aws.String(l.Tablename),
})
if err != nil {
Expand All @@ -1024,7 +1024,7 @@ func (l *Log) turnOnTimeToLive() error {
case dynamodb.TimeToLiveStatusEnabled, dynamodb.TimeToLiveStatusEnabling:
return nil
}
_, err = l.svc.UpdateTimeToLive(&dynamodb.UpdateTimeToLiveInput{
_, err = l.svc.UpdateTimeToLiveWithContext(ctx, &dynamodb.UpdateTimeToLiveInput{
TableName: aws.String(l.Tablename),
TimeToLiveSpecification: &dynamodb.TimeToLiveSpecification{
AttributeName: aws.String(keyExpires),
Expand All @@ -1035,8 +1035,8 @@ func (l *Log) turnOnTimeToLive() error {
}

// getTableStatus checks if a given table exists
func (l *Log) getTableStatus(tableName string) (tableStatus, error) {
_, err := l.svc.DescribeTable(&dynamodb.DescribeTableInput{
func (l *Log) getTableStatus(ctx context.Context, tableName string) (tableStatus, error) {
_, err := l.svc.DescribeTableWithContext(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
})
err = convertError(err)
Expand All @@ -1050,8 +1050,8 @@ func (l *Log) getTableStatus(tableName string) (tableStatus, error) {
}

// indexExists checks if a given index exists on a given table and that it is active or updating.
func (l *Log) indexExists(tableName, indexName string) (bool, error) {
tableDescription, err := l.svc.DescribeTable(&dynamodb.DescribeTableInput{
func (l *Log) indexExists(ctx context.Context, tableName, indexName string) (bool, error) {
tableDescription, err := l.svc.DescribeTableWithContext(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
})
if err != nil {
Expand All @@ -1077,7 +1077,7 @@ func (l *Log) indexExists(tableName, indexName string) (bool, error) {
// - This function must be called before the
// backend is considered initialized and the main Teleport process is started.
func (l *Log) createV2GSI(ctx context.Context) error {
v2Exists, err := l.indexExists(l.Tablename, indexTimeSearchV2)
v2Exists, err := l.indexExists(ctx, l.Tablename, indexTimeSearchV2)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1121,7 +1121,7 @@ func (l *Log) createV2GSI(ctx context.Context) error {
},
}

if _, err := l.svc.UpdateTable(&c); err != nil {
if _, err := l.svc.UpdateTableWithContext(ctx, &c); err != nil {
return trace.Wrap(convertError(err))
}

Expand All @@ -1131,7 +1131,7 @@ func (l *Log) createV2GSI(ctx context.Context) error {

// Wait until the index is created and active or updating.
for time.Now().Before(endWait) {
indexExists, err := l.indexExists(l.Tablename, indexTimeSearchV2)
indexExists, err := l.indexExists(ctx, l.Tablename, indexTimeSearchV2)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1159,8 +1159,8 @@ func (l *Log) createV2GSI(ctx context.Context) error {
// Invariants:
// - This function must not be called concurrently with itself.
// - This may only be executed after the post RFD 24 global secondary index has been created.
func (l *Log) removeV1GSI() error {
v1Exists, err := l.indexExists(l.Tablename, indexTimeSearch)
func (l *Log) removeV1GSI(ctx context.Context) error {
v1Exists, err := l.indexExists(ctx, l.Tablename, indexTimeSearch)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -1181,7 +1181,7 @@ func (l *Log) removeV1GSI() error {
},
}

if _, err := l.svc.UpdateTable(&c); err != nil {
if _, err := l.svc.UpdateTableWithContext(ctx, &c); err != nil {
return trace.Wrap(convertError(err))
}

Expand Down Expand Up @@ -1278,7 +1278,7 @@ func (l *Log) migrateMatchingEvents(ctx context.Context, filterExpr string, tran
// Resume the scan at the end of the previous one.
// This processes `DynamoBatchSize*maxMigrationWorkers` events at maximum
// which is why we need to run this multiple times on the dataset.
scanOut, err := l.svc.Scan(c)
scanOut, err := l.svc.ScanWithContext(ctx, c)
if err != nil {
return trace.Wrap(convertError(err))
}
Expand Down Expand Up @@ -1328,7 +1328,7 @@ func (l *Log) migrateMatchingEvents(ctx context.Context, filterExpr string, tran
defer workerBarrier.Done()
amountProcessed := len(batch)

if err := l.uploadBatch(batch); err != nil {
if err := l.uploadBatch(ctx, batch); err != nil {
workerErrors <- trace.Wrap(err)
return
}
Expand Down Expand Up @@ -1363,13 +1363,13 @@ func (l *Log) migrateMatchingEvents(ctx context.Context, filterExpr string, tran
}

// uploadBatch creates or updates a batch of `DynamoBatchSize` events or less in one API call.
func (l *Log) uploadBatch(writeRequests []*dynamodb.WriteRequest) error {
func (l *Log) uploadBatch(ctx context.Context, writeRequests []*dynamodb.WriteRequest) error {
for {
c := &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]*dynamodb.WriteRequest{l.Tablename: writeRequests},
}

out, err := l.svc.BatchWriteItem(c)
out, err := l.svc.BatchWriteItemWithContext(ctx, c)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -1387,7 +1387,7 @@ func (l *Log) uploadBatch(writeRequests []*dynamodb.WriteRequest) error {
// rangeKey is the name of the 'range key' the schema requires.
// currently is always set to "FullPath" (used to be something else, that's
// why it's a parameter for migration purposes)
func (l *Log) createTable(tableName string) error {
func (l *Log) createTable(ctx context.Context, tableName string) error {
provisionedThroughput := dynamodb.ProvisionedThroughput{
ReadCapacityUnits: aws.Int64(l.ReadCapacityUnits),
WriteCapacityUnits: aws.Int64(l.WriteCapacityUnits),
Expand Down Expand Up @@ -1428,12 +1428,12 @@ func (l *Log) createTable(tableName string) error {
},
},
}
_, err := l.svc.CreateTable(&c)
_, err := l.svc.CreateTableWithContext(ctx, &c)
if err != nil {
return trace.Wrap(err)
}
log.Infof("Waiting until table %q is created", tableName)
err = l.svc.WaitUntilTableExists(&dynamodb.DescribeTableInput{
err = l.svc.WaitUntilTableExistsWithContext(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
})
if err == nil {
Expand All @@ -1448,8 +1448,8 @@ func (l *Log) Close() error {
}

// deleteAllItems deletes all items from the database, used in tests
func (l *Log) deleteAllItems() error {
out, err := l.svc.Scan(&dynamodb.ScanInput{TableName: aws.String(l.Tablename)})
func (l *Log) deleteAllItems(ctx context.Context) error {
out, err := l.svc.ScanWithContext(ctx, &dynamodb.ScanInput{TableName: aws.String(l.Tablename)})
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -1473,12 +1473,11 @@ func (l *Log) deleteAllItems() error {
chunk := requests[:top]
requests = requests[top:]

req, _ := l.svc.BatchWriteItemRequest(&dynamodb.BatchWriteItemInput{
_, err := l.svc.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]*dynamodb.WriteRequest{
l.Tablename: chunk,
},
})
err = req.Send()
err = convertError(err)
if err != nil {
return trace.Wrap(err)
Expand All @@ -1489,15 +1488,15 @@ func (l *Log) deleteAllItems() error {
}

// deleteTable deletes DynamoDB table with a given name
func (l *Log) deleteTable(tableName string, wait bool) error {
func (l *Log) deleteTable(ctx context.Context, tableName string, wait bool) error {
tn := aws.String(tableName)
_, err := l.svc.DeleteTable(&dynamodb.DeleteTableInput{TableName: tn})
_, err := l.svc.DeleteTableWithContext(ctx, &dynamodb.DeleteTableInput{TableName: tn})
if err != nil {
return trace.Wrap(err)
}
if wait {
return trace.Wrap(
l.svc.WaitUntilTableNotExists(&dynamodb.DescribeTableInput{TableName: tn}))
l.svc.WaitUntilTableNotExistsWithContext(ctx, &dynamodb.DescribeTableInput{TableName: tn}))
}
return nil
}
Expand Down
6 changes: 3 additions & 3 deletions lib/events/dynamoevents/dynamoevents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ func (s *suiteBase) SetUpSuite(c *check.C) {
}

func (s *suiteBase) SetUpTest(c *check.C) {
err := s.log.deleteAllItems()
err := s.log.deleteAllItems(context.Background())
c.Assert(err, check.IsNil)
}

func (s *suiteBase) TearDownSuite(c *check.C) {
if s.log != nil {
if err := s.log.deleteTable(s.log.Tablename, true); err != nil {
if err := s.log.deleteTable(context.Background(), s.log.Tablename, true); err != nil {
c.Fatalf("Failed to delete table: %#v", trace.DebugReport(err))
}
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func (s *DynamoeventsSuite) TestSessionEventsCRUD(c *check.C) {

// TestIndexExists tests functionality of the `Log.indexExists` function.
func (s *DynamoeventsSuite) TestIndexExists(c *check.C) {
hasIndex, err := s.log.indexExists(s.log.Tablename, indexTimeSearchV2)
hasIndex, err := s.log.indexExists(context.Background(), s.log.Tablename, indexTimeSearchV2)
c.Assert(err, check.IsNil)
c.Assert(hasIndex, check.Equals, true)
}
Expand Down
4 changes: 2 additions & 2 deletions lib/events/gcssessions/gcshandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,15 @@ func composerRun(ctx context.Context, composer *storage.Composer) (*storage.Obje
}

// DefaultNewHandler returns a new handler with default GCS client settings derived from the config
func DefaultNewHandler(cfg Config) (*Handler, error) {
func DefaultNewHandler(ctx context.Context, cfg Config) (*Handler, error) {
var args []option.ClientOption
if len(cfg.Endpoint) != 0 {
args = append(args, option.WithoutAuthentication(), option.WithEndpoint(cfg.Endpoint), option.WithGRPCDialOption(grpc.WithInsecure()))
} else if len(cfg.CredentialsPath) != 0 {
args = append(args, option.WithCredentialsFile(cfg.CredentialsPath))
}

ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithCancel(ctx)
client, err := storage.NewClient(ctx, args...)
if err != nil {
cancelFunc()
Expand Down
Loading

0 comments on commit 8f783cf

Please sign in to comment.