From d9375b4347483043a416ee51f37118e5f7b9cf79 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 1 Aug 2024 16:41:44 +0200 Subject: [PATCH] Check parts' last modified time for the grace period --- lib/events/api.go | 3 +++ lib/events/azsessions/azsessions.go | 21 +++++++++++++++++---- lib/events/complete.go | 10 ++++++++++ lib/events/eventstest/uploader.go | 25 +++++++++++++++++-------- lib/events/filesessions/filestream.go | 14 +++++++++++--- lib/events/gcssessions/gcsstream.go | 3 ++- lib/events/s3sessions/s3stream.go | 8 ++++---- 7 files changed, 64 insertions(+), 20 deletions(-) diff --git a/lib/events/api.go b/lib/events/api.go index 8ac05d415ddde..ad6fa69ac17a9 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -877,6 +877,9 @@ type StreamPart struct { Number int64 // ETag is a part e-tag ETag string + // LastModified is the time of last modification of this part (if + // available). + LastModified time.Time } // StreamUpload represents stream multipart upload diff --git a/lib/events/azsessions/azsessions.go b/lib/events/azsessions/azsessions.go index 85aa13d39084d..527f024670825 100644 --- a/lib/events/azsessions/azsessions.go +++ b/lib/events/azsessions/azsessions.go @@ -28,6 +28,7 @@ import ( "slices" "strconv" "strings" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -450,7 +451,8 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa // our parts are just over 5 MiB (events.MinUploadPartSizeBytes) so we can // upload them in one shot - if _, err := cErr(partBlob.Upload(ctx, streaming.NopCloser(partBody), nil)); err != nil { + response, err := cErr(partBlob.Upload(ctx, streaming.NopCloser(partBody), nil)) + if err != nil { return nil, trace.Wrap(err) } h.log.WithFields(logrus.Fields{ @@ -459,7 +461,11 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa fieldPartNumber: partNumber, }).Debug("Uploaded part.") - return &events.StreamPart{Number: partNumber}, nil + var lastModified time.Time + if response.LastModified != nil { + lastModified = *response.LastModified + } + return &events.StreamPart{Number: partNumber, LastModified: lastModified}, nil } // ListParts implements [events.MultipartUploader]. @@ -492,8 +498,15 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] if err != nil { continue } - - parts = append(parts, events.StreamPart{Number: partNumber}) + var lastModified time.Time + if b.Properties != nil && + b.Properties.LastModified != nil { + lastModified = *b.Properties.LastModified + } + parts = append(parts, events.StreamPart{ + Number: partNumber, + LastModified: lastModified, + }) } } diff --git a/lib/events/complete.go b/lib/events/complete.go index c97643eab2ef9..62e610df1d7ce 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -263,6 +263,16 @@ func (u *UploadCompleter) CheckUploads(ctx context.Context) error { } return trace.Wrap(err, "listing parts") } + var lastModified time.Time + for _, part := range parts { + if part.LastModified.After(lastModified) { + lastModified = part.LastModified + } + } + if u.cfg.Clock.Since(lastModified) <= gracePeriod { + log.Debug("Found incomplete upload with recently uploaded part, skipping.") + continue + } log.Debugf("upload has %d parts", len(parts)) diff --git a/lib/events/eventstest/uploader.go b/lib/events/eventstest/uploader.go index 63a30cd242684..f31c9fdca5633 100644 --- a/lib/events/eventstest/uploader.go +++ b/lib/events/eventstest/uploader.go @@ -63,16 +63,21 @@ type MemoryUpload struct { // id is the upload ID id string // parts is the upload parts - parts map[int64][]byte + parts map[int64]part // sessionID is the session ID associated with the upload sessionID session.ID - //completed specifies upload as completed + // completed specifies upload as completed completed bool // Initiated contains the timestamp of when the upload // was initiated, not always initialized Initiated time.Time } +type part struct { + data []byte + lastModified time.Time +} + func (m *MemoryUploader) trySendEvent(event events.UploadEvent) { if m.eventsC == nil { return @@ -105,7 +110,7 @@ func (m *MemoryUploader) CreateUpload(ctx context.Context, sessionID session.ID) m.uploads[upload.ID] = &MemoryUpload{ id: upload.ID, sessionID: sessionID, - parts: make(map[int64][]byte), + parts: make(map[int64]part), Initiated: upload.Initiated, } return upload, nil @@ -127,11 +132,11 @@ func (m *MemoryUploader) CompleteUpload(ctx context.Context, upload events.Strea partsSet := make(map[int64]bool, len(parts)) for _, part := range parts { partsSet[part.Number] = true - data, ok := up.parts[part.Number] + upPart, ok := up.parts[part.Number] if !ok { return trace.NotFound("part %v has not been uploaded", part.Number) } - result = append(result, data...) + result = append(result, upPart.data...) } // exclude parts that are not requested to be completed for number := range up.parts { @@ -157,8 +162,12 @@ func (m *MemoryUploader) UploadPart(ctx context.Context, upload events.StreamUpl if !ok { return nil, trace.NotFound("upload %q is not found", upload.ID) } - up.parts[partNumber] = data - return &events.StreamPart{Number: partNumber}, nil + lastModified := m.Clock.Now() + up.parts[partNumber] = part{ + data: data, + lastModified: lastModified, + } + return &events.StreamPart{Number: partNumber, LastModified: lastModified}, nil } // ListUploads lists uploads that have been initiated but not completed with @@ -199,7 +208,7 @@ func (m *MemoryUploader) GetParts(uploadID string) ([][]byte, error) { return partNumbers[i] < partNumbers[j] }) for _, partNumber := range partNumbers { - sortedParts = append(sortedParts, up.parts[partNumber]) + sortedParts = append(sortedParts, up.parts[partNumber].data) } return sortedParts, nil } diff --git a/lib/events/filesessions/filestream.go b/lib/events/filesessions/filestream.go index c9c5ff5ccd855..0e0399ce89835 100644 --- a/lib/events/filesessions/filestream.go +++ b/lib/events/filesessions/filestream.go @@ -124,12 +124,19 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa } // Rename reservation to part file. - err = os.Rename(reservationPath, h.partPath(upload, partNumber)) + partPath := h.partPath(upload, partNumber) + err = os.Rename(reservationPath, partPath) if err != nil { return nil, trace.ConvertSystemError(err) } - return &events.StreamPart{Number: partNumber}, nil + var lastModified time.Time + fi, err := os.Stat(partPath) + if err == nil { + lastModified = fi.ModTime() + } + + return &events.StreamPart{Number: partNumber, LastModified: lastModified}, nil } // CompleteUpload completes the upload @@ -254,7 +261,8 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] return nil } parts = append(parts, events.StreamPart{ - Number: part, + Number: part, + LastModified: info.ModTime(), }) return nil }) diff --git a/lib/events/gcssessions/gcsstream.go b/lib/events/gcssessions/gcsstream.go index f18487fed85e9..f51a5df111b22 100644 --- a/lib/events/gcssessions/gcsstream.go +++ b/lib/events/gcssessions/gcsstream.go @@ -99,7 +99,7 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa if err != nil { return nil, convertGCSError(err) } - return &events.StreamPart{Number: partNumber}, nil + return &events.StreamPart{Number: partNumber, LastModified: writer.Attrs().Created}, nil } // CompleteUpload completes the upload @@ -249,6 +249,7 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] if err != nil { return nil, trace.Wrap(err) } + part.LastModified = attrs.Updated parts = append(parts, *part) } return parts, nil diff --git a/lib/events/s3sessions/s3stream.go b/lib/events/s3sessions/s3stream.go index c855ca564180f..9318ceb11cf3b 100644 --- a/lib/events/s3sessions/s3stream.go +++ b/lib/events/s3sessions/s3stream.go @@ -105,7 +105,7 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa } log.Infof("Uploaded part %v in %v", partNumber, time.Since(start)) - return &events.StreamPart{ETag: aws.StringValue(resp.ETag), Number: partNumber}, nil + return &events.StreamPart{ETag: aws.StringValue(resp.ETag), Number: partNumber, LastModified: time.Now()}, nil } func (h *Handler) abortUpload(ctx context.Context, upload events.StreamUpload) error { @@ -205,10 +205,10 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] return nil, awsutils.ConvertS3Error(err) } for _, part := range re.Parts { - parts = append(parts, events.StreamPart{ - Number: aws.Int64Value(part.PartNumber), - ETag: aws.StringValue(part.ETag), + Number: aws.Int64Value(part.PartNumber), + ETag: aws.StringValue(part.ETag), + LastModified: aws.TimeValue(part.LastModified), }) } if !aws.BoolValue(re.IsTruncated) {