Skip to content

Commit 3d4e62f

Browse files
fix(storage): fix stream termination in MRD. (#11432)
1. Make CloseSend() call before releasing resource. 2. Drain inbound response from the stream.
1 parent 3ec1119 commit 3d4e62f

File tree

2 files changed

+189
-19
lines changed

2 files changed

+189
-19
lines changed

storage/grpc_client.go

+72-19
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,8 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params
11851185
done: false,
11861186
activeTask: 0,
11871187
streamRecreation: false,
1188+
endReceiver: false,
1189+
endSender: false,
11881190
}
11891191

11901192
// streamManager goroutine runs in background where we send message to gcs and process response.
@@ -1195,18 +1197,21 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params
11951197
case <-rr.ctx.Done():
11961198
rr.mu.Lock()
11971199
rr.done = true
1200+
rr.endSender = true
1201+
if rr.stream != nil {
1202+
rr.stream.CloseSend()
1203+
}
11981204
rr.mu.Unlock()
11991205
return
12001206
case <-rr.managerRetry:
1207+
// We are not closing stream here as it is already closed and we are retring it.
12011208
return
12021209
case <-rr.closeManager:
12031210
rr.mu.Lock()
1204-
if len(rr.mp) != 0 {
1205-
for key := range rr.mp {
1206-
rr.mp[key].callback(rr.mp[key].offset, rr.mp[key].totalBytesWritten, fmt.Errorf("stream closed early"))
1207-
delete(rr.mp, key)
1208-
}
1211+
if rr.stream != nil {
1212+
rr.stream.CloseSend()
12091213
}
1214+
rr.endSender = true
12101215
rr.activeTask = 0
12111216
rr.mu.Unlock()
12121217
return
@@ -1255,11 +1260,29 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params
12551260
for {
12561261
select {
12571262
case <-rr.ctx.Done():
1263+
rr.mu.Lock()
1264+
rr.endReceiver = true
12581265
rr.done = true
1266+
if len(rr.mp) != 0 {
1267+
drainInboundReadStream(rr.stream)
1268+
}
1269+
for key := range rr.mp {
1270+
rr.mp[key].callback(rr.mp[key].offset, rr.mp[key].totalBytesWritten, rr.ctx.Err())
1271+
delete(rr.mp, key)
1272+
}
1273+
rr.activeTask = 0
1274+
rr.mu.Unlock()
12591275
return
12601276
case <-rr.receiverRetry:
1277+
// We are not draining from stream here as it is already closed and we are retring it.
12611278
return
12621279
case <-rr.closeReceiver:
1280+
rr.mu.Lock()
1281+
if len(rr.mp) != 0 {
1282+
drainInboundReadStream(rr.stream)
1283+
}
1284+
rr.endReceiver = true
1285+
rr.mu.Unlock()
12631286
return
12641287
default:
12651288
// This function reads the data sent for a particular range request and has a callback
@@ -1269,7 +1292,9 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params
12691292
rr.readHandle = resp.GetReadHandle().GetHandle()
12701293
}
12711294
if err == io.EOF {
1272-
err = nil
1295+
rr.mu.Lock()
1296+
rr.endReceiver = true
1297+
rr.mu.Unlock()
12731298
}
12741299
if err != nil {
12751300
// cancel stream and reopen the stream again.
@@ -1341,6 +1366,8 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params
13411366
err = rr.retryStream(err)
13421367
if err != nil {
13431368
rr.mu.Lock()
1369+
rr.endReceiver = true
1370+
rr.endSender = true
13441371
for key := range rr.mp {
13451372
rr.mp[key].callback(rr.mp[key].offset, rr.mp[key].totalBytesWritten, err)
13461373
delete(rr.mp, key)
@@ -1350,6 +1377,10 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params
13501377
rr.mu.Unlock()
13511378
rr.close()
13521379
} else {
1380+
rr.mu.Lock()
1381+
rr.endReceiver = false
1382+
rr.endSender = false
1383+
rr.mu.Unlock()
13531384
// If stream recreation happened successfully lets again start
13541385
// both the goroutine making the whole flow asynchronous again.
13551386
if thread == "receiver" {
@@ -1483,18 +1514,39 @@ func (mr *gRPCBidiReader) wait() {
14831514

14841515
// Close will notify stream manager goroutine that the reader has been closed, if it's still running.
14851516
func (mr *gRPCBidiReader) close() error {
1486-
if mr.cancel != nil {
1487-
mr.cancel()
1488-
}
1517+
mr.closeManager <- true
1518+
mr.closeReceiver <- true
14891519
mr.mu.Lock()
1520+
for key := range mr.mp {
1521+
mr.mp[key].callback(mr.mp[key].offset, mr.mp[key].totalBytesWritten, fmt.Errorf("stream closed early"))
1522+
delete(mr.mp, key)
1523+
}
14901524
mr.done = true
14911525
mr.activeTask = 0
14921526
mr.mu.Unlock()
1493-
mr.closeReceiver <- true
1494-
mr.closeManager <- true
1527+
mr.mu.Lock()
1528+
tryClosing := !(mr.endReceiver && mr.endSender)
1529+
mr.mu.Unlock()
1530+
1531+
for tryClosing {
1532+
mr.mu.Lock()
1533+
tryClosing = !(mr.endReceiver && mr.endSender)
1534+
mr.mu.Unlock()
1535+
}
1536+
defer mr.cancel()
14951537
return nil
14961538
}
14971539

1540+
// drainInboundReadStream calls stream.Recv() repeatedly until an error is returned.
1541+
// drainInboundReadStream always returns a non-nil error. io.EOF indicates all
1542+
// messages were successfully read.
1543+
func drainInboundReadStream(stream storagepb.Storage_BidiReadObjectClient) (err error) {
1544+
for err == nil {
1545+
_, err = stream.Recv()
1546+
}
1547+
return err
1548+
}
1549+
14981550
func (mrr *gRPCBidiReader) getHandle() []byte {
14991551
return mrr.readHandle
15001552
}
@@ -1925,6 +1977,8 @@ type gRPCBidiReader struct {
19251977
objectSize int64 // always use the mutex when accessing this variable
19261978
retrier func(error, string)
19271979
streamRecreation bool // This helps us identify if stream recreation is in progress or not. If stream recreation gets called from two goroutine then this will stop second one.
1980+
endReceiver bool
1981+
endSender bool
19281982
}
19291983

19301984
// gRPCReader is used by storage.Reader if the experimental option WithGRPCBidiReads is passed.
@@ -2653,11 +2707,11 @@ func bucketContext(ctx context.Context, bucket string) context.Context {
26532707
return gax.InsertMetadataIntoOutgoingContext(ctx, hds...)
26542708
}
26552709

2656-
// drainInboundStream calls stream.Recv() repeatedly until an error is returned.
2710+
// drainInboundWriteStream calls stream.Recv() repeatedly until an error is returned.
26572711
// It returns the last Resource received on the stream, or nil if no Resource
2658-
// was returned. drainInboundStream always returns a non-nil error. io.EOF
2712+
// was returned. drainInboundWriteStream always returns a non-nil error. io.EOF
26592713
// indicates all messages were successfully read.
2660-
func drainInboundStream(stream storagepb.Storage_BidiWriteObjectClient) (object *storagepb.Object, err error) {
2714+
func drainInboundWriteStream(stream storagepb.Storage_BidiWriteObjectClient) (object *storagepb.Object, err error) {
26612715
for err == nil {
26622716
var resp *storagepb.BidiWriteObjectResponse
26632717
resp, err = stream.Recv()
@@ -2737,7 +2791,7 @@ func (s *gRPCOneshotBidiWriteBufferSender) sendBuffer(ctx context.Context, buf [
27372791

27382792
sendErr := s.stream.Send(req)
27392793
if sendErr != nil {
2740-
obj, err = drainInboundStream(s.stream)
2794+
obj, err = drainInboundWriteStream(s.stream)
27412795
s.stream = nil
27422796
if sendErr != io.EOF {
27432797
err = sendErr
@@ -2750,7 +2804,7 @@ func (s *gRPCOneshotBidiWriteBufferSender) sendBuffer(ctx context.Context, buf [
27502804
s.stream.CloseSend()
27512805
// Oneshot uploads only read from the response stream on completion or
27522806
// failure
2753-
obj, err = drainInboundStream(s.stream)
2807+
obj, err = drainInboundWriteStream(s.stream)
27542808
s.stream = nil
27552809
if err == io.EOF {
27562810
err = nil
@@ -2862,7 +2916,7 @@ func (s *gRPCResumableBidiWriteBufferSender) sendBuffer(ctx context.Context, buf
28622916

28632917
sendErr := s.stream.Send(req)
28642918
if sendErr != nil {
2865-
obj, err = drainInboundStream(s.stream)
2919+
obj, err = drainInboundWriteStream(s.stream)
28662920
s.stream = nil
28672921
if err == io.EOF {
28682922
// This is unexpected - we got an error on Send(), but not on Recv().
@@ -2874,7 +2928,7 @@ func (s *gRPCResumableBidiWriteBufferSender) sendBuffer(ctx context.Context, buf
28742928

28752929
if finishWrite {
28762930
s.stream.CloseSend()
2877-
obj, err = drainInboundStream(s.stream)
2931+
obj, err = drainInboundWriteStream(s.stream)
28782932
s.stream = nil
28792933
if err == io.EOF {
28802934
err = nil
@@ -2971,6 +3025,5 @@ func checkCanceled(err error) error {
29713025
if status.Code(err) == codes.Canceled {
29723026
return context.Canceled
29733027
}
2974-
29753028
return err
29763029
}

storage/integration_test.go

+117
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,123 @@ func TestIntegration_MRDWithNonRetriableError(t *testing.T) {
597597
})
598598
}
599599

600+
// Test that context cancellation correctly stops a multi range download before completion.
601+
func TestIntegration_MultiRangeDownloaderContextCancel(t *testing.T) {
602+
multiTransportTest(skipHTTP("gRPC implementation specific test"), t, func(t *testing.T, ctx context.Context, bucket, _ string, client *Client) {
603+
ctx, close := context.WithDeadline(ctx, time.Now().Add(time.Second*30))
604+
defer close()
605+
content := make([]byte, 5<<20)
606+
rand.New(rand.NewSource(0)).Read(content)
607+
objName := "mrdnonretry"
608+
// Upload test data.
609+
obj := client.Bucket(bucket).Object(objName)
610+
if err := writeObject(ctx, obj, "text/plain", content); err != nil {
611+
t.Fatal(err)
612+
}
613+
defer func() {
614+
if err := obj.Delete(ctx); err != nil {
615+
log.Printf("failed to delete test object: %v", err)
616+
}
617+
}()
618+
// Create a multi-range-reader and then cancel the context before completing the reads.
619+
readerCtx, cancel := context.WithCancel(ctx)
620+
reader, err := obj.NewMultiRangeDownloader(readerCtx)
621+
if err != nil {
622+
t.Fatalf("NewMultiRangeDownloader: %v", err)
623+
}
624+
res := make([]multiRangeDownloaderOutput, 3)
625+
callback := func(x, y int64, err error) {
626+
res[0].offset = x
627+
res[0].limit = y
628+
res[0].err = err
629+
}
630+
callback1 := func(x, y int64, err error) {
631+
res[1].offset = x
632+
res[1].limit = y
633+
res[1].err = err
634+
}
635+
callback2 := func(x, y int64, err error) {
636+
res[2].offset = x
637+
res[2].limit = y
638+
res[2].err = err
639+
}
640+
// Add one range on the reader, and then cancel the context.
641+
reader.Add(&res[0].buf, 0, int64(len(content)), callback)
642+
// As context is cancelled remaining ranges would result in context cancelled error or stream is closed errors.
643+
cancel()
644+
reader.Add(&res[1].buf, -10, 0, callback1)
645+
reader.Add(&res[2].buf, 0, 10, callback2)
646+
reader.Wait()
647+
// we can get stream is closed, can't add range error in case process is over before we add the range.
648+
expErr := fmt.Errorf("stream is closed, can't add range")
649+
for i, k := range res {
650+
// if we get nil error for any callback other than first, that should be an error.
651+
if i == 0 && k.err == nil && !bytes.Equal(content, k.buf.Bytes()) {
652+
t.Errorf("Error in read range offset %v, limit %v, got: %v; want: %v",
653+
k.offset, k.limit, len(k.buf.Bytes()), len(content))
654+
}
655+
if k.err == nil && k.err.Error() != expErr.Error() && !errors.Is(err, context.Canceled) && !(status.Code(err) == codes.Canceled) {
656+
t.Fatalf("read range %v to %v: got error %v, want nil, context.Canceled or stream is closed error", k.offset, k.limit, k.err)
657+
}
658+
}
659+
if err = reader.Close(); err != nil {
660+
t.Fatalf("Error while closing reader %v", err)
661+
}
662+
})
663+
}
664+
665+
func TestIntegration_MultiRangeDownloaderSuddenClose(t *testing.T) {
666+
multiTransportTest(skipHTTP("gRPC implementation specific test"), t, func(t *testing.T, ctx context.Context, bucket string, _ string, client *Client) {
667+
content := make([]byte, 5<<20)
668+
rand.New(rand.NewSource(0)).Read(content)
669+
objName := "MultiRangeDownloader"
670+
671+
// Upload test data.
672+
obj := client.Bucket(bucket).Object(objName)
673+
if err := writeObject(ctx, obj, "text/plain", content); err != nil {
674+
t.Fatal(err)
675+
}
676+
defer func() {
677+
if err := obj.Delete(ctx); err != nil {
678+
log.Printf("failed to delete test object: %v", err)
679+
}
680+
}()
681+
reader, err := obj.NewMultiRangeDownloader(ctx)
682+
if err != nil {
683+
t.Fatalf("NewMultiRangeDownloader: %v", err)
684+
}
685+
res := make([]multiRangeDownloaderOutput, 3)
686+
callback := func(x, y int64, err error) {
687+
res[0].offset = x
688+
res[0].limit = y
689+
res[0].err = err
690+
}
691+
callback1 := func(x, y int64, err error) {
692+
res[1].offset = x
693+
res[1].limit = y
694+
res[1].err = err
695+
}
696+
callback2 := func(x, y int64, err error) {
697+
res[2].offset = x
698+
res[2].limit = y
699+
res[2].err = err
700+
}
701+
// Add three ranges on the reader, and then do a sudden close.
702+
reader.Add(&res[0].buf, 0, int64(len(content)), callback)
703+
reader.Close()
704+
reader.Add(&res[1].buf, -10, 0, callback1)
705+
reader.Add(&res[2].buf, 0, 10, callback2)
706+
// we can get stream is closed, can't add range error in case process is over before we add the range.
707+
expErr := fmt.Errorf("stream is closed, can't add range")
708+
expErr2 := fmt.Errorf("stream closed early")
709+
for _, k := range res {
710+
if k.err.Error() != expErr.Error() && k.err.Error() != expErr2.Error() {
711+
t.Fatalf("read range %v to %v: got error %v, want stream closed error", k.offset, k.limit, k.err)
712+
}
713+
}
714+
})
715+
}
716+
600717
// Test in a GCE environment expected to be located in one of:
601718
// - us-west1-a, us-west1-b, us-west-c
602719
//

0 commit comments

Comments
 (0)