diff --git a/objstore.go b/objstore.go index a83dacdf..b1d7d09a 100644 --- a/objstore.go +++ b/objstore.go @@ -232,19 +232,6 @@ func NopCloserWithSize(r io.Reader) io.ReadCloser { return nopCloserWithObjectSize{r} } -type nopSeekerCloserWithObjectSize struct{ io.Reader } - -func (nopSeekerCloserWithObjectSize) Close() error { return nil } -func (n nopSeekerCloserWithObjectSize) ObjectSize() (int64, error) { return TryToGetSize(n.Reader) } - -func (n nopSeekerCloserWithObjectSize) Seek(offset int64, whence int) (int64, error) { - return n.Reader.(io.Seeker).Seek(offset, whence) -} - -func nopSeekerCloserWithSize(r io.Reader) io.ReadSeekCloser { - return nopSeekerCloserWithObjectSize{r} -} - // UploadDir uploads all files in srcdir to the bucket with into a top-level directory // named dstdir. It is a caller responsibility to clean partial upload in case of failure. func UploadDir(ctx context.Context, logger log.Logger, bkt Bucket, srcdir, dstdir string, options ...UploadOption) error { @@ -555,8 +542,9 @@ func (b *metricBucket) Get(ctx context.Context, name string) (io.ReadCloser, err } return nil, err } - return newTimingReadCloser( + return newTimingReader( rc, + true, op, b.opsDuration, b.opsFailures, @@ -577,8 +565,9 @@ func (b *metricBucket) GetRange(ctx context.Context, name string, off, length in } return nil, err } - return newTimingReadCloser( + return newTimingReader( rc, + true, op, b.opsDuration, b.opsFailures, @@ -608,16 +597,9 @@ func (b *metricBucket) Upload(ctx context.Context, name string, r io.Reader) err const op = OpUpload b.ops.WithLabelValues(op).Inc() - _, ok := r.(io.Seeker) - var nopR io.ReadCloser - if ok { - nopR = nopSeekerCloserWithSize(r) - } else { - nopR = NopCloserWithSize(r) - } - - trc := newTimingReadCloser( - nopR, + trc := newTimingReader( + r, + false, op, b.opsDuration, b.opsFailures, @@ -670,12 +652,13 @@ func (b *metricBucket) Name() string { return b.bkt.Name() } -type timingReadSeekCloser struct { - timingReadCloser -} +type timingReader struct { + io.Reader + + // closeReader holds whether the wrapper io.Reader should be closed when + // Close() is called on the timingReader. + closeReader bool -type timingReadCloser struct { - io.ReadCloser objSize int64 objSizeErr error @@ -691,14 +674,15 @@ type timingReadCloser struct { transferredBytes *prometheus.HistogramVec } -func newTimingReadCloser(rc io.ReadCloser, op string, dur *prometheus.HistogramVec, failed *prometheus.CounterVec, isFailureExpected IsOpFailureExpectedFunc, fetchedBytes *prometheus.CounterVec, transferredBytes *prometheus.HistogramVec) io.ReadCloser { +func newTimingReader(r io.Reader, closeReader bool, op string, dur *prometheus.HistogramVec, failed *prometheus.CounterVec, isFailureExpected IsOpFailureExpectedFunc, fetchedBytes *prometheus.CounterVec, transferredBytes *prometheus.HistogramVec) io.ReadCloser { // Initialize the metrics with 0. dur.WithLabelValues(op) failed.WithLabelValues(op) - objSize, objSizeErr := TryToGetSize(rc) + objSize, objSizeErr := TryToGetSize(r) - trc := timingReadCloser{ - ReadCloser: rc, + trc := timingReader{ + Reader: r, + closeReader: closeReader, objSize: objSize, objSizeErr: objSizeErr, start: time.Now(), @@ -711,50 +695,79 @@ func newTimingReadCloser(rc io.ReadCloser, op string, dur *prometheus.HistogramV readBytes: 0, } - _, ok := rc.(io.Seeker) - if ok { - return &timingReadSeekCloser{ - timingReadCloser: trc, - } + _, isSeeker := r.(io.Seeker) + _, isReaderAt := r.(io.ReaderAt) + + if isSeeker && isReaderAt { + // The assumption is that in most cases when io.ReaderAt() is implemented then + // io.Seeker is implemented too (e.g. os.File). + return &timingReaderSeekerReaderAt{timingReaderSeeker: timingReaderSeeker{timingReader: trc}} + } + if isSeeker { + return &timingReaderSeeker{timingReader: trc} } return &trc } -func (t *timingReadCloser) ObjectSize() (int64, error) { - return t.objSize, t.objSizeErr +func (r *timingReader) ObjectSize() (int64, error) { + return r.objSize, r.objSizeErr } -func (rc *timingReadCloser) Close() error { - err := rc.ReadCloser.Close() - if !rc.alreadyGotErr && err != nil { - rc.failed.WithLabelValues(rc.op).Inc() +func (r *timingReader) Close() error { + var closeErr error + + // Call the wrapped reader if it implements Close(), only if we've been asked to close it. + if closer, ok := r.Reader.(io.Closer); r.closeReader && ok { + closeErr = closer.Close() + + if !r.alreadyGotErr && closeErr != nil { + r.failed.WithLabelValues(r.op).Inc() + r.alreadyGotErr = true + } } - if !rc.alreadyGotErr && err == nil { - rc.duration.WithLabelValues(rc.op).Observe(time.Since(rc.start).Seconds()) - rc.transferredBytes.WithLabelValues(rc.op).Observe(float64(rc.readBytes)) - rc.alreadyGotErr = true + + // Track duration and transferred bytes only if no error occurred. + if !r.alreadyGotErr { + r.duration.WithLabelValues(r.op).Observe(time.Since(r.start).Seconds()) + r.transferredBytes.WithLabelValues(r.op).Observe(float64(r.readBytes)) + + // Trick to tracking metrics multiple times in case Close() gets called again. + r.alreadyGotErr = true } - return err + + return closeErr } -func (rc *timingReadCloser) Read(b []byte) (n int, err error) { - n, err = rc.ReadCloser.Read(b) - if rc.fetchedBytes != nil { - rc.fetchedBytes.WithLabelValues(rc.op).Add(float64(n)) +func (r *timingReader) Read(b []byte) (n int, err error) { + n, err = r.Reader.Read(b) + if r.fetchedBytes != nil { + r.fetchedBytes.WithLabelValues(r.op).Add(float64(n)) } - rc.readBytes += int64(n) + r.readBytes += int64(n) // Report metric just once. - if !rc.alreadyGotErr && err != nil && err != io.EOF { - if !rc.isFailureExpected(err) { - rc.failed.WithLabelValues(rc.op).Inc() + if !r.alreadyGotErr && err != nil && err != io.EOF { + if !r.isFailureExpected(err) { + r.failed.WithLabelValues(r.op).Inc() } - rc.alreadyGotErr = true + r.alreadyGotErr = true } return n, err } -func (rsc *timingReadSeekCloser) Seek(offset int64, whence int) (int64, error) { - return (rsc.ReadCloser).(io.Seeker).Seek(offset, whence) +type timingReaderSeeker struct { + timingReader +} + +func (rsc *timingReaderSeeker) Seek(offset int64, whence int) (int64, error) { + return (rsc.Reader).(io.Seeker).Seek(offset, whence) +} + +type timingReaderSeekerReaderAt struct { + timingReaderSeeker +} + +func (rsc *timingReaderSeekerReaderAt) ReadAt(p []byte, off int64) (int, error) { + return (rsc.Reader).(io.ReaderAt).ReadAt(p, off) } diff --git a/objstore_test.go b/objstore_test.go index bb3fbf5b..ababe62c 100644 --- a/objstore_test.go +++ b/objstore_test.go @@ -8,6 +8,7 @@ import ( "context" "io" "os" + "path/filepath" "strings" "testing" @@ -80,6 +81,208 @@ func TestMetricBucket_Multiple_Clients(t *testing.T) { WrapWithMetrics(NewInMemBucket(), reg, "def") } +func TestMetricBucket_UploadShouldPreserveReaderFeatures(t *testing.T) { + tests := map[string]struct { + reader io.Reader + expectedIsSeeker bool + expectedIsReaderAt bool + }{ + "bytes.Reader": { + reader: bytes.NewReader([]byte("1")), + expectedIsSeeker: true, + expectedIsReaderAt: true, + }, + "bytes.Buffer": { + reader: bytes.NewBuffer([]byte("1")), + expectedIsSeeker: false, + expectedIsReaderAt: false, + }, + "os.File": { + reader: func() io.Reader { + // Create a test file. + testFilepath := filepath.Join(t.TempDir(), "test") + testutil.Ok(t, os.WriteFile(testFilepath, []byte("test"), os.ModePerm)) + + // Open the file (it will be used as io.Reader). + file, err := os.Open(testFilepath) + testutil.Ok(t, err) + t.Cleanup(func() { + testutil.Ok(t, file.Close()) + }) + + return file + }(), + expectedIsSeeker: true, + expectedIsReaderAt: true, + }, + } + + for testName, testData := range tests { + t.Run(testName, func(t *testing.T) { + var uploadReader io.Reader + + m := &mockBucket{ + Bucket: WrapWithMetrics(NewInMemBucket(), nil, ""), + upload: func(ctx context.Context, name string, r io.Reader) error { + uploadReader = r + return nil + }, + } + + testutil.Ok(t, m.Upload(context.Background(), "dir/obj1", testData.reader)) + + _, isSeeker := uploadReader.(io.Seeker) + testutil.Equals(t, testData.expectedIsSeeker, isSeeker) + + _, isReaderAt := uploadReader.(io.ReaderAt) + testutil.Equals(t, testData.expectedIsReaderAt, isReaderAt) + }) + } +} + +func TestMetricBucket_ReaderClose(t *testing.T) { + const objPath = "dir/obj1" + + t.Run("Upload() should not close the input Reader", func(t *testing.T) { + closeCalled := false + + reader := &mockReader{ + Reader: bytes.NewBuffer([]byte("test")), + close: func() error { + closeCalled = true + return nil + }, + } + + bucket := WrapWithMetrics(NewInMemBucket(), nil, "") + testutil.Ok(t, bucket.Upload(context.Background(), objPath, reader)) + + // Should not call Close() on the reader. + testutil.Assert(t, !closeCalled) + + // An explicit call to Close() should close it. + testutil.Ok(t, reader.Close()) + testutil.Assert(t, closeCalled) + + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.ops.WithLabelValues(OpUpload))) + testutil.Equals(t, float64(0), promtest.ToFloat64(bucket.opsFailures.WithLabelValues(OpUpload))) + }) + + t.Run("Get() should return a wrapper io.ReadCloser that correctly Close the wrapped one", func(t *testing.T) { + closeCalled := false + + origReader := &mockReader{ + Reader: bytes.NewBuffer([]byte("test")), + close: func() error { + closeCalled = true + return nil + }, + } + + bucket := WrapWithMetrics(&mockBucket{ + get: func(_ context.Context, _ string) (io.ReadCloser, error) { + return origReader, nil + }, + }, nil, "") + + wrappedReader, err := bucket.Get(context.Background(), objPath) + testutil.Ok(t, err) + testutil.Assert(t, origReader != wrappedReader) + + // Calling Close() to the wrappedReader should close origReader. + testutil.Assert(t, !closeCalled) + testutil.Ok(t, wrappedReader.Close()) + testutil.Assert(t, closeCalled) + + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.ops.WithLabelValues(OpGet))) + testutil.Equals(t, float64(0), promtest.ToFloat64(bucket.opsFailures.WithLabelValues(OpGet))) + }) + + t.Run("GetRange() should return a wrapper io.ReadCloser that correctly Close the wrapped one", func(t *testing.T) { + closeCalled := false + + origReader := &mockReader{ + Reader: bytes.NewBuffer([]byte("test")), + close: func() error { + closeCalled = true + return nil + }, + } + + bucket := WrapWithMetrics(&mockBucket{ + getRange: func(_ context.Context, _ string, _, _ int64) (io.ReadCloser, error) { + return origReader, nil + }, + }, nil, "") + + wrappedReader, err := bucket.GetRange(context.Background(), objPath, 0, 1) + testutil.Ok(t, err) + testutil.Assert(t, origReader != wrappedReader) + + // Calling Close() to the wrappedReader should close origReader. + testutil.Assert(t, !closeCalled) + testutil.Ok(t, wrappedReader.Close()) + testutil.Assert(t, closeCalled) + + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.ops.WithLabelValues(OpGetRange))) + testutil.Equals(t, float64(0), promtest.ToFloat64(bucket.opsFailures.WithLabelValues(OpGetRange))) + }) +} + +func TestMetricBucket_ReaderCloseError(t *testing.T) { + origReader := &mockReader{ + Reader: bytes.NewBuffer([]byte("test")), + close: func() error { + return errors.New("mocked error") + }, + } + + t.Run("Get() should track failure if reader Close() returns error", func(t *testing.T) { + bucket := WrapWithMetrics(&mockBucket{ + get: func(ctx context.Context, name string) (io.ReadCloser, error) { + return origReader, nil + }, + }, nil, "") + + testutil.NotOk(t, bucket.Upload(context.Background(), "test", origReader)) + + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.ops.WithLabelValues(OpUpload))) + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.opsFailures.WithLabelValues(OpUpload))) + }) + + t.Run("Get() should track failure if reader Close() returns error", func(t *testing.T) { + bucket := WrapWithMetrics(&mockBucket{ + get: func(ctx context.Context, name string) (io.ReadCloser, error) { + return origReader, nil + }, + }, nil, "") + + reader, err := bucket.Get(context.Background(), "test") + testutil.Ok(t, err) + testutil.NotOk(t, reader.Close()) + testutil.NotOk(t, reader.Close()) // Called twice to ensure metrics are not tracked twice. + + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.ops.WithLabelValues(OpGet))) + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.opsFailures.WithLabelValues(OpGet))) + }) + + t.Run("GetRange() should track failure if reader Close() returns error", func(t *testing.T) { + bucket := WrapWithMetrics(&mockBucket{ + getRange: func(ctx context.Context, name string, off, length int64) (io.ReadCloser, error) { + return origReader, nil + }, + }, nil, "") + + reader, err := bucket.GetRange(context.Background(), "test", 0, 1) + testutil.Ok(t, err) + testutil.NotOk(t, reader.Close()) + testutil.NotOk(t, reader.Close()) // Called twice to ensure metrics are not tracked twice. + + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.ops.WithLabelValues(OpGetRange))) + testutil.Equals(t, float64(1), promtest.ToFloat64(bucket.opsFailures.WithLabelValues(OpGetRange))) + }) +} + func TestDownloadUploadDirConcurrency(t *testing.T) { r := prometheus.NewRegistry() m := WrapWithMetrics(NewInMemBucket(), r, "") @@ -206,12 +409,10 @@ func TestDownloadUploadDirConcurrency(t *testing.T) { `), `objstore_bucket_operations_total`)) } -func TestTimingTracingReader(t *testing.T) { +func TestTimingReader(t *testing.T) { m := WrapWithMetrics(NewInMemBucket(), nil, "") r := bytes.NewReader([]byte("hello world")) - - tr := NopCloserWithSize(r) - tr = newTimingReadCloser(tr, "", m.opsDuration, m.opsFailures, func(err error) bool { + tr := newTimingReader(r, true, "", m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes) @@ -230,58 +431,38 @@ func TestTimingTracingReader(t *testing.T) { testutil.Ok(t, err) testutil.Equals(t, int64(11), size) -} - -func TestUploadKeepsSeekerObj(t *testing.T) { - r := prometheus.NewRegistry() - m := seekerTestBucket{ - Bucket: WrapWithMetrics(NewInMemBucket(), r, ""), - } - testutil.Ok(t, m.Upload(context.Background(), "dir/obj1", bytes.NewReader([]byte("1")))) -} + // Given the reader was bytes.Reader it should both implement io.Seeker and io.ReaderAt. + _, isSeeker := tr.(io.Seeker) + testutil.Assert(t, isSeeker) -// seekerBucket implements Bucket and checks if io.Reader is still seekable. -type seekerTestBucket struct { - Bucket + _, isReaderAt := tr.(io.ReaderAt) + testutil.Assert(t, isReaderAt) } -func (b seekerTestBucket) Upload(ctx context.Context, name string, r io.Reader) error { - _, ok := r.(io.Seeker) - if !ok { - return errors.New("Reader was supposed to be seekable") - } +func TestTimingReader_ShouldCorrectlyWrapFile(t *testing.T) { + // Create a test file. + testFilepath := filepath.Join(t.TempDir(), "test") + testutil.Ok(t, os.WriteFile(testFilepath, []byte("test"), os.ModePerm)) - return nil -} + // Open the file (it will be used as io.Reader). + file, err := os.Open(testFilepath) + testutil.Ok(t, err) + t.Cleanup(func() { + testutil.Ok(t, file.Close()) + }) -func TestTimingTracingReaderSeeker(t *testing.T) { m := WrapWithMetrics(NewInMemBucket(), nil, "") - r := bytes.NewReader([]byte("hello world")) - - tr := nopSeekerCloserWithSize(r).(io.ReadCloser) - tr = newTimingReadCloser(tr, "", m.opsDuration, m.opsFailures, func(err error) bool { + r := newTimingReader(file, true, "", m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes) - size, err := TryToGetSize(tr) - - testutil.Ok(t, err) - testutil.Equals(t, int64(11), size) - - smallBuf := make([]byte, 4) - n, err := io.ReadFull(tr, smallBuf) - testutil.Ok(t, err) - testutil.Equals(t, 4, n) - - // Verify that size is still the same, after reading 4 bytes. - size, err = TryToGetSize(tr) - - testutil.Ok(t, err) - testutil.Equals(t, int64(11), size) + // It must both implement io.Seeker and io.ReaderAt. + _, isSeeker := r.(io.Seeker) + testutil.Assert(t, isSeeker) - _, ok := tr.(io.Seeker) - testutil.Equals(t, true, ok) + _, isReaderAt := r.(io.ReaderAt) + testutil.Assert(t, isReaderAt) } func TestDownloadDir_CleanUp(t *testing.T) { @@ -316,3 +497,47 @@ func (b unreliableBucket) Get(ctx context.Context, name string) (io.ReadCloser, } return b.Bucket.Get(ctx, name) } + +// mockReader implements io.ReadCloser and allows to mock the functions. +type mockReader struct { + io.Reader + + close func() error +} + +func (r *mockReader) Close() error { + if r.close != nil { + return r.close() + } + return nil +} + +// mockBucket implements Bucket and allows to mock the functions. +type mockBucket struct { + Bucket + + upload func(ctx context.Context, name string, r io.Reader) error + get func(ctx context.Context, name string) (io.ReadCloser, error) + getRange func(ctx context.Context, name string, off, length int64) (io.ReadCloser, error) +} + +func (b *mockBucket) Upload(ctx context.Context, name string, r io.Reader) error { + if b.upload != nil { + return b.upload(ctx, name, r) + } + return errors.New("Upload has not been mocked") +} + +func (b *mockBucket) Get(ctx context.Context, name string) (io.ReadCloser, error) { + if b.get != nil { + return b.get(ctx, name) + } + return nil, errors.New("Get has not been mocked") +} + +func (b *mockBucket) GetRange(ctx context.Context, name string, off, length int64) (io.ReadCloser, error) { + if b.getRange != nil { + return b.getRange(ctx, name, off, length) + } + return nil, errors.New("GetRange has not been mocked") +}