Skip to content

Commit 209e466

Browse files
committed
Handle cancellation of GRPC stream. Closes #17
1 parent 8a9ee47 commit 209e466

File tree

5 files changed

+30
-6
lines changed

5 files changed

+30
-6
lines changed

grpc/shared/grpc.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ func (c *GRPCClient) GetSchema(req *proto.GetSchemaRequest) (*proto.GetSchemaRes
2020
return c.client.GetSchema(c.ctx, req)
2121
}
2222

23-
func (c *GRPCClient) Execute(req *proto.ExecuteRequest) (proto.WrapperPlugin_ExecuteClient, error) {
24-
return c.client.Execute(c.ctx, req)
23+
func (c *GRPCClient) Execute(req *proto.ExecuteRequest) (proto.WrapperPlugin_ExecuteClient, context.CancelFunc, error) {
24+
ctx, cancel := context.WithCancel(c.ctx)
25+
client, err := c.client.Execute(ctx, req)
26+
return client, cancel, err
2527
}
2628

2729
func (c *GRPCClient) SetConnectionConfig(req *proto.SetConnectionConfigRequest) (*proto.SetConnectionConfigResponse, error) {

grpc/shared/interface.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type WrapperPluginServer interface {
2929

3030
type WrapperPluginClient interface {
3131
GetSchema(request *proto.GetSchemaRequest) (*proto.GetSchemaResponse, error)
32-
Execute(req *proto.ExecuteRequest) (proto.WrapperPlugin_ExecuteClient, error)
32+
Execute(req *proto.ExecuteRequest) (proto.WrapperPlugin_ExecuteClient, context.CancelFunc, error)
3333
SetConnectionConfig(req *proto.SetConnectionConfigRequest) (*proto.SetConnectionConfigResponse, error)
3434
}
3535

plugin/errors.go

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package plugin
2+
3+
const contextCancelledError = "rpc error: code = Canceled desc = context canceled"

plugin/hydrate.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,13 @@ func WrapHydrate(hydrateFunc HydrateFunc, shouldIgnoreError ErrorPredicate) Hydr
3636
return func(ctx context.Context, d *QueryData, h *HydrateData) (item interface{}, err error) {
3737
defer func() {
3838
if r := recover(); r != nil {
39-
log.Printf("[WARN] recovered a panic from a wrapped hydrate function: %v\n", r)
40-
err = status.Error(codes.Internal, fmt.Sprintf("hydrate function %s failed with panic %v", helpers.GetFunctionName(hydrateFunc), r))
39+
if helpers.ToError(r).Error() == contextCancelledError {
40+
// if the error was a context cancellation, just trace it - this is not an error
41+
log.Printf("[TRACE] hydrate function %s terminated with a context cancellation", helpers.GetFunctionName(hydrateFunc))
42+
} else {
43+
log.Printf("[WARN] recovered a panic from a wrapped hydrate function: %v\n", r)
44+
err = status.Error(codes.Internal, fmt.Sprintf("hydrate function %s failed with panic %v", helpers.GetFunctionName(hydrateFunc), r))
45+
}
4146
}
4247
}()
4348
// call the underlying get function

plugin/query_data.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package plugin
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"log"
78
"sync"
@@ -54,6 +55,9 @@ type QueryData struct {
5455
listWg sync.WaitGroup
5556
// when executing parent child list calls, we cache the parent list result in the query data passed to the child list call
5657
parentItem interface{}
58+
59+
// there was an error streaming to the grpc stream
60+
streamingError error
5761
}
5862

5963
func newQueryData(queryContext *QueryContext, table *Table, stream proto.WrapperPlugin_ExecuteServer, connection *Connection, matrix []map[string]interface{}, connectionManager *connection_manager.Manager) *QueryData {
@@ -261,6 +265,11 @@ func (d *QueryData) verifyCallerIsListCall(callingFunction string) bool {
261265
}
262266

263267
func (d *QueryData) streamLeafListItem(ctx context.Context, item interface{}) {
268+
if d.streamingError != nil {
269+
// if there is streaming error, panic to force exit thread - this will be recovered higher up
270+
panic(d.streamingError)
271+
}
272+
264273
// create rowData, passing matrixItem from context
265274
rd := newRowData(d, item)
266275
rd.matrixItem = GetMatrixItem(ctx)
@@ -283,6 +292,8 @@ func (d *QueryData) streamRows(_ context.Context, rowChan chan *proto.Row) error
283292
for {
284293
// wait for either an item or an error
285294
select {
295+
case <-d.stream.Context().Done():
296+
d.streamingError = errors.New(contextCancelledError)
286297
case err := <-d.errorChan:
287298
log.Printf("[ERROR] streamRows error chan select: %v\n", err)
288299
return err
@@ -295,7 +306,10 @@ func (d *QueryData) streamRows(_ context.Context, rowChan chan *proto.Row) error
295306
return nil
296307
}
297308
if err := d.streamRow(row); err != nil {
298-
log.Printf("[ERROR] stream.Send returned error: %v\n", err)
309+
// if there was an error streaming, store in d.streamingError
310+
// - this is checked by the thread streaming list items and will cause it to terminate
311+
d.streamingError = err
312+
299313
return err
300314
}
301315
}

0 commit comments

Comments
 (0)