diff --git a/spanner/request_id_header.go b/spanner/request_id_header.go index 9fcd9377f432..51fa93af8c1a 100644 --- a/spanner/request_id_header.go +++ b/spanner/request_id_header.go @@ -116,7 +116,7 @@ func (r requestID) augmentErrorWithRequestID(err error) error { } } -func gRPCCallOptionsToRequestID(opts []grpc.CallOption) (reqID requestID, found bool) { +func gRPCCallOptionsToRequestID(opts []grpc.CallOption) (md metadata.MD, reqID requestID, found bool) { for _, opt := range opts { hdrOpt, ok := opt.(grpc.HeaderCallOption) if !ok { @@ -126,6 +126,7 @@ func gRPCCallOptionsToRequestID(opts []grpc.CallOption) (reqID requestID, found metadata := hdrOpt.HeaderAddr reqIDs := metadata.Get(xSpannerRequestIDHeader) if len(reqIDs) != 0 && len(reqIDs[0]) != 0 { + md = *metadata reqID = requestID(reqIDs[0]) found = true break @@ -137,7 +138,11 @@ func gRPCCallOptionsToRequestID(opts []grpc.CallOption) (reqID requestID, found func (wr *requestIDHeaderInjector) interceptUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { // It is imperative to search for the requestID before the call // because gRPC's internals will consume the headers. - reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + metadataWithRequestID, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + if foundRequestID { + ctx = metadata.NewOutgoingContext(ctx, metadataWithRequestID) + } + err := invoker(ctx, method, req, reply, cc, opts...) if !foundRequestID { return err @@ -174,7 +179,11 @@ type requestIDHeaderInjector int func (wr *requestIDHeaderInjector) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // It is imperative to search for the requestID before the call // because gRPC's internals will consume the headers. - reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + metadataWithRequestID, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + if foundRequestID { + ctx = metadata.NewOutgoingContext(ctx, metadataWithRequestID) + } + cs, err := streamer(ctx, desc, cc, method, opts...) if !foundRequestID { return cs, err