Skip to content

Commit

Permalink
Add iteration support to dynamo
Browse files Browse the repository at this point in the history
This builds on top of #52199 by updating the dynamo backend to
implement backend.BackendWithItems. Both GetRange and DeleteRange
were refactored to call Items instead of getAllRecords to unify
logic and vet the implementation of Items.

The custom pagination logic to retrieve items was also replaced with
an iterator that works in a similar fashio to the query paginator
from the aws sdk. In addition to simplifying logic this also
removed some extraneous sorting.
  • Loading branch information
rosstimothy committed Mar 4, 2025
1 parent bc1edf1 commit 7b502b5
Showing 1 changed file with 159 additions and 160 deletions.
319 changes: 159 additions & 160 deletions lib/backend/dynamo/dynamodbbk.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ package dynamo
import (
"context"
"errors"
"iter"
"log/slog"
"net/http"
"sort"
"strconv"
"sync/atomic"
"time"
Expand Down Expand Up @@ -548,63 +548,132 @@ func (b *Backend) Update(ctx context.Context, item backend.Item) (*backend.Lease
return backend.NewLease(item), nil
}

// GetRange returns range of elements
func (b *Backend) GetRange(ctx context.Context, startKey, endKey backend.Key, limit int) (*backend.GetResult, error) {
if startKey.IsZero() {
return nil, trace.BadParameter("missing parameter startKey")
}
if endKey.IsZero() {
return nil, trace.BadParameter("missing parameter endKey")
}
func (b *Backend) queryOutputPages(ctx context.Context, limit int, input *dynamodb.QueryInput) iter.Seq2[*dynamodb.QueryOutput, error] {
if limit <= 0 {
limit = backend.DefaultRangeLimit
}

result, err := b.getAllRecords(ctx, startKey, endKey, limit)
if err != nil {
return nil, trace.Wrap(err)
}
sort.Sort(records(result.records))
values := make([]backend.Item, len(result.records))
for i, r := range result.records {
values[i] = backend.Item{
Key: trimPrefix(r.FullPath),
Value: r.Value,
Revision: r.Revision,
}
if r.Expires != nil {
values[i].Expires = time.Unix(*r.Expires, 0).UTC()
}
if values[i].Revision == "" {
values[i].Revision = backend.BlankRevision
return func(yield func(*dynamodb.QueryOutput, error) bool) {
const defaultPageSize = 1000
var nextToken map[string]types.AttributeValue

totalCount := 0
for {
input.Limit = aws.Int32(int32(min(limit-totalCount, defaultPageSize)))

result, err := b.svc.Query(ctx, input)
if err != nil {
yield(nil, trace.Wrap(err))
return
}

nextToken = result.LastEvaluatedKey
if !yield(result, nil) {
return
}

if nextToken == nil {
return
}

totalCount += len(result.Items)
if totalCount >= limit {
return
}
input.ExclusiveStartKey = nextToken
}
}
return &backend.GetResult{Items: values}, nil
}

func (b *Backend) getAllRecords(ctx context.Context, startKey, endKey backend.Key, limit int) (*getResult, error) {
var result getResult
func (b *Backend) Items(ctx context.Context, params backend.IterateParams) iter.Seq2[backend.Item, error] {
if params.StartKey.IsZero() {
err := trace.BadParameter("missing parameter startKey")
return func(yield func(backend.Item, error) bool) { yield(backend.Item{}, err) }
}
if params.EndKey.IsZero() {
err := trace.BadParameter("missing parameter endKey")
return func(yield func(backend.Item, error) bool) { yield(backend.Item{}, err) }
}

// this code is being extra careful here not to introduce endless loop
// by some unfortunate series of events
for i := 0; i < backend.DefaultRangeLimit/100; i++ {
re, err := b.getRecords(ctx, prependPrefix(startKey), prependPrefix(endKey), limit, result.lastEvaluatedKey)
const (
query = "HashKey = :hashKey AND FullPath BETWEEN :rangeStart AND :rangeEnd"

// filter out expired items, otherwise they might show up in the query
// http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/howitworks-ttl.html
filter = "attribute_not_exists(Expires) OR Expires >= :timestamp"
)

av := map[string]types.AttributeValue{
":rangeStart": &types.AttributeValueMemberS{Value: prependPrefix(params.StartKey)},
":rangeEnd": &types.AttributeValueMemberS{Value: prependPrefix(params.EndKey)},
":timestamp": timeToAttributeValue(b.clock.Now().UTC()),
":hashKey": &types.AttributeValueMemberS{Value: hashKey},
}

input := dynamodb.QueryInput{
KeyConditionExpression: aws.String(query),
TableName: &b.TableName,
ExpressionAttributeValues: av,
FilterExpression: aws.String(filter),
ConsistentRead: aws.Bool(true),
ScanIndexForward: aws.Bool(!params.Descending),
}

return func(yield func(backend.Item, error) bool) {
count := 0
defer func() {
if count >= backend.DefaultRangeLimit {
b.logger.WarnContext(ctx, "Range query hit backend limit. (this is a bug!)", "start_key", params.StartKey, "limit", backend.DefaultRangeLimit)
}
}()

for page, err := range b.queryOutputPages(ctx, params.Limit, &input) {
if err != nil {
yield(backend.Item{}, convertError(err))
return
}

for _, itemAttributes := range page.Items {
var r record
if err := attributevalue.UnmarshalMap(itemAttributes, &r); err != nil {
yield(backend.Item{}, convertError(err))
return
}

item := backend.Item{
Key: trimPrefix(r.FullPath),
Value: r.Value,
Revision: r.Revision,
}
if r.Expires != nil {
item.Expires = time.Unix(*r.Expires, 0).UTC()
}
if item.Revision == "" {
item.Revision = backend.BlankRevision
}

if !yield(item, nil) {
return
}
count++
if params.Limit != backend.NoLimit && count >= params.Limit {
return
}
}
}
}
}

// GetRange returns range of elements
func (b *Backend) GetRange(ctx context.Context, startKey, endKey backend.Key, limit int) (*backend.GetResult, error) {
var result backend.GetResult
for i, err := range b.Items(ctx, backend.IterateParams{StartKey: startKey, EndKey: endKey, Limit: limit}) {
if err != nil {
return nil, trace.Wrap(err)
}
result.records = append(result.records, re.records...)
// If the limit was exceeded or there are no more records to fetch return the current result
// otherwise updated lastEvaluatedKey and proceed with obtaining new records.
if (limit != 0 && len(result.records) >= limit) || len(re.lastEvaluatedKey) == 0 {
if len(result.records) == backend.DefaultRangeLimit {
b.logger.WarnContext(ctx, "Range query hit backend limit. (this is a bug!)", "start_key", startKey, "limit", backend.DefaultRangeLimit)
}
result.lastEvaluatedKey = nil
return &result, nil
}
result.lastEvaluatedKey = re.lastEvaluatedKey
result.Items = append(result.Items, i)
}
return nil, trace.BadParameter("backend entered endless loop")
return &result, nil
}

const (
Expand All @@ -617,45 +686,63 @@ const (

// DeleteRange deletes range of items with keys between startKey and endKey
func (b *Backend) DeleteRange(ctx context.Context, startKey, endKey backend.Key) error {
if startKey.IsZero() {
return trace.BadParameter("missing parameter startKey")
}
if endKey.IsZero() {
return trace.BadParameter("missing parameter endKey")
}
// keep fetching and deleting until no records left,
// keep the very large limit, just in case if someone else
// keeps adding records
for i := 0; i < backend.DefaultRangeLimit/100; i++ {
result, err := b.getRecords(ctx, prependPrefix(startKey), prependPrefix(endKey), batchOperationItemsLimit, nil)
// Attempt to pull all existing items and delete them in batches
// in accordance with the BatchWriteItem limits. There is a hard
// cap on the total number of items that can be deleted in a single
// DeleteRange call to avoid racing with additional records being added.
const maxDeletionOperations = backend.DefaultRangeLimit / 100 / batchOperationItemsLimit
requests := make([]types.WriteRequest, 0, batchOperationItemsLimit)
var deletions int
for item, err := range b.Items(ctx, backend.IterateParams{StartKey: startKey, EndKey: endKey}) {
if err != nil {
return trace.Wrap(err)
}
if len(result.records) == 0 {
return nil

if deletions >= maxDeletionOperations {
break
}
requests := make([]types.WriteRequest, 0, len(result.records))
for _, record := range result.records {
requests = append(requests, types.WriteRequest{
DeleteRequest: &types.DeleteRequest{
Key: map[string]types.AttributeValue{
hashKeyKey: &types.AttributeValueMemberS{Value: hashKey},
fullPathKey: &types.AttributeValueMemberS{Value: record.FullPath},
},

requests = append(requests, types.WriteRequest{
DeleteRequest: &types.DeleteRequest{
Key: map[string]types.AttributeValue{
hashKeyKey: &types.AttributeValueMemberS{Value: hashKey},
fullPathKey: &types.AttributeValueMemberS{Value: prependPrefix(item.Key)},
},
})
},
})

if len(requests) == batchOperationItemsLimit {
if _, err := b.svc.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]types.WriteRequest{
b.TableName: requests,
},
}); err != nil {
return trace.Wrap(err)
}

requests = requests[:0]
deletions++
if deletions >= maxDeletionOperations {
break
}
}
input := dynamodb.BatchWriteItemInput{
}

if deletions >= maxDeletionOperations {
return trace.ConnectionProblem(nil, "not all items deleted, too many requests")
}

if len(requests) > 0 {
if _, err := b.svc.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]types.WriteRequest{
b.TableName: requests,
},
}

if _, err = b.svc.BatchWriteItem(ctx, &input); err != nil {
}); err != nil {
return trace.Wrap(err)
}
}
return trace.ConnectionProblem(nil, "not all items deleted, too many requests")

return nil
}

// Get returns a single item or not found error
Expand Down Expand Up @@ -961,60 +1048,6 @@ func (b *Backend) createTable(ctx context.Context, tableName *string, rangeKey s
return trace.Wrap(err)
}

type getResult struct {
// lastEvaluatedKey is the primary key of the item where the operation stopped, inclusive of the
// previous result set. Use this value to start a new operation, excluding this
// value in the new request.
lastEvaluatedKey map[string]types.AttributeValue
records []record
}

// getRecords retrieves all keys by path
func (b *Backend) getRecords(ctx context.Context, startKey, endKey string, limit int, lastEvaluatedKey map[string]types.AttributeValue) (*getResult, error) {
query := "HashKey = :hashKey AND FullPath BETWEEN :fullPath AND :rangeEnd"
attrV := map[string]interface{}{
":fullPath": startKey,
":hashKey": hashKey,
":timestamp": b.clock.Now().UTC().Unix(),
":rangeEnd": endKey,
}

// filter out expired items, otherwise they might show up in the query
// http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/howitworks-ttl.html
filter := "attribute_not_exists(Expires) OR Expires >= :timestamp"
av, err := attributevalue.MarshalMap(attrV)
if err != nil {
return nil, convertError(err)
}
input := dynamodb.QueryInput{
KeyConditionExpression: aws.String(query),
TableName: &b.TableName,
ExpressionAttributeValues: av,
FilterExpression: aws.String(filter),
ConsistentRead: aws.Bool(true),
ExclusiveStartKey: lastEvaluatedKey,
}
if limit > 0 {
input.Limit = aws.Int32(int32(limit))
}
out, err := b.svc.Query(ctx, &input)
if err != nil {
return nil, trace.Wrap(err)
}
var result getResult
for _, item := range out.Items {
var r record
if err := attributevalue.UnmarshalMap(item, &r); err != nil {
return nil, trace.Wrap(err)
}
result.records = append(result.records, r)
}
sort.Sort(records(result.records))
result.records = removeDuplicates(result.records)
result.lastEvaluatedKey = out.LastEvaluatedKey
return &result, nil
}

// isExpired returns 'true' if the given object (record) has a TTL and
// it's due.
func (r *record) isExpired(now time.Time) bool {
Expand All @@ -1025,23 +1058,6 @@ func (r *record) isExpired(now time.Time) bool {
return now.UTC().After(expiryDateUTC)
}

func removeDuplicates(elements []record) []record {
// Use map to record duplicates as we find them.
encountered := map[string]bool{}
var result []record

for v := range elements {
if !encountered[elements[v].FullPath] {
// Record this element as an encountered element.
encountered[elements[v].FullPath] = true
// Append to result slice.
result = append(result, elements[v])
}
}
// Return the new slice.
return result
}

const (
modeCreate = iota
modePut
Expand Down Expand Up @@ -1235,23 +1251,6 @@ func convertError(err error) error {
return err
}

type records []record

// Len is part of sort.Interface.
func (r records) Len() int {
return len(r)
}

// Swap is part of sort.Interface.
func (r records) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}

// Less is part of sort.Interface.
func (r records) Less(i, j int) bool {
return r[i].FullPath < r[j].FullPath
}

func fullPathToAttributeValueMap(fullPath string) map[string]types.AttributeValue {
return map[string]types.AttributeValue{
hashKeyKey: &types.AttributeValueMemberS{Value: hashKey},
Expand Down

0 comments on commit 7b502b5

Please sign in to comment.