diff --git a/client.go b/client.go index 839b1c9c..2777b390 100644 --- a/client.go +++ b/client.go @@ -69,25 +69,25 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // once at client creation. unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - sender, receiver := protocolClient.NewStream(ctx, unarySpec, request.Header()) + conn := protocolClient.NewConn(ctx, unarySpec, request.Header()) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. - if err := sender.Send(request.Any()); err != nil && !errors.Is(err, io.EOF) { - _ = sender.Close(err) - _ = receiver.Close() + if err := conn.Send(request.Any()); err != nil && !errors.Is(err, io.EOF) { + _ = conn.CloseRequest() + _ = conn.CloseResponse() return nil, err } - if err := sender.Close(nil); err != nil { - _ = receiver.Close() + if err := conn.CloseRequest(); err != nil { + _ = conn.CloseResponse() return nil, err } - response, err := receiveUnaryResponse[Res](receiver) + response, err := receiveUnaryResponse[Res](conn) if err != nil { - _ = receiver.Close() + _ = conn.CloseResponse() return nil, err } - return response, receiver.Close() + return response, conn.CloseResponse() }) if interceptor := config.Interceptor; interceptor != nil { unaryFunc = interceptor.WrapUnary(unaryFunc) @@ -123,8 +123,7 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo if c.err != nil { return &ClientStreamForClient[Req, Res]{err: c.err} } - sender, receiver := c.newStream(ctx, StreamTypeClient) - return &ClientStreamForClient[Req, Res]{sender: sender, receiver: receiver} + return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient)} } // CallServerStream calls a server streaming procedure. @@ -132,20 +131,20 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } - sender, receiver := c.newStream(ctx, StreamTypeServer) - mergeHeaders(sender.Header(), request.header) + conn := c.newConn(ctx, StreamTypeServer) + mergeHeaders(conn.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. - if err := sender.Send(request.Msg); err != nil && !errors.Is(err, io.EOF) { - _ = sender.Close(err) - _ = receiver.Close() + if err := conn.Send(request.Msg); err != nil && !errors.Is(err, io.EOF) { + _ = conn.CloseRequest() + _ = conn.CloseResponse() return nil, err } - if err := sender.Close(nil); err != nil { + if err := conn.CloseRequest(); err != nil { return nil, err } - return &ServerStreamForClient[Res]{receiver: receiver}, nil + return &ServerStreamForClient[Res]{conn: conn}, nil } // CallBidiStream calls a bidirectional streaming procedure. @@ -153,22 +152,19 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli if c.err != nil { return &BidiStreamForClient[Req, Res]{err: c.err} } - sender, receiver := c.newStream(ctx, StreamTypeBidi) - return &BidiStreamForClient[Req, Res]{sender: sender, receiver: receiver} + return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi)} } -func (c *Client[Req, Res]) newStream(ctx context.Context, streamType StreamType) (Sender, Receiver) { - if interceptor := c.config.Interceptor; interceptor != nil { - ctx = interceptor.WrapStreamContext(ctx) +func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType) StreamingClientConn { + newConn := func(ctx context.Context, spec Spec) StreamingClientConn { + header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing + c.protocolClient.WriteRequestHeader(streamType, header) + return c.protocolClient.NewConn(ctx, spec, header) } - header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing - c.protocolClient.WriteRequestHeader(streamType, header) - sender, receiver := c.protocolClient.NewStream(ctx, c.config.newSpec(streamType), header) if interceptor := c.config.Interceptor; interceptor != nil { - sender = interceptor.WrapStreamSender(ctx, sender) - receiver = interceptor.WrapStreamReceiver(ctx, receiver) + newConn = interceptor.WrapStreamingClient(newConn) } - return sender, receiver + return newConn(ctx, c.config.newSpec(streamType)) } type clientConfig struct { diff --git a/client_stream.go b/client_stream.go index 6318c815..49cbea12 100644 --- a/client_stream.go +++ b/client_stream.go @@ -25,8 +25,7 @@ import ( // It's returned from Client.CallClientStream, but doesn't currently have an // exported constructor function. type ClientStreamForClient[Req, Res any] struct { - sender Sender - receiver Receiver + conn StreamingClientConn // Error from client construction. If non-nil, return for all calls. err error } @@ -37,7 +36,7 @@ func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header { if c.err != nil { return http.Header{} } - return c.sender.Header() + return c.conn.RequestHeader() } // Send a message to the server. The first call to Send also sends the request @@ -50,7 +49,7 @@ func (c *ClientStreamForClient[Req, Res]) Send(request *Req) error { if c.err != nil { return c.err } - return c.sender.Send(request) + return c.conn.Send(request) } // CloseAndReceive closes the send side of the stream and waits for the @@ -59,18 +58,16 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err if c.err != nil { return nil, c.err } - if err := c.sender.Close(nil); err != nil { + if err := c.conn.CloseRequest(); err != nil { + _ = c.conn.CloseResponse() return nil, err } - response, err := receiveUnaryResponse[Res](c.receiver) + response, err := receiveUnaryResponse[Res](c.conn) if err != nil { - _ = c.receiver.Close() + _ = c.conn.CloseResponse() return nil, err } - if err := c.receiver.Close(); err != nil { - return nil, err - } - return response, nil + return response, c.conn.CloseResponse() } // ServerStreamForClient is the client's view of a server streaming RPC. @@ -78,11 +75,11 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err // It's returned from Client.CallServerStream, but doesn't currently have an // exported constructor function. type ServerStreamForClient[Res any] struct { - receiver Receiver - msg Res + conn StreamingClientConn + msg Res // Error from client construction. If non-nil, return for all calls. constructErr error - // Error from Receive(). + // Error from conn.Receive(). receiveErr error } @@ -95,7 +92,7 @@ func (s *ServerStreamForClient[Res]) Receive() bool { if s.constructErr != nil || s.receiveErr != nil { return false } - s.receiveErr = s.receiver.Receive(&s.msg) + s.receiveErr = s.conn.Receive(&s.msg) return s.receiveErr == nil } @@ -123,7 +120,7 @@ func (s *ServerStreamForClient[Res]) ResponseHeader() http.Header { if s.constructErr != nil { return http.Header{} } - return s.receiver.Header() + return s.conn.ResponseHeader() } // ResponseTrailer returns the trailers received from the server. Trailers @@ -132,10 +129,7 @@ func (s *ServerStreamForClient[Res]) ResponseTrailer() http.Header { if s.constructErr != nil { return http.Header{} } - if trailer, ok := s.receiver.Trailer(); ok { - return trailer - } - return make(http.Header) + return s.conn.ResponseTrailer() } // Close the receive side of the stream. @@ -143,7 +137,7 @@ func (s *ServerStreamForClient[Res]) Close() error { if s.constructErr != nil { return s.constructErr } - return s.receiver.Close() + return s.conn.CloseResponse() } // BidiStreamForClient is the client's view of a bidirectional streaming RPC. @@ -151,8 +145,7 @@ func (s *ServerStreamForClient[Res]) Close() error { // It's returned from Client.CallBidiStream, but doesn't currently have an // exported constructor function. type BidiStreamForClient[Req, Res any] struct { - sender Sender - receiver Receiver + conn StreamingClientConn // Error from client construction. If non-nil, return for all calls. err error } @@ -163,7 +156,7 @@ func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header { if b.err != nil { return http.Header{} } - return b.sender.Header() + return b.conn.RequestHeader() } // Send a message to the server. The first call to Send also sends the request @@ -176,15 +169,15 @@ func (b *BidiStreamForClient[Req, Res]) Send(msg *Req) error { if b.err != nil { return b.err } - return b.sender.Send(msg) + return b.conn.Send(msg) } -// CloseSend closes the send side of the stream. -func (b *BidiStreamForClient[Req, Res]) CloseSend() error { +// CloseRequest closes the send side of the stream. +func (b *BidiStreamForClient[Req, Res]) CloseRequest() error { if b.err != nil { return b.err } - return b.sender.Close(nil) + return b.conn.CloseRequest() } // Receive a message. When the server is done sending messages and no other @@ -194,18 +187,18 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) { return nil, b.err } var msg Res - if err := b.receiver.Receive(&msg); err != nil { + if err := b.conn.Receive(&msg); err != nil { return nil, err } return &msg, nil } -// CloseReceive closes the receive side of the stream. -func (b *BidiStreamForClient[Req, Res]) CloseReceive() error { +// CloseResponse closes the receive side of the stream. +func (b *BidiStreamForClient[Req, Res]) CloseResponse() error { if b.err != nil { return b.err } - return b.receiver.Close() + return b.conn.CloseResponse() } // ResponseHeader returns the headers received from the server. It blocks until @@ -214,7 +207,7 @@ func (b *BidiStreamForClient[Req, Res]) ResponseHeader() http.Header { if b.err != nil { return http.Header{} } - return b.receiver.Header() + return b.conn.ResponseHeader() } // ResponseTrailer returns the trailers received from the server. Trailers @@ -223,8 +216,5 @@ func (b *BidiStreamForClient[Req, Res]) ResponseTrailer() http.Header { if b.err != nil { return http.Header{} } - if trailer, ok := b.receiver.Trailer(); ok { - return trailer - } - return make(http.Header) + return b.conn.ResponseTrailer() } diff --git a/client_stream_test.go b/client_stream_test.go index 50d98eca..6a2e9f9b 100644 --- a/client_stream_test.go +++ b/client_stream_test.go @@ -57,8 +57,8 @@ func TestBidiStreamForClient_NoPanics(t *testing.T) { verifyHeaders(t, bidiStream.ResponseHeader()) verifyHeaders(t, bidiStream.ResponseTrailer()) assert.ErrorIs(t, bidiStream.Send(&pingv1.CumSumRequest{}), initErr) - assert.ErrorIs(t, bidiStream.CloseReceive(), initErr) - assert.ErrorIs(t, bidiStream.CloseSend(), initErr) + assert.ErrorIs(t, bidiStream.CloseRequest(), initErr) + assert.ErrorIs(t, bidiStream.CloseResponse(), initErr) } func verifyHeaders(t *testing.T, headers http.Header) { diff --git a/code.go b/code.go index 852f88e4..96a47a63 100644 --- a/code.go +++ b/code.go @@ -142,10 +142,12 @@ func (c Code) String() string { return fmt.Sprintf("code_%d", c) } +// MarshalText implements encoding.TextMarshaler. func (c Code) MarshalText() ([]byte, error) { return []byte(c.String()), nil } +// UnmarshalText implements encoding.TextUnmarshaler. func (c *Code) UnmarshalText(data []byte) error { dataStr := string(data) switch dataStr { diff --git a/connect.go b/connect.go index 21311528..ded7e68d 100644 --- a/connect.go +++ b/connect.go @@ -49,51 +49,67 @@ const ( StreamTypeBidi = StreamTypeClient | StreamTypeServer ) -// Sender is the writable side of a bidirectional stream of messages. Sender -// implementations do not need to be safe for concurrent use. +// StreamingHandlerConn is the server's view of a bidirectional message +// exchange. Interceptors for streaming RPCs may wrap StreamingHandlerConns. // -// Sender implementations provided by this module guarantee that all returned -// errors can be cast to *Error using errors.As. The Close method of Sender -// implementations provided by this module automatically adds the appropriate -// codes when passed context.DeadlineExceeded or context.Canceled. +// Like the standard library's http.ResponseWriter, StreamingHandlerConns write +// response headers to the network with the first call to Send. Any subsequent +// mutations are effectively no-ops. Handlers may mutate response trailers at +// any time before returning. When the client has finished sending data, +// Receive returns an error wrapping io.EOF. Handlers should check for this +// using the standard library's errors.Is. // -// Like the standard library's http.ResponseWriter, both client- and -// handler-side Senders write headers to the network with the first call to -// Send. Any subsequent mutations to the headers are effectively no-ops. +// StreamingHandlerConn implementations provided by this module guarantee that +// all returned errors can be cast to *Error using the standard library's +// errors.As. // -// Handler-side Senders may mutate trailers until calling Close, when the -// trailers are written to the network. Clients may not send trailers, since -// the gRPC, gRPC-Web, and Connect protocols all forbid it. -// -// Once servers return an error, they're not interested in receiving additional -// messages and clients should stop sending them. Client-side Senders indicate -// this by returning a wrapped io.EOF from Send. Clients should check for this -// condition with the standard library's errors.Is and call the receiver's -// Receive method to unmarshal the error. -type Sender interface { - Send(any) error - Close(error) error - +// StreamingHandlerConn implementations do not need to be safe for concurrent use. +type StreamingHandlerConn interface { Spec() Spec - Header() http.Header - Trailer() (http.Header, bool) + + Receive(any) error + RequestHeader() http.Header + + Send(any) error + ResponseHeader() http.Header + ResponseTrailer() http.Header } -// Receiver is the readable side of a bidirectional stream of messages. -// Receiver implementations do not need to be safe for concurrent use. +// StreamingClientConn is the client's view of a bidirectional message exchange. +// Interceptors for streaming RPCs may wrap StreamingClientConns. // -// Receiver implementations provided by this module guarantee that all returned -// errors can be cast to *Error using errors.As. +// StreamingClientConns write request headers to the network with the first +// call to Send. Any subsequent mutations are effectively no-ops. When the +// server is done sending data, the StreamingClientConn's Receive method +// returns an error wrapping io.EOF. Clients should check for this using the +// standard library's errors.Is. If the server encounters an error during +// processing, subsequent calls to the StreamingClientConn's Send method will +// return an error wrapping io.EOF; clients may then call Receive to unmarshal +// the error. // -// Only client-side Receivers may read trailers. -type Receiver interface { - Receive(any) error - Close() error - +// StreamingClientConn implementations provided by this module guarantee that +// all returned errors can be cast to *Error using the standard library's +// errors.As. +// +// In order to support bidirectional streaming RPCs, all StreamingClientConn +// implementations must support limited concurrent use. See the comments on +// each group of methods for details. +type StreamingClientConn interface { + // Spec must be safe to call concurrently with all other methods. Spec() Spec - Header() http.Header - // Trailers are populated only after Receive returns an error. - Trailer() (http.Header, bool) + + // Send, RequestHeader, and CloseRequest may race with each other, but must + // be safe to call concurrently with all other methods. + Send(any) error + RequestHeader() http.Header + CloseRequest() error + + // Receive, ResponseHeader, ResponseTrailer, and CloseResponse may race with + // each other, but must be safe to call concurrently with all other methods. + Receive(any) error + ResponseHeader() http.Header + ResponseTrailer() http.Header + CloseResponse() error } // Request is a wrapper around a generated request message. It provides @@ -225,28 +241,34 @@ type Spec struct { IsClient bool // otherwise we're in a handler } -// receiveUnaryResponse unmarshals a message from a Receiver, then envelopes -// the message and attaches the Receiver's headers and trailers. It attempts to -// consume the Receiver and isn't appropriate when receiving multiple messages. -func receiveUnaryResponse[T any](receiver Receiver) (*Response[T], error) { +// handlerConnCloser extends HandlerConn with a method for handlers to +// terminate the message exchange (and optionally send an error to the client). +type handlerConnCloser interface { + StreamingHandlerConn + + Close(error) error +} + +// receiveUnaryResponse unmarshals a message from a StreamingClientConn, then +// envelopes the message and attaches headers and trailers. It attempts to +// consume the response stream and isn't appropriate when receiving multiple +// messages. +func receiveUnaryResponse[T any](conn StreamingClientConn) (*Response[T], error) { var msg T - if err := receiver.Receive(&msg); err != nil { + if err := conn.Receive(&msg); err != nil { return nil, err } // In a well-formed stream, the response message may be followed by a block // of in-stream trailers or HTTP trailers. To ensure that we receive the // trailers, try to read another message from the stream. - if err := receiver.Receive(new(T)); err == nil { + if err := conn.Receive(new(T)); err == nil { return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages")) } else if err != nil && !errors.Is(err, io.EOF) { return nil, NewError(CodeUnknown, err) } - response := &Response[T]{ - Msg: &msg, - header: receiver.Header(), - } - if trailer, ok := receiver.Trailer(); ok { - response.trailer = trailer - } - return response, nil + return &Response[T]{ + Msg: &msg, + header: conn.ResponseHeader(), + trailer: conn.ResponseTrailer(), + }, nil } diff --git a/connect_ext_test.go b/connect_ext_test.go index 43016957..6a0708fb 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -203,7 +203,7 @@ func TestServer(t *testing.T) { err := stream.Send(&pingv1.CumSumRequest{Number: n}) assert.Nil(t, err, assert.Sprintf("send error #%d", i)) } - assert.Nil(t, stream.CloseSend()) + assert.Nil(t, stream.CloseRequest()) }() go func() { defer wg.Done() @@ -215,7 +215,7 @@ func TestServer(t *testing.T) { assert.Nil(t, err) got = append(got, msg.Sum) } - assert.Nil(t, stream.CloseReceive()) + assert.Nil(t, stream.CloseResponse()) }() wg.Wait() assert.Equal(t, got, expect) @@ -246,11 +246,11 @@ func TestServer(t *testing.T) { } // Deliberately closing with calling Send to test the behavior of Receive. // This test case is based on the grpc interop tests. - assert.Nil(t, stream.CloseSend()) + assert.Nil(t, stream.CloseRequest()) response, err := stream.Receive() assert.Nil(t, response) assert.True(t, errors.Is(err, io.EOF)) - assert.Nil(t, stream.CloseReceive()) // clean-up the stream + assert.Nil(t, stream.CloseResponse()) // clean-up the stream }) t.Run("cumsum_cancel_after_first_response", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -510,6 +510,13 @@ func TestGRPCMissingTrailersError(t *testing.T) { ) } + assertNilOrEOF := func(t *testing.T, err error) { + t.Helper() + if err != nil { + assert.ErrorIs(t, err, io.EOF) + } + } + t.Run("ping", func(t *testing.T) { t.Parallel() request := connect.NewRequest(&pingv1.PingRequest{Number: 1, Text: "foobar"}) @@ -520,7 +527,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { t.Parallel() stream := client.Sum(context.Background()) err := stream.Send(&pingv1.SumRequest{Number: 1}) - assert.Nil(t, err) + assertNilOrEOF(t, err) _, err = stream.CloseAndReceive() assertErrorNoTrailers(t, err) }) @@ -534,20 +541,19 @@ func TestGRPCMissingTrailersError(t *testing.T) { t.Run("cumsum", func(t *testing.T) { t.Parallel() stream := client.CumSum(context.Background()) - err := stream.Send(&pingv1.CumSumRequest{Number: 10}) - assert.Nil(t, err) - _, err = stream.Receive() + assertNilOrEOF(t, stream.Send(&pingv1.CumSumRequest{Number: 10})) + _, err := stream.Receive() assertErrorNoTrailers(t, err) - assert.Nil(t, stream.CloseReceive()) + assert.Nil(t, stream.CloseResponse()) }) t.Run("cumsum_empty_stream", func(t *testing.T) { t.Parallel() stream := client.CumSum(context.Background()) - assert.Nil(t, stream.CloseSend()) + assert.Nil(t, stream.CloseRequest()) response, err := stream.Receive() assert.Nil(t, response) assertErrorNoTrailers(t, err) - assert.Nil(t, stream.CloseReceive()) + assert.Nil(t, stream.CloseResponse()) }) } @@ -579,7 +585,7 @@ func TestBidiRequiresHTTP2(t *testing.T) { ) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(&pingv1.CumSumRequest{})) - assert.Nil(t, stream.CloseSend()) + assert.Nil(t, stream.CloseRequest()) _, err := stream.Receive() assert.NotNil(t, err) var connectErr *connect.Error @@ -949,7 +955,7 @@ func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSu assert.ErrorIs(tb, err, io.EOF) assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) } - assert.Nil(tb, stream.CloseSend()) + assert.Nil(tb, stream.CloseRequest()) _, err := stream.Receive() assert.NotNil(tb, err) // should be 505 assert.True( @@ -957,7 +963,7 @@ func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSu strings.Contains(err.Error(), "HTTP status 505"), assert.Sprintf("expected 505, got %v", err), ) - assert.Nil(tb, stream.CloseReceive()) + assert.Nil(tb, stream.CloseResponse()) } func expectClientHeader(check bool, req connect.AnyRequest) error { diff --git a/handler.go b/handler.go index 6b48bf72..d16a3988 100644 --- a/handler.go +++ b/handler.go @@ -27,8 +27,7 @@ import ( // standard library's compress/gzip. type Handler struct { spec Spec - interceptor Interceptor - implementation func(context.Context, Sender, Receiver, error /* client-visible */) + implementation StreamingHandlerFunc protocolHandlers []protocolHandler acceptPost string // Accept-Post header } @@ -39,82 +38,44 @@ func NewUnaryHandler[Req, Res any]( unary func(context.Context, *Request[Req]) (*Response[Res], error), options ...HandlerOption, ) *Handler { + // Wrap the strongly-typed implementation so we can apply interceptors. + untyped := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + typed, ok := request.(*Request[Req]) + if !ok { + return nil, errorf(CodeInternal, "unexpected handler request type %T", request) + } + return unary(ctx, typed) + }) config := newHandlerConfig(procedure, options) - // Given a (possibly failed) stream, how should we call the unary function? - implementation := func(ctx context.Context, sender Sender, receiver Receiver, clientVisibleError error) { - defer receiver.Close() - - var request *Request[Req] - if clientVisibleError != nil { - // The protocol implementation failed to establish a stream. To make the - // resulting error visible to the interceptor stack, we still want to - // call the wrapped unary Func. To do that safely, we need a useful - // Message struct. (Note that we do *not* actually calling the handler's - // implementation.) - request = &Request[Req]{ - Msg: new(Req), - spec: receiver.Spec(), - header: receiver.Header(), - } - } else { - var msg Req - if err := receiver.Receive(&msg); err != nil { - // Interceptors should see this error too. Just as above, they need a - // useful Message. - clientVisibleError = err - request = &Request[Req]{ - Msg: new(Req), - spec: receiver.Spec(), - header: receiver.Header(), - } - } else { - request = &Request[Req]{ - Msg: &msg, - spec: receiver.Spec(), - header: receiver.Header(), - } - } + if interceptor := config.Interceptor; interceptor != nil { + untyped = interceptor.WrapUnary(untyped) + } + // Given a stream, how should we call the unary function? + implementation := func(ctx context.Context, conn StreamingHandlerConn) error { + var msg Req + if err := conn.Receive(&msg); err != nil { + return err } - - untyped := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - if clientVisibleError != nil { - // We've already encountered an error, short-circuit before calling the - // handler's implementation. - return nil, clientVisibleError - } - if err := ctx.Err(); err != nil { - return nil, err - } - typed, ok := request.(*Request[Req]) - if !ok { - return nil, errorf(CodeInternal, "unexpected handler request type %T", request) - } - res, err := unary(ctx, typed) - if err != nil { - return nil, err - } - return res, nil - }) - if interceptor := config.Interceptor; interceptor != nil { - untyped = interceptor.WrapUnary(untyped) + request := &Request[Req]{ + Msg: &msg, + spec: conn.Spec(), + header: conn.RequestHeader(), } - response, err := untyped(ctx, request) if err != nil { - _ = sender.Close(err) - return - } - mergeHeaders(sender.Header(), response.Header()) - if trailers, ok := sender.Trailer(); ok { - mergeHeaders(trailers, response.Trailer()) + return err } - _ = sender.Close(sender.Send(response.Any())) + mergeHeaders(conn.ResponseHeader(), response.Header()) + mergeHeaders(conn.ResponseTrailer(), response.Trailer()) + return conn.Send(response.Any()) } protocolHandlers := config.newProtocolHandlers(StreamTypeUnary) return &Handler{ spec: config.newSpec(StreamTypeUnary), - interceptor: nil, // already applied implementation: implementation, protocolHandlers: protocolHandlers, acceptPost: sortedAcceptPostValue(protocolHandlers), @@ -130,23 +91,15 @@ func NewClientStreamHandler[Req, Res any]( return newStreamHandler( procedure, StreamTypeClient, - func(ctx context.Context, sender Sender, receiver Receiver) { - stream := &ClientStream[Req]{receiver: receiver} + func(ctx context.Context, conn StreamingHandlerConn) error { + stream := &ClientStream[Req]{conn: conn} res, err := implementation(ctx, stream) if err != nil { - _ = receiver.Close() - _ = sender.Close(err) - return + return err } - if err := receiver.Close(); err != nil { - _ = sender.Close(err) - return - } - mergeHeaders(sender.Header(), res.header) - if trailer, ok := sender.Trailer(); ok { - mergeHeaders(trailer, res.trailer) - } - _ = sender.Close(sender.Send(res.Msg)) + mergeHeaders(conn.ResponseHeader(), res.header) + mergeHeaders(conn.ResponseTrailer(), res.trailer) + return conn.Send(res.Msg) }, options..., ) @@ -161,25 +114,20 @@ func NewServerStreamHandler[Req, Res any]( return newStreamHandler( procedure, StreamTypeServer, - func(ctx context.Context, sender Sender, receiver Receiver) { - stream := &ServerStream[Res]{sender: sender} + func(ctx context.Context, conn StreamingHandlerConn) error { var msg Req - if err := receiver.Receive(&msg); err != nil { - _ = receiver.Close() - _ = sender.Close(err) - return - } - if err := receiver.Close(); err != nil { - _ = sender.Close(err) - return - } - request := &Request[Req]{ - Msg: &msg, - spec: receiver.Spec(), - header: receiver.Header(), + if err := conn.Receive(&msg); err != nil { + return err } - err := implementation(ctx, request, stream) - _ = sender.Close(err) + return implementation( + ctx, + &Request[Req]{ + Msg: &msg, + spec: conn.Spec(), + header: conn.RequestHeader(), + }, + &ServerStream[Res]{conn: conn}, + ) }, options..., ) @@ -194,11 +142,11 @@ func NewBidiStreamHandler[Req, Res any]( return newStreamHandler( procedure, StreamTypeBidi, - func(ctx context.Context, sender Sender, receiver Receiver) { - stream := &BidiStream[Req, Res]{sender: sender, receiver: receiver} - err := implementation(ctx, stream) - _ = receiver.Close() - _ = sender.Close(err) + func(ctx context.Context, conn StreamingHandlerConn) error { + return implementation( + ctx, + &BidiStream[Req, Res]{conn: conn}, + ) }, options..., ) @@ -223,6 +171,7 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re return } + // Find our implementation of the RPC protocol in use. contentType := request.Header.Get("Content-Type") var protocolHandler protocolHandler for _, handler := range h.protocolHandlers { @@ -236,6 +185,8 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re responseWriter.WriteHeader(http.StatusUnsupportedMediaType) return } + + // Establish a stream and serve the RPC. ctx, cancel, timeoutErr := protocolHandler.SetTimeout(request) if timeoutErr != nil { ctx = request.Context() @@ -243,27 +194,20 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re if cancel != nil { defer cancel() } - if ic := h.interceptor; ic != nil { - ctx = ic.WrapStreamContext(ctx) - } - // Most errors returned from protocolHandler.NewStream are caused by - // invalid requests. For example, the client may have specified an invalid - // timeout or an unavailable codec. We'd like those errors to be visible to - // the interceptor chain, so we're going to capture them here and pass them - // to the implementation. - sender, receiver, clientVisibleError := protocolHandler.NewStream( + connCloser, ok := protocolHandler.NewConn( responseWriter, request.WithContext(ctx), ) - if timeoutErr != nil { - clientVisibleError = timeoutErr + if !ok { + // Failed to create stream, usually because client used an unknown + // compression algorithm. Nothing further to do. + return } - if interceptor := h.interceptor; interceptor != nil { - // Unary interceptors were handled in NewUnaryHandler. - sender = interceptor.WrapStreamSender(ctx, sender) - receiver = interceptor.WrapStreamReceiver(ctx, receiver) + if timeoutErr != nil { + _ = connCloser.Close(timeoutErr) + return } - h.implementation(ctx, sender, receiver, clientVisibleError) + _ = connCloser.Close(h.implementation(ctx, connCloser)) } type handlerConfig struct { @@ -335,22 +279,17 @@ func (c *handlerConfig) newProtocolHandlers(streamType StreamType) []protocolHan func newStreamHandler( procedure string, streamType StreamType, - implementation func(context.Context, Sender, Receiver), + implementation StreamingHandlerFunc, options ...HandlerOption, ) *Handler { config := newHandlerConfig(procedure, options) + if ic := config.Interceptor; ic != nil { + implementation = ic.WrapStreamingHandler(implementation) + } protocolHandlers := config.newProtocolHandlers(streamType) return &Handler{ - spec: config.newSpec(streamType), - interceptor: config.Interceptor, - implementation: func(ctx context.Context, sender Sender, receiver Receiver, clientVisibleErr error) { - if clientVisibleErr != nil { - _ = receiver.Close() - _ = sender.Close(clientVisibleErr) - return - } - implementation(ctx, sender, receiver) - }, + spec: config.newSpec(streamType), + implementation: implementation, protocolHandlers: protocolHandlers, acceptPost: sortedAcceptPostValue(protocolHandlers), } diff --git a/handler_stream.go b/handler_stream.go index 517a4d9e..993ea4bb 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -25,14 +25,14 @@ import ( // It's constructed as part of Handler invocation, but doesn't currently have // an exported constructor. type ClientStream[Req any] struct { - receiver Receiver - msg Req - err error + conn StreamingHandlerConn + msg Req + err error } // RequestHeader returns the headers received from the client. func (c *ClientStream[Req]) RequestHeader() http.Header { - return c.receiver.Header() + return c.conn.RequestHeader() } // Receive advances the stream to the next message, which will then be @@ -44,7 +44,7 @@ func (c *ClientStream[Req]) Receive() bool { if c.err != nil { return false } - c.err = c.receiver.Receive(&c.msg) + c.err = c.conn.Receive(&c.msg) return c.err == nil } @@ -68,28 +68,25 @@ func (c *ClientStream[Req]) Err() error { // It's constructed as part of Handler invocation, but doesn't currently have // an exported constructor. type ServerStream[Res any] struct { - sender Sender + conn StreamingHandlerConn } // ResponseHeader returns the response headers. Headers are sent with the first // call to Send. func (s *ServerStream[Res]) ResponseHeader() http.Header { - return s.sender.Header() + return s.conn.ResponseHeader() } // ResponseTrailer returns the response trailers. Handlers may write to the // response trailers at any time before returning. func (s *ServerStream[Res]) ResponseTrailer() http.Header { - if trailers, ok := s.sender.Trailer(); ok { - return trailers - } - return make(http.Header) + return s.conn.ResponseTrailer() } // Send a message to the client. The first call to Send also sends the response // headers. func (s *ServerStream[Res]) Send(msg *Res) error { - return s.sender.Send(msg) + return s.conn.Send(msg) } // BidiStream is the handler's view of a bidirectional streaming RPC. @@ -97,20 +94,19 @@ func (s *ServerStream[Res]) Send(msg *Res) error { // It's constructed as part of Handler invocation, but doesn't currently have // an exported constructor. type BidiStream[Req, Res any] struct { - sender Sender - receiver Receiver + conn StreamingHandlerConn } // RequestHeader returns the headers received from the client. func (b *BidiStream[Req, Res]) RequestHeader() http.Header { - return b.receiver.Header() + return b.conn.RequestHeader() } // Receive a message. When the client is done sending messages, Receive will // return an error that wraps io.EOF. func (b *BidiStream[Req, Res]) Receive() (*Req, error) { var req Req - if err := b.receiver.Receive(&req); err != nil { + if err := b.conn.Receive(&req); err != nil { return nil, err } return &req, nil @@ -119,20 +115,17 @@ func (b *BidiStream[Req, Res]) Receive() (*Req, error) { // ResponseHeader returns the response headers. Headers are sent with the first // call to Send. func (b *BidiStream[Req, Res]) ResponseHeader() http.Header { - return b.sender.Header() + return b.conn.ResponseHeader() } // ResponseTrailer returns the response trailers. Handlers may write to the // response trailers at any time before returning. func (b *BidiStream[Req, Res]) ResponseTrailer() http.Header { - if trailers, ok := b.sender.Trailer(); ok { - return trailers - } - return make(http.Header) + return b.conn.ResponseTrailer() } // Send a message to the client. The first call to Send also sends the response // headers. func (b *BidiStream[Req, Res]) Send(msg *Res) error { - return b.sender.Send(msg) + return b.conn.Send(msg) } diff --git a/interceptor.go b/interceptor.go index 5dd0be45..d481b696 100644 --- a/interceptor.go +++ b/interceptor.go @@ -18,71 +18,50 @@ import ( "context" ) -// UnaryFunc is the generic signature of a unary RPC. Interceptors wrap Funcs. +// UnaryFunc is the generic signature of a unary RPC. Interceptors may wrap +// Funcs. // // The type of the request and response structs depend on the codec being used. // When using Protobuf, request.Any() and response.Any() will always be // proto.Message implementations. type UnaryFunc func(context.Context, AnyRequest) (AnyResponse, error) +// StreamingClientFunc is the generic signature of a streaming RPC from the client's +// perspective. Interceptors may wrap StreamingClientFuncs. +type StreamingClientFunc func(context.Context, Spec) StreamingClientConn + +// StreamingHandlerFunc is the generic signature of a streaming RPC from the +// handler's perspective. Interceptors may wrap StreamingHandlerFuncs. +type StreamingHandlerFunc func(context.Context, StreamingHandlerConn) error + // An Interceptor adds logic to a generated handler or client, like the // decorators or middleware you may have seen in other libraries. Interceptors -// may replace the context, mutate the request, mutate the response, handle the -// returned error, retry, recover from panics, emit logs and metrics, or do -// nearly anything else. +// may replace the context, mutate requests and responses, handle errors, +// retry, recover from panics, emit logs and metrics, or do nearly anything +// else. +// +// The returned functions must be safe to call concurrently. type Interceptor interface { - // WrapUnary adds logic to a unary procedure. The returned UnaryFunc must be safe - // to call concurrently. WrapUnary(UnaryFunc) UnaryFunc - - // WrapStreamContext, WrapStreamSender, and WrapStreamReceiver work together - // to add logic to streaming procedures. Stream interceptors work in phases. - // First, each interceptor may wrap the request context. Then, the connect - // runtime constructs a (Sender, Receiver) pair. Finally, each interceptor - // may wrap the Sender and/or Receiver. For example, the flow within a - // Handler looks like this: - // - // func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // ctx := r.Context() - // if ic := h.interceptor; ic != nil { - // ctx = ic.WrapStreamContext(ctx) - // } - // sender, receiver := h.newStream(w, r.WithContext(ctx)) - // if ic := h.interceptor; ic != nil { - // sender = ic.WrapStreamSender(ctx, sender) - // receiver = ic.WrapStreamReceiver(ctx, receiver) - // } - // h.serveStream(sender, receiver) - // } - // - // Sender and Receiver implementations don't need to be safe for concurrent - // use. - WrapStreamContext(context.Context) context.Context - WrapStreamSender(context.Context, Sender) Sender - WrapStreamReceiver(context.Context, Receiver) Receiver + WrapStreamingClient(StreamingClientFunc) StreamingClientFunc + WrapStreamingHandler(StreamingHandlerFunc) StreamingHandlerFunc } // UnaryInterceptorFunc is a simple Interceptor implementation that only -// wraps unary RPCs. It has no effect on client, server, or bidirectional -// streaming RPCs. +// wraps unary RPCs. It has no effect on streaming RPCs. type UnaryInterceptorFunc func(UnaryFunc) UnaryFunc // WrapUnary implements Interceptor by applying the interceptor function. func (f UnaryInterceptorFunc) WrapUnary(next UnaryFunc) UnaryFunc { return f(next) } -// WrapStreamContext implements Interceptor with a no-op. -func (f UnaryInterceptorFunc) WrapStreamContext(ctx context.Context) context.Context { - return ctx -} - -// WrapStreamSender implements Interceptor with a no-op. -func (f UnaryInterceptorFunc) WrapStreamSender(_ context.Context, sender Sender) Sender { - return sender +// WrapStreamingClient implements Interceptor with a no-op. +func (f UnaryInterceptorFunc) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc { + return next } -// WrapStreamReceiver implements Interceptor with a no-op. -func (f UnaryInterceptorFunc) WrapStreamReceiver(_ context.Context, receiver Receiver) Receiver { - return receiver +// WrapStreamingHandler implements Interceptor with a no-op. +func (f UnaryInterceptorFunc) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandlerFunc { + return next } // A chain composes multiple interceptors into one. @@ -90,8 +69,6 @@ type chain struct { interceptors []Interceptor } -var _ Interceptor = (*chain)(nil) - // newChain composes multiple interceptors into one. func newChain(interceptors []Interceptor) *chain { // We usually wrap in reverse order to have the first interceptor from @@ -113,33 +90,16 @@ func (c *chain) WrapUnary(next UnaryFunc) UnaryFunc { return next } -func (c *chain) WrapStreamContext(ctx context.Context) context.Context { +func (c *chain) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc { for _, interceptor := range c.interceptors { - ctx = interceptor.WrapStreamContext(ctx) + next = interceptor.WrapStreamingClient(next) } - return ctx -} - -func (c *chain) WrapStreamSender(ctx context.Context, sender Sender) Sender { - if sender.Spec().IsClient { - for _, interceptor := range c.interceptors { - sender = interceptor.WrapStreamSender(ctx, sender) - } - return sender - } - // When we're wrapping senders on the handler side, we need to wrap in the - // opposite order. See TestOnionOrderingEndToEnd. - for i := len(c.interceptors) - 1; i >= 0; i-- { - if interceptor := c.interceptors[i]; interceptor != nil { - sender = interceptor.WrapStreamSender(ctx, sender) - } - } - return sender + return next } -func (c *chain) WrapStreamReceiver(ctx context.Context, receiver Receiver) Receiver { +func (c *chain) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandlerFunc { for _, interceptor := range c.interceptors { - receiver = interceptor.WrapStreamReceiver(ctx, receiver) + next = interceptor.WrapStreamingHandler(next) } - return receiver + return next } diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 62a3544c..81ab3d96 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -18,7 +18,6 @@ import ( "context" "net/http" "net/http/httptest" - "strings" "testing" "github.com/bufbuild/connect-go" @@ -27,68 +26,6 @@ import ( "github.com/bufbuild/connect-go/internal/gen/connect/ping/v1/pingv1connect" ) -func TestClientStreamErrors(t *testing.T) { - t.Parallel() - _, err := pingv1connect. - NewPingServiceClient(http.DefaultClient, "INVALID_URL"). - Ping(context.Background(), nil) - assert.NotNil(t, err) - assert.Match(t, err.Error(), "missing scheme") - // We don't even get to calling methods on the client, so there's no question - // of interceptors running. Once we're calling methods on the client, all - // errors are visible to interceptors. -} - -func TestHandlerStreamErrors(t *testing.T) { - t.Parallel() - // If we receive an HTTP request and send a response, interceptors should - // fire - even if we can't successfully set up a stream. (This is different - // from clients, where stream creation fails before any HTTP request is - // issued.) - var called bool - reset := func() { - called = false - } - mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{}, - connect.WithInterceptors(&assertCalledInterceptor{&called}), - )) - server := httptest.NewServer(mux) - defer server.Close() - - t.Run("unary", func(t *testing.T) { // nolint:paralleltest - defer reset() - request, err := http.NewRequest( - http.MethodPost, - server.URL+"/connect.ping.v1.PingService/Ping", - strings.NewReader(""), - ) - assert.Nil(t, err) - request.Header.Set("Content-Type", "application/grpc+proto") - request.Header.Set("Grpc-Timeout", "INVALID") - res, err := server.Client().Do(request) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) - assert.True(t, called) - }) - t.Run("stream", func(t *testing.T) { // nolint:paralleltest - defer reset() - request, err := http.NewRequest( - http.MethodPost, - server.URL+"/connect.ping.v1.PingService/CountUp", - strings.NewReader(""), - ) - assert.Nil(t, err) - request.Header.Set("Content-Type", "application/grpc+proto") - request.Header.Set("Grpc-Timeout", "INVALID") - res, err := server.Client().Do(request) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) - assert.True(t, called) - }) -} - func TestOnionOrderingEndToEnd(t *testing.T) { t.Parallel() // Helper function: returns a function that asserts that there's some value @@ -188,10 +125,18 @@ func TestOnionOrderingEndToEnd(t *testing.T) { server.URL, clientOnion, ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) assert.Nil(t, err) - _, err = client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + + responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) assert.Nil(t, err) + var sum int64 + for responses.Receive() { + sum += responses.Msg().Number + } + assert.Equal(t, sum, 55) + assert.Nil(t, responses.Close()) } // headerInterceptor makes it easier to write interceptors that inspect or @@ -236,84 +181,63 @@ func (h *headerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc } } -func (h *headerInterceptor) WrapStreamContext(ctx context.Context) context.Context { - return ctx -} - -// WrapStreamSender implements Interceptor. Depending on whether it's operating -// on a client or handler, it wraps the sender with the request- or -// response-inspecting function. -func (h *headerInterceptor) WrapStreamSender(ctx context.Context, sender connect.Sender) connect.Sender { - if sender.Spec().IsClient { - return &headerInspectingSender{Sender: sender, inspect: h.inspectRequestHeader} +func (h *headerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + return &headerInspectingClientConn{ + StreamingClientConn: next(ctx, spec), + inspectRequestHeader: h.inspectRequestHeader, + inspectResponseHeader: h.inspectResponseHeader, + } } - return &headerInspectingSender{Sender: sender, inspect: h.inspectResponseHeader} } -// WrapStreamReceiver implements Interceptor. Depending on whether it's -// operating on a client or handler, it wraps the sender with the response- or -// request-inspecting function. -func (h *headerInterceptor) WrapStreamReceiver(ctx context.Context, receiver connect.Receiver) connect.Receiver { - if receiver.Spec().IsClient { - return &headerInspectingReceiver{Receiver: receiver, inspect: h.inspectResponseHeader} +func (h *headerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + h.inspectRequestHeader(conn.Spec(), conn.RequestHeader()) + return next(ctx, &headerInspectingHandlerConn{ + StreamingHandlerConn: conn, + inspectResponseHeader: h.inspectResponseHeader, + }) } - return &headerInspectingReceiver{Receiver: receiver, inspect: h.inspectRequestHeader} } -type headerInspectingSender struct { - connect.Sender +type headerInspectingHandlerConn struct { + connect.StreamingHandlerConn - called bool // senders don't need to be thread-safe - inspect func(connect.Spec, http.Header) + inspectedResponse bool + inspectResponseHeader func(connect.Spec, http.Header) } -func (s *headerInspectingSender) Send(m any) error { - if !s.called { - s.inspect(s.Spec(), s.Header()) - s.called = true +func (hc *headerInspectingHandlerConn) Send(msg any) error { + if !hc.inspectedResponse { + hc.inspectResponseHeader(hc.Spec(), hc.ResponseHeader()) + hc.inspectedResponse = true } - return s.Sender.Send(m) + return hc.StreamingHandlerConn.Send(msg) } -type headerInspectingReceiver struct { - connect.Receiver +type headerInspectingClientConn struct { + connect.StreamingClientConn - called bool // receivers don't need to be thread-safe - inspect func(connect.Spec, http.Header) + inspectedRequest bool + inspectRequestHeader func(connect.Spec, http.Header) + inspectedResponse bool + inspectResponseHeader func(connect.Spec, http.Header) } -func (r *headerInspectingReceiver) Receive(m any) error { - if !r.called { - r.inspect(r.Spec(), r.Header()) - r.called = true - } - if err := r.Receiver.Receive(m); err != nil { - return err +func (cc *headerInspectingClientConn) Send(msg any) error { + if !cc.inspectedRequest { + cc.inspectRequestHeader(cc.Spec(), cc.RequestHeader()) + cc.inspectedRequest = true } - return nil -} - -type assertCalledInterceptor struct { - called *bool + return cc.StreamingClientConn.Send(msg) } -func (i *assertCalledInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { - return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { - *i.called = true - return next(ctx, req) +func (cc *headerInspectingClientConn) Receive(msg any) error { + err := cc.StreamingClientConn.Receive(msg) + if !cc.inspectedResponse { + cc.inspectResponseHeader(cc.Spec(), cc.ResponseHeader()) + cc.inspectedResponse = true } -} - -func (i *assertCalledInterceptor) WrapStreamContext(ctx context.Context) context.Context { - return ctx -} - -func (i *assertCalledInterceptor) WrapStreamSender(_ context.Context, sender connect.Sender) connect.Sender { - *i.called = true - return sender -} - -func (i *assertCalledInterceptor) WrapStreamReceiver(_ context.Context, receiver connect.Receiver) connect.Receiver { - *i.called = true - return receiver + return err } diff --git a/protocol.go b/protocol.go index 7bc58641..ce21fae2 100644 --- a/protocol.go +++ b/protocol.go @@ -83,19 +83,8 @@ type protocolHandler interface { // request's context, a nil cancellation function, and a nil error. SetTimeout(*http.Request) (context.Context, context.CancelFunc, error) - // NewStream constructs a Sender and Receiver for the message exchange. - // - // Implementations may decide whether the returned error should be sent to - // the client. (For example, it's helpful to send the client a list of - // supported compressors if they use an unknown compressor.) - // - // In either case, any returned error is passed through the full interceptor - // stack. - // - // TODO: Implementations _must_ return a usable Sender and Receiver, even if - // they're also returning an error. If we ever export this interface, we - // should revert https://github.com/bufbuild/connect-go/pull/290. - NewStream(http.ResponseWriter, *http.Request) (Sender, Receiver, error) + // NewConn constructs a HandlerConn for the message exchange. + NewConn(http.ResponseWriter, *http.Request) (handlerConnCloser, bool) } // ClientParams are the arguments provided to a Protocol's NewClient method, @@ -122,85 +111,84 @@ type protocolClient interface { // WriteRequestHeader writes any protocol-specific request headers. WriteRequestHeader(StreamType, http.Header) - // NewStream constructs a Sender and Receiver for the message exchange. + // NewConn constructs a StreamingClientConn for the message exchange. // // Implementations should assume that the supplied HTTP headers have already // been populated by WriteRequestHeader. When constructing a stream for a // unary call, implementations may assume that the Sender's Send and Close // methods return before the Receiver's Receive or Close methods are called. - NewStream(context.Context, Spec, http.Header) (Sender, Receiver) + NewConn(context.Context, Spec, http.Header) StreamingClientConn } -// errorTranslatingSender wraps a Sender to ensure that we always return coded -// errors to clients and write coded errors to the network. +// errorTranslatingHandlerConnCloser wraps a handlerConnCloser to ensure that +// we always return coded errors to users and write coded errors to the +// network. // -// This is used in protocol implementations. -type errorTranslatingSender struct { - Sender +// It's used in protocol implementations. +type errorTranslatingHandlerConnCloser struct { + handlerConnCloser toWire func(error) error fromWire func(error) error } -func (s *errorTranslatingSender) Send(msg any) error { - return s.fromWire(s.Sender.Send(msg)) +func (hc *errorTranslatingHandlerConnCloser) Send(msg any) error { + return hc.fromWire(hc.handlerConnCloser.Send(msg)) } -func (s *errorTranslatingSender) Close(err error) error { - sendErr := s.Sender.Close(s.toWire(err)) - return s.fromWire(sendErr) +func (hc *errorTranslatingHandlerConnCloser) Receive(msg any) error { + return hc.fromWire(hc.handlerConnCloser.Receive(msg)) } -// errorTranslatingReceiver wraps a Receiver to make sure that we always return -// coded errors from clients. +func (hc *errorTranslatingHandlerConnCloser) Close(err error) error { + closeErr := hc.handlerConnCloser.Close(hc.toWire(err)) + return hc.fromWire(closeErr) +} + +// errorTranslatingClientConn wraps a StreamingClientConn to make sure that we always +// return coded errors from clients. // -// This is used in protocol implementations. -type errorTranslatingReceiver struct { - Receiver +// It's used in protocol implementations. +type errorTranslatingClientConn struct { + StreamingClientConn fromWire func(error) error } -func (r *errorTranslatingReceiver) Receive(msg any) error { - if err := r.Receiver.Receive(msg); err != nil { - return r.fromWire(err) - } - return nil +func (cc *errorTranslatingClientConn) Send(msg any) error { + return cc.fromWire(cc.StreamingClientConn.Send(msg)) } -func (r *errorTranslatingReceiver) Close() error { - return r.fromWire(r.Receiver.Close()) +func (cc *errorTranslatingClientConn) Receive(msg any) error { + return cc.fromWire(cc.StreamingClientConn.Receive(msg)) } -// wrapHandlerStreamWithCodedErrors ensures that we (1) automatically code +func (cc *errorTranslatingClientConn) CloseRequest() error { + return cc.fromWire(cc.StreamingClientConn.CloseRequest()) +} + +func (cc *errorTranslatingClientConn) CloseResponse() error { + return cc.fromWire(cc.StreamingClientConn.CloseResponse()) +} + +// wrapHandlerConnWithCodedErrors ensures that we (1) automatically code // context-related errors correctly when writing them to the network, and (2) // return *Errors from all exported APIs. -func wrapHandlerStreamWithCodedErrors(sender Sender, receiver Receiver) (Sender, Receiver) { - wrappedSender := &errorTranslatingSender{ - Sender: sender, - toWire: wrapIfContextError, - fromWire: wrapIfUncoded, - } - wrappedReceiver := &errorTranslatingReceiver{ - Receiver: receiver, - fromWire: wrapIfUncoded, +func wrapHandlerConnWithCodedErrors(conn handlerConnCloser) handlerConnCloser { + return &errorTranslatingHandlerConnCloser{ + handlerConnCloser: conn, + toWire: wrapIfContextError, + fromWire: wrapIfUncoded, } - return wrappedSender, wrappedReceiver } -// wrapClientStreamWithCodedErrors ensures that we always return *Errors from +// wrapClientConnWithCodedErrors ensures that we always return *Errors from // public APIs. -func wrapClientStreamWithCodedErrors(sender Sender, receiver Receiver) (Sender, Receiver) { - wrappedSender := &errorTranslatingSender{ - Sender: sender, - toWire: func(err error) error { return err }, // no-op - fromWire: wrapIfUncoded, - } - wrappedReceiver := &errorTranslatingReceiver{ - Receiver: receiver, - fromWire: wrapIfUncoded, +func wrapClientConnWithCodedErrors(conn StreamingClientConn) StreamingClientConn { + return &errorTranslatingClientConn{ + StreamingClientConn: conn, + fromWire: wrapIfUncoded, } - return wrappedSender, wrappedReceiver } func sortedAcceptPostValue(handlers []protocolHandler) string { diff --git a/protocol_connect.go b/protocol_connect.go index 18267f4d..a6da61f0 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -101,10 +101,10 @@ func (*connectHandler) SetTimeout(request *http.Request) (context.Context, conte return ctx, cancel, nil } -func (h *connectHandler) NewStream( +func (h *connectHandler) NewConn( responseWriter http.ResponseWriter, request *http.Request, -) (Sender, Receiver, error) { +) (handlerConnCloser, bool) { // We need to parse metadata before entering the interceptor stack; we'll // send the error to the client later on. var contentEncoding, acceptEncoding string @@ -148,36 +148,36 @@ func (h *connectHandler) NewStream( request.Header.Get(headerContentType), ) codec := h.Codecs.Get(codecName) // handler.go guarantees this is not nil - var sender Sender = &connectUnaryHandlerSender{ - spec: h.Spec, - responseWriter: responseWriter, - trailer: make(http.Header), - marshaler: connectUnaryMarshaler{ - writer: responseWriter, - codec: codec, - compressMinBytes: h.CompressMinBytes, - compressionName: responseCompression, - compressionPool: h.CompressionPools.Get(responseCompression), - bufferPool: h.BufferPool, - header: responseWriter.Header(), - }, - } - var receiver Receiver = &connectUnaryHandlerReceiver{ - spec: h.Spec, - request: request, - unmarshaler: connectUnaryUnmarshaler{ - reader: request.Body, - codec: codec, - compressionPool: h.CompressionPools.Get(requestCompression), - bufferPool: h.BufferPool, - readMaxBytes: h.ReadMaxBytes, - }, - } - if h.Spec.StreamType != StreamTypeUnary { - sender = &connectStreamingHandlerSender{ - spec: h.Spec, - writer: responseWriter, - trailer: make(http.Header), + + var conn handlerConnCloser + if h.Spec.StreamType == StreamTypeUnary { + conn = &connectUnaryHandlerConn{ + spec: h.Spec, + request: request, + responseWriter: responseWriter, + marshaler: connectUnaryMarshaler{ + writer: responseWriter, + codec: codec, + compressMinBytes: h.CompressMinBytes, + compressionName: responseCompression, + compressionPool: h.CompressionPools.Get(responseCompression), + bufferPool: h.BufferPool, + header: responseWriter.Header(), + }, + unmarshaler: connectUnaryUnmarshaler{ + reader: request.Body, + codec: codec, + compressionPool: h.CompressionPools.Get(requestCompression), + bufferPool: h.BufferPool, + readMaxBytes: h.ReadMaxBytes, + }, + responseTrailer: make(http.Header), + } + } else { + conn = &connectStreamingHandlerConn{ + spec: h.Spec, + request: request, + responseWriter: responseWriter, marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ writer: responseWriter, @@ -187,10 +187,6 @@ func (h *connectHandler) NewStream( bufferPool: h.BufferPool, }, }, - } - receiver = &connectStreamingHandlerReceiver{ - spec: h.Spec, - request: request, unmarshaler: connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ reader: request.Body, @@ -200,19 +196,18 @@ func (h *connectHandler) NewStream( readMaxBytes: h.ReadMaxBytes, }, }, + responseTrailer: make(http.Header), } } - sender, receiver = wrapHandlerStreamWithCodedErrors(sender, receiver) + conn = wrapHandlerConnWithCodedErrors(conn) // We can't return failed as-is: a nil *Error is non-nil when returned as an // error interface. if failed != nil { - // Negotiation failed, so we can't establish a stream. To make the - // request's HTTP trailers visible to interceptors, we should try to read - // the body to EOF. - _ = discard(request.Body) - return sender, receiver, failed + // Negotiation failed, so we can't establish a stream. + _ = conn.Close(failed) + return nil, false } - return sender, receiver, nil + return conn, true } type connectClient struct { @@ -246,11 +241,11 @@ func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.He } } -func (c *connectClient) NewStream( +func (c *connectClient) NewConn( ctx context.Context, spec Spec, header http.Header, -) (Sender, Receiver) { +) StreamingClientConn { if deadline, ok := ctx.Deadline(); ok { millis := int64(time.Until(deadline) / time.Millisecond) if millis > 0 { @@ -261,13 +256,14 @@ func (c *connectClient) NewStream( } } duplexCall := newDuplexHTTPCall(ctx, c.HTTPClient, c.URL, spec, header) - var sender Sender - var receiver Receiver + var conn StreamingClientConn if spec.StreamType == StreamTypeUnary { - unarySender := &connectClientSender{ - spec: spec, - duplexCall: duplexCall, - marshaler: &connectUnaryMarshaler{ + unaryConn := &connectUnaryClientConn{ + spec: spec, + duplexCall: duplexCall, + compressionPools: c.CompressionPools, + bufferPool: c.BufferPool, + marshaler: connectUnaryMarshaler{ writer: duplexCall, codec: c.Codec, compressMinBytes: c.CompressMinBytes, @@ -276,29 +272,25 @@ func (c *connectClient) NewStream( bufferPool: c.BufferPool, header: duplexCall.Header(), }, - } - sender = unarySender - unaryReceiver := &connectUnaryClientReceiver{ - spec: spec, - duplexCall: duplexCall, - compressionPools: c.CompressionPools, - bufferPool: c.BufferPool, - header: make(http.Header), - trailer: make(http.Header), unmarshaler: connectUnaryUnmarshaler{ reader: duplexCall, codec: c.Codec, bufferPool: c.BufferPool, readMaxBytes: c.ReadMaxBytes, }, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), } - receiver = unaryReceiver - duplexCall.SetValidateResponse(unaryReceiver.validateResponse) + conn = unaryConn + duplexCall.SetValidateResponse(unaryConn.validateResponse) } else { - streamingSender := &connectClientSender{ - spec: spec, - duplexCall: duplexCall, - marshaler: &connectStreamingMarshaler{ + streamingConn := &connectStreamingClientConn{ + spec: spec, + duplexCall: duplexCall, + compressionPools: c.CompressionPools, + bufferPool: c.BufferPool, + codec: c.Codec, + marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ writer: duplexCall, codec: c.Codec, @@ -307,16 +299,6 @@ func (c *connectClient) NewStream( bufferPool: c.BufferPool, }, }, - } - sender = streamingSender - streamingReceiver := &connectStreamingClientReceiver{ - spec: spec, - bufferPool: c.BufferPool, - compressionPools: c.CompressionPools, - codec: c.Codec, - header: make(http.Header), - trailer: make(http.Header), - duplexCall: duplexCall, unmarshaler: connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ reader: duplexCall, @@ -325,317 +307,253 @@ func (c *connectClient) NewStream( readMaxBytes: c.ReadMaxBytes, }, }, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), } - receiver = streamingReceiver - duplexCall.SetValidateResponse(streamingReceiver.validateResponse) + conn = streamingConn + duplexCall.SetValidateResponse(streamingConn.validateResponse) } - return wrapClientStreamWithCodedErrors(sender, receiver) + return wrapClientConnWithCodedErrors(conn) } -// connectClientSender works equally well for unary and streaming, since it can -// use either marshaler. -type connectClientSender struct { - spec Spec - duplexCall *duplexHTTPCall - marshaler interface{ Marshal(any) *Error } -} - -func (s *connectClientSender) Spec() Spec { - return s.spec -} - -func (s *connectClientSender) Header() http.Header { - return s.duplexCall.Header() +type connectUnaryClientConn struct { + spec Spec + duplexCall *duplexHTTPCall + compressionPools readOnlyCompressionPools + bufferPool *bufferPool + marshaler connectUnaryMarshaler + unmarshaler connectUnaryUnmarshaler + responseHeader http.Header + responseTrailer http.Header } -func (s *connectClientSender) Trailer() (http.Header, bool) { - return nil, false +func (cc *connectUnaryClientConn) Spec() Spec { + return cc.spec } -func (s *connectClientSender) Send(message any) error { - if err := s.marshaler.Marshal(message); err != nil { +func (cc *connectUnaryClientConn) Send(msg any) error { + if err := cc.marshaler.Marshal(msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error } -func (s *connectClientSender) Close(error) error { - return s.duplexCall.CloseWrite() +func (cc *connectUnaryClientConn) RequestHeader() http.Header { + return cc.duplexCall.Header() } -type connectStreamingClientReceiver struct { - spec Spec - bufferPool *bufferPool - compressionPools readOnlyCompressionPools - codec Codec - header http.Header - trailer http.Header - duplexCall *duplexHTTPCall - unmarshaler connectStreamingUnmarshaler +func (cc *connectUnaryClientConn) CloseRequest() error { + return cc.duplexCall.CloseWrite() } -func (r *connectStreamingClientReceiver) Spec() Spec { - return r.spec +func (cc *connectUnaryClientConn) Receive(msg any) error { + cc.duplexCall.BlockUntilResponseReady() + if err := cc.unmarshaler.Unmarshal(msg); err != nil { + return err + } + return nil // must be a literal nil: nil *Error is a non-nil error } -func (r *connectStreamingClientReceiver) Header() http.Header { - r.duplexCall.BlockUntilResponseReady() - return r.header +func (cc *connectUnaryClientConn) ResponseHeader() http.Header { + cc.duplexCall.BlockUntilResponseReady() + return cc.responseHeader } -func (r *connectStreamingClientReceiver) Trailer() (http.Header, bool) { - r.duplexCall.BlockUntilResponseReady() - return r.trailer, true +func (cc *connectUnaryClientConn) ResponseTrailer() http.Header { + cc.duplexCall.BlockUntilResponseReady() + return cc.responseTrailer } -func (r *connectStreamingClientReceiver) Receive(message any) error { - r.duplexCall.BlockUntilResponseReady() - err := r.unmarshaler.Unmarshal(message) - if err == nil { - return nil - } - // See if the server sent an explicit error in the end-of-stream message. - mergeHeaders(r.trailer, r.unmarshaler.Trailer()) - if serverErr := r.unmarshaler.EndStreamError(); serverErr != nil { - // This is expected from a protocol perspective, but receiving an - // end-of-stream message means that we're _not_ getting a regular message. - // For users to realize that the stream has ended, Receive must return an - // error. - serverErr.meta = r.header.Clone() - mergeHeaders(serverErr.meta, r.trailer) - r.duplexCall.SetError(serverErr) - return serverErr - } - // There's no error in the trailers, so this was probably an error - // converting the bytes to a message, an error reading from the network, or - // just an EOF. We're going to return it to the user, but we also want to - // setResponseError so Send errors out. - r.duplexCall.SetError(err) - return err -} - -func (r *connectStreamingClientReceiver) Close() error { - return r.duplexCall.CloseRead() +func (cc *connectUnaryClientConn) CloseResponse() error { + return cc.duplexCall.CloseRead() } -// validateResponse is called by duplexHTTPCall in a separate goroutine. -func (r *connectStreamingClientReceiver) validateResponse(response *http.Response) *Error { - if response.StatusCode != http.StatusOK { - return errorf(connectHTTPToCode(response.StatusCode), "HTTP status %v", response.Status) +func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Error { + for k, v := range response.Header { + if !strings.HasPrefix(k, connectUnaryTrailerPrefix) { + cc.responseHeader[k] = v + continue + } + cc.responseTrailer[strings.TrimPrefix(k, connectUnaryTrailerPrefix)] = v } - compression := response.Header.Get(connectStreamingHeaderCompression) + compression := response.Header.Get(connectUnaryHeaderCompression) if compression != "" && compression != compressionIdentity && - !r.compressionPools.Contains(compression) { + !cc.compressionPools.Contains(compression) { return errorf( CodeInternal, "unknown encoding %q: accepted encodings are %v", compression, - r.compressionPools.CommaSeparatedNames(), + cc.compressionPools.CommaSeparatedNames(), + ) + } + if response.StatusCode != http.StatusOK { + unmarshaler := connectUnaryUnmarshaler{ + reader: response.Body, + compressionPool: cc.compressionPools.Get(compression), + bufferPool: cc.bufferPool, + } + var serverErr Error + if err := unmarshaler.UnmarshalFunc( + (*connectWireError)(&serverErr), + json.Unmarshal, + ); err == nil { + serverErr.meta = cc.responseHeader.Clone() + mergeHeaders(serverErr.meta, cc.responseTrailer) + return &serverErr + } + return NewError( + connectHTTPToCode(response.StatusCode), + errors.New(response.Status), ) } - r.unmarshaler.compressionPool = r.compressionPools.Get(compression) - mergeHeaders(r.header, response.Header) + cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression) return nil } -type connectStreamingHandlerSender struct { - spec Spec - marshaler connectStreamingMarshaler - writer http.ResponseWriter - trailer http.Header +type connectStreamingClientConn struct { + spec Spec + duplexCall *duplexHTTPCall + compressionPools readOnlyCompressionPools + bufferPool *bufferPool + codec Codec + marshaler connectStreamingMarshaler + unmarshaler connectStreamingUnmarshaler + responseHeader http.Header + responseTrailer http.Header } -func (s *connectStreamingHandlerSender) Send(message any) error { - defer flushResponseWriter(s.writer) - if err := s.marshaler.Marshal(message); err != nil { - return err - } - return nil // must be a literal nil: nil *Error is a non-nil error +func (cc *connectStreamingClientConn) Spec() Spec { + return cc.spec } -func (s *connectStreamingHandlerSender) Close(err error) error { - defer flushResponseWriter(s.writer) - if err := s.marshaler.MarshalEndStream(err, s.trailer); err != nil { +func (cc *connectStreamingClientConn) Send(msg any) error { + if err := cc.marshaler.Marshal(msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error } -func (s *connectStreamingHandlerSender) Spec() Spec { - return s.spec +func (cc *connectStreamingClientConn) RequestHeader() http.Header { + return cc.duplexCall.Header() } -func (s *connectStreamingHandlerSender) Header() http.Header { - return s.writer.Header() +func (cc *connectStreamingClientConn) CloseRequest() error { + return cc.duplexCall.CloseWrite() } -func (s *connectStreamingHandlerSender) Trailer() (http.Header, bool) { - return s.trailer, true -} - -type connectStreamingHandlerReceiver struct { - spec Spec - unmarshaler connectStreamingUnmarshaler - request *http.Request -} - -func (r *connectStreamingHandlerReceiver) Receive(message any) error { - if err := r.unmarshaler.Unmarshal(message); err != nil { - // Clients may not send end-of-stream metadata, so we don't need to handle - // errSpecialEnvelope. - return err +func (cc *connectStreamingClientConn) Receive(msg any) error { + cc.duplexCall.BlockUntilResponseReady() + err := cc.unmarshaler.Unmarshal(msg) + if err == nil { + return nil } - return nil // must be a literal nil: nil *Error is a non-nil error -} - -func (r *connectStreamingHandlerReceiver) Close() error { - // We don't want to copy unread portions of the body to /dev/null here: if - // the client hasn't closed the request body, we'll block until the server - // timeout kicks in. This could happen because the client is malicious, but - // a well-intentioned client may just not expect the server to be returning - // an error for a streaming RPC. Better to accept that we can't always reuse - // TCP connections. - if err := r.request.Body.Close(); err != nil { - if connectErr, ok := asError(err); ok { - return connectErr - } - return NewError(CodeUnknown, err) + // See if the server sent an explicit error in the end-of-stream message. + mergeHeaders(cc.responseTrailer, cc.unmarshaler.Trailer()) + if serverErr := cc.unmarshaler.EndStreamError(); serverErr != nil { + // This is expected from a protocol perspective, but receiving an + // end-of-stream message means that we're _not_ getting a regular message. + // For users to realize that the stream has ended, Receive must return an + // error. + serverErr.meta = cc.responseHeader.Clone() + mergeHeaders(serverErr.meta, cc.responseTrailer) + cc.duplexCall.SetError(serverErr) + return serverErr } - return nil -} - -func (r *connectStreamingHandlerReceiver) Spec() Spec { - return r.spec -} - -func (r *connectStreamingHandlerReceiver) Header() http.Header { - return r.request.Header -} - -func (r *connectStreamingHandlerReceiver) Trailer() (http.Header, bool) { - return nil, false -} - -type connectUnaryClientReceiver struct { - spec Spec - duplexCall *duplexHTTPCall - compressionPools readOnlyCompressionPools - bufferPool *bufferPool - - header http.Header - trailer http.Header - unmarshaler connectUnaryUnmarshaler -} - -func (r *connectUnaryClientReceiver) Spec() Spec { - return r.spec -} - -func (r *connectUnaryClientReceiver) Header() http.Header { - r.duplexCall.BlockUntilResponseReady() - return r.header + // There's no error in the trailers, so this was probably an error + // converting the bytes to a message, an error reading from the network, or + // just an EOF. We're going to return it to the user, but we also want to + // setResponseError so Send errors out. + cc.duplexCall.SetError(err) + return err } -func (r *connectUnaryClientReceiver) Trailer() (http.Header, bool) { - r.duplexCall.BlockUntilResponseReady() - return r.trailer, true +func (cc *connectStreamingClientConn) ResponseHeader() http.Header { + cc.duplexCall.BlockUntilResponseReady() + return cc.responseHeader } -func (r *connectUnaryClientReceiver) Receive(message any) error { - r.duplexCall.BlockUntilResponseReady() - if err := r.unmarshaler.Unmarshal(message); err != nil { - return err - } - return nil +func (cc *connectStreamingClientConn) ResponseTrailer() http.Header { + cc.duplexCall.BlockUntilResponseReady() + return cc.responseTrailer } -func (r *connectUnaryClientReceiver) Close() error { - return r.duplexCall.CloseRead() +func (cc *connectStreamingClientConn) CloseResponse() error { + return cc.duplexCall.CloseRead() } -func (r *connectUnaryClientReceiver) validateResponse(response *http.Response) *Error { - for k, v := range response.Header { - if !strings.HasPrefix(k, connectUnaryTrailerPrefix) { - r.header[k] = v - continue - } - r.trailer[strings.TrimPrefix(k, connectUnaryTrailerPrefix)] = v +func (cc *connectStreamingClientConn) validateResponse(response *http.Response) *Error { + if response.StatusCode != http.StatusOK { + return errorf(connectHTTPToCode(response.StatusCode), "HTTP status %v", response.Status) } - compression := response.Header.Get(connectUnaryHeaderCompression) + compression := response.Header.Get(connectStreamingHeaderCompression) if compression != "" && compression != compressionIdentity && - !r.compressionPools.Contains(compression) { + !cc.compressionPools.Contains(compression) { return errorf( CodeInternal, "unknown encoding %q: accepted encodings are %v", compression, - r.compressionPools.CommaSeparatedNames(), - ) - } - if response.StatusCode != http.StatusOK { - unmarshaler := connectUnaryUnmarshaler{ - reader: response.Body, - compressionPool: r.compressionPools.Get(compression), - bufferPool: r.bufferPool, - } - var serverErr Error - if err := unmarshaler.UnmarshalFunc( - (*connectWireError)(&serverErr), - json.Unmarshal, - ); err == nil { - serverErr.meta = r.header.Clone() - mergeHeaders(serverErr.meta, r.trailer) - return &serverErr - } - return NewError( - connectHTTPToCode(response.StatusCode), - errors.New(response.Status), + cc.compressionPools.CommaSeparatedNames(), ) } - r.unmarshaler.compressionPool = r.compressionPools.Get(compression) + cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression) + mergeHeaders(cc.responseHeader, response.Header) return nil } -type connectUnaryHandlerSender struct { - spec Spec - responseWriter http.ResponseWriter - marshaler connectUnaryMarshaler - trailer http.Header - wroteBody bool +type connectUnaryHandlerConn struct { + spec Spec + request *http.Request + responseWriter http.ResponseWriter + marshaler connectUnaryMarshaler + unmarshaler connectUnaryUnmarshaler + responseTrailer http.Header + wroteBody bool } -func (s *connectUnaryHandlerSender) Spec() Spec { - return s.spec +func (hc *connectUnaryHandlerConn) Spec() Spec { + return hc.spec } -func (s *connectUnaryHandlerSender) Header() http.Header { - return s.responseWriter.Header() +func (hc *connectUnaryHandlerConn) Receive(msg any) error { + if err := hc.unmarshaler.Unmarshal(msg); err != nil { + return err + } + return nil // must be a literal nil: nil *Error is a non-nil error } -func (s *connectUnaryHandlerSender) Trailer() (http.Header, bool) { - return s.trailer, true +func (hc *connectUnaryHandlerConn) RequestHeader() http.Header { + return hc.request.Header } -func (s *connectUnaryHandlerSender) Send(message any) error { - s.wroteBody = true - s.writeHeader(nil /* error */) - if err := s.marshaler.Marshal(message); err != nil { +func (hc *connectUnaryHandlerConn) Send(msg any) error { + hc.wroteBody = true + hc.writeResponseHeader(nil /* error */) + if err := hc.marshaler.Marshal(msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error } -func (s *connectUnaryHandlerSender) Close(err error) error { - if !s.wroteBody { - s.writeHeader(err) +func (hc *connectUnaryHandlerConn) ResponseHeader() http.Header { + return hc.responseWriter.Header() +} + +func (hc *connectUnaryHandlerConn) ResponseTrailer() http.Header { + return hc.responseTrailer +} + +func (hc *connectUnaryHandlerConn) Close(err error) error { + if !hc.wroteBody { + hc.writeResponseHeader(err) } if err == nil { - return nil + return hc.request.Body.Close() } // In unary Connect, errors always use application/json. - s.responseWriter.Header().Set(headerContentType, connectUnaryContentTypeJSON) - s.responseWriter.WriteHeader(connectCodeToHTTP(CodeOf(err))) + hc.responseWriter.Header().Set(headerContentType, connectUnaryContentTypeJSON) + hc.responseWriter.WriteHeader(connectCodeToHTTP(CodeOf(err))) var wire *connectWireError if connectErr, ok := asError(err); ok { wire = (*connectWireError)(connectErr) @@ -644,51 +562,89 @@ func (s *connectUnaryHandlerSender) Close(err error) error { } data, marshalErr := json.Marshal(wire) if marshalErr != nil { + _ = hc.request.Body.Close() return errorf(CodeInternal, "marshal error: %w", err) } - _, writeErr := s.responseWriter.Write(data) - return writeErr + if _, writeErr := hc.responseWriter.Write(data); writeErr != nil { + _ = hc.request.Body.Close() + return writeErr + } + return hc.request.Body.Close() } -func (s *connectUnaryHandlerSender) writeHeader(err error) { - header := s.responseWriter.Header() +func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { + header := hc.responseWriter.Header() if err != nil { if connectErr, ok := asError(err); ok { mergeHeaders(header, connectErr.meta) } } - for k, v := range s.trailer { + for k, v := range hc.responseTrailer { header[connectUnaryTrailerPrefix+k] = v } } -type connectUnaryHandlerReceiver struct { - spec Spec - request *http.Request - unmarshaler connectUnaryUnmarshaler +type connectStreamingHandlerConn struct { + spec Spec + request *http.Request + responseWriter http.ResponseWriter + marshaler connectStreamingMarshaler + unmarshaler connectStreamingUnmarshaler + responseTrailer http.Header } -func (r *connectUnaryHandlerReceiver) Spec() Spec { - return r.spec +func (hc *connectStreamingHandlerConn) Spec() Spec { + return hc.spec } -func (r *connectUnaryHandlerReceiver) Header() http.Header { - return r.request.Header +func (hc *connectStreamingHandlerConn) Receive(msg any) error { + if err := hc.unmarshaler.Unmarshal(msg); err != nil { + // Clients may not send end-of-stream metadata, so we don't need to handle + // errSpecialEnvelope. + return err + } + return nil // must be a literal nil: nil *Error is a non-nil error } -func (r *connectUnaryHandlerReceiver) Trailer() (http.Header, bool) { - return nil, false +func (hc *connectStreamingHandlerConn) RequestHeader() http.Header { + return hc.request.Header } -func (r *connectUnaryHandlerReceiver) Receive(message any) error { - if err := r.unmarshaler.Unmarshal(message); err != nil { +func (hc *connectStreamingHandlerConn) Send(msg any) error { + defer flushResponseWriter(hc.responseWriter) + if err := hc.marshaler.Marshal(msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error } -func (r *connectUnaryHandlerReceiver) Close() error { - return r.request.Body.Close() +func (hc *connectStreamingHandlerConn) ResponseHeader() http.Header { + return hc.responseWriter.Header() +} + +func (hc *connectStreamingHandlerConn) ResponseTrailer() http.Header { + return hc.responseTrailer +} + +func (hc *connectStreamingHandlerConn) Close(err error) error { + defer flushResponseWriter(hc.responseWriter) + if err := hc.marshaler.MarshalEndStream(err, hc.responseTrailer); err != nil { + _ = hc.request.Body.Close() + return err + } + // We don't want to copy unread portions of the body to /dev/null here: if + // the client hasn't closed the request body, we'll block until the server + // timeout kicks in. This could happen because the client is malicious, but + // a well-intentioned client may just not expect the server to be returning + // an error for a streaming RPC. Better to accept that we can't always reuse + // TCP connections. + if err := hc.request.Body.Close(); err != nil { + if connectErr, ok := asError(err); ok { + return connectErr + } + return NewError(CodeUnknown, err) + } + return nil // must be a literal nil: nil *Error is a non-nil error } type connectStreamingMarshaler struct { diff --git a/protocol_grpc.go b/protocol_grpc.go index 3f138756..13f8a5c8 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -133,10 +133,10 @@ func (*grpcHandler) SetTimeout(request *http.Request) (context.Context, context. return ctx, cancel, nil } -func (g *grpcHandler) NewStream( +func (g *grpcHandler) NewConn( responseWriter http.ResponseWriter, request *http.Request, -) (Sender, Receiver, error) { +) (handlerConnCloser, bool) { // We need to parse metadata before entering the interceptor stack; we'll // send the error to the client later on. requestCompression, responseCompression, failed := negotiateCompression( @@ -160,27 +160,42 @@ func (g *grpcHandler) NewStream( } codecName := grpcCodecFromContentType(g.web, request.Header.Get(headerContentType)) - sender, receiver := wrapHandlerStreamWithCodedErrors(newGRPCHandlerStream( - g.Spec, - g.web, - responseWriter, - request, - g.CompressMinBytes, - g.Codecs.Get(codecName), // handler.go guarantees that this is not nil - g.Codecs.Protobuf(), // for errors - g.CompressionPools.Get(requestCompression), - g.CompressionPools.Get(responseCompression), - g.BufferPool, - g.ReadMaxBytes, - )) + codec := g.Codecs.Get(codecName) // handler.go guarantees this is not nil + conn := wrapHandlerConnWithCodedErrors(&grpcHandlerConn{ + spec: g.Spec, + web: g.web, + bufferPool: g.BufferPool, + protobuf: g.Codecs.Protobuf(), // for errors + marshaler: grpcMarshaler{ + envelopeWriter: envelopeWriter{ + writer: responseWriter, + compressionPool: g.CompressionPools.Get(responseCompression), + codec: codec, + compressMinBytes: g.CompressMinBytes, + bufferPool: g.BufferPool, + }, + }, + responseWriter: responseWriter, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), + request: request, + unmarshaler: grpcUnmarshaler{ + envelopeReader: envelopeReader{ + reader: request.Body, + codec: codec, + compressionPool: g.CompressionPools.Get(requestCompression), + bufferPool: g.BufferPool, + readMaxBytes: g.ReadMaxBytes, + }, + web: g.web, + }, + }) if failed != nil { - // Negotiation failed, so we can't establish a stream. To make the - // request's HTTP trailers visible to interceptors, we should try to read - // the body to EOF. - _ = discard(request.Body) - return sender, receiver, failed + // Negotiation failed, so we can't establish a stream. + _ = conn.Close(failed) + return nil, false } - return sender, receiver, nil + return conn, true } type grpcClient struct { @@ -211,11 +226,11 @@ func (g *grpcClient) WriteRequestHeader(_ StreamType, header http.Header) { } } -func (g *grpcClient) NewStream( +func (g *grpcClient) NewConn( ctx context.Context, spec Spec, header http.Header, -) (Sender, Receiver) { +) StreamingClientConn { if deadline, ok := ctx.Deadline(); ok { if encodedDeadline, err := grpcEncodeTimeout(time.Until(deadline)); err == nil { // Tests verify that the error in encodeTimeout is unreachable, so we @@ -230,9 +245,12 @@ func (g *grpcClient) NewStream( spec, header, ) - sender := &grpcClientSender{ - spec: spec, - duplexCall: duplexCall, + conn := &grpcClientConn{ + spec: spec, + duplexCall: duplexCall, + compressionPools: g.CompressionPools, + bufferPool: g.BufferPool, + protobuf: g.Protobuf, marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ writer: duplexCall, @@ -242,135 +260,84 @@ func (g *grpcClient) NewStream( bufferPool: g.BufferPool, }, }, + unmarshaler: grpcUnmarshaler{ + envelopeReader: envelopeReader{ + reader: duplexCall, + codec: g.Codec, + bufferPool: g.BufferPool, + readMaxBytes: g.ReadMaxBytes, + }, + }, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), } - var receiver Receiver + duplexCall.SetValidateResponse(conn.validateResponse) if g.web { - webReceiver := &grpcClientReceiver{ - spec: spec, - bufferPool: g.BufferPool, - compressionPools: g.CompressionPools, - protobuf: g.Protobuf, - header: make(http.Header), - trailer: make(http.Header), - duplexCall: duplexCall, - unmarshaler: grpcUnmarshaler{ - web: true, - envelopeReader: envelopeReader{ - reader: duplexCall, - codec: g.Codec, - bufferPool: g.BufferPool, - readMaxBytes: g.ReadMaxBytes, - }, - }, - readTrailers: func(unmarshaler *grpcUnmarshaler, _ *duplexHTTPCall) http.Header { - return unmarshaler.WebTrailer() - }, + conn.unmarshaler.web = true + conn.readTrailers = func(unmarshaler *grpcUnmarshaler, _ *duplexHTTPCall) http.Header { + return unmarshaler.WebTrailer() } - receiver = webReceiver - duplexCall.SetValidateResponse(webReceiver.validateResponse) } else { - grpcReceiver := &grpcClientReceiver{ - spec: spec, - bufferPool: g.BufferPool, - compressionPools: g.CompressionPools, - protobuf: g.Protobuf, - header: make(http.Header), - trailer: make(http.Header), - duplexCall: duplexCall, - unmarshaler: grpcUnmarshaler{ - web: false, - envelopeReader: envelopeReader{ - reader: duplexCall, - codec: g.Codec, - bufferPool: g.BufferPool, - readMaxBytes: g.ReadMaxBytes, - }, - }, - readTrailers: func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { - // To access HTTP trailers, we need to read the body to EOF. - _ = discard(call) - return call.ResponseTrailer() - }, + conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { + // To access HTTP trailers, we need to read the body to EOF. + _ = discard(call) + return call.ResponseTrailer() } - receiver = grpcReceiver - duplexCall.SetValidateResponse(grpcReceiver.validateResponse) } - return wrapClientStreamWithCodedErrors(sender, receiver) -} - -// grpcClientSender works for both gRPC and gRPC-Web. From our perspective, the -// protocols differ only in how trailers are sent, and clients aren't allowed -// to send trailers. -type grpcClientSender struct { - spec Spec - duplexCall *duplexHTTPCall - marshaler grpcMarshaler -} - -func (s *grpcClientSender) Spec() Spec { - return s.spec + return wrapClientConnWithCodedErrors(conn) } -func (s *grpcClientSender) Header() http.Header { - return s.duplexCall.Header() -} - -func (s *grpcClientSender) Trailer() (http.Header, bool) { - return nil, false -} - -func (s *grpcClientSender) Send(message any) error { - if err := s.marshaler.Marshal(message); err != nil { - return err - } - return nil // must be a literal nil: nil *Error is a non-nil error -} - -func (s *grpcClientSender) Close(_ error) error { - return s.duplexCall.CloseWrite() -} - -type grpcClientReceiver struct { +// grpcClientConn works for both gRPC and gRPC-Web. +type grpcClientConn struct { spec Spec + duplexCall *duplexHTTPCall compressionPools readOnlyCompressionPools bufferPool *bufferPool protobuf Codec // for errors - header http.Header - trailer http.Header - duplexCall *duplexHTTPCall + marshaler grpcMarshaler unmarshaler grpcUnmarshaler + responseHeader http.Header + responseTrailer http.Header readTrailers func(*grpcUnmarshaler, *duplexHTTPCall) http.Header } -func (r *grpcClientReceiver) Spec() Spec { - return r.spec +func (cc *grpcClientConn) Spec() Spec { + return cc.spec } -func (r *grpcClientReceiver) Header() http.Header { - r.duplexCall.BlockUntilResponseReady() - return r.header +func (cc *grpcClientConn) Send(msg any) error { + if err := cc.marshaler.Marshal(msg); err != nil { + return err + } + return nil // must be a literal nil: nil *Error is a non-nil error } -func (r *grpcClientReceiver) Trailer() (http.Header, bool) { - r.duplexCall.BlockUntilResponseReady() - return r.trailer, true +func (cc *grpcClientConn) RequestHeader() http.Header { + return cc.duplexCall.Header() } -func (r *grpcClientReceiver) Receive(message any) error { - r.duplexCall.BlockUntilResponseReady() - err := r.unmarshaler.Unmarshal(message) +func (cc *grpcClientConn) CloseRequest() error { + return cc.duplexCall.CloseWrite() +} + +func (cc *grpcClientConn) Receive(msg any) error { + cc.duplexCall.BlockUntilResponseReady() + err := cc.unmarshaler.Unmarshal(msg) if err == nil { return nil } - if r.header.Get(grpcHeaderStatus) != "" { + if cc.responseHeader.Get(grpcHeaderStatus) != "" { // We got what gRPC calls a trailers-only response, which puts the trailing // metadata (including errors) into HTTP headers. validateResponse has // already extracted the error. return err } // See if the server sent an explicit error in the HTTP or gRPC-Web trailers. - mergeHeaders(r.trailer, r.readTrailers(&r.unmarshaler, r.duplexCall)) - serverErr := grpcErrorFromTrailer(r.bufferPool, r.protobuf, r.trailer) + mergeHeaders( + cc.responseTrailer, + cc.readTrailers(&cc.unmarshaler, cc.duplexCall), + ) + serverErr := grpcErrorFromTrailer(cc.bufferPool, cc.protobuf, cc.responseTrailer) if serverErr != nil && (errors.Is(err, io.EOF) || !errors.Is(serverErr, errTrailersWithoutGRPCStatus)) { // We've either: // - Cleanly read until the end of the response body and *not* received @@ -380,76 +347,125 @@ func (r *grpcClientReceiver) Receive(message any) error { // This is expected from a protocol perspective, but receiving trailers // means that we're _not_ getting a message. For users to realize that // the stream has ended, Receive must return an error. - serverErr.meta = r.Header().Clone() - mergeHeaders(serverErr.meta, r.trailer) - r.duplexCall.SetError(serverErr) + serverErr.meta = cc.responseHeader.Clone() + mergeHeaders(serverErr.meta, cc.responseTrailer) + cc.duplexCall.SetError(serverErr) return serverErr } // This was probably an error converting the bytes to a message or an error // reading from the network. We're going to return it to the // user, but we also want to setResponseError so Send errors out. - r.duplexCall.SetError(err) + cc.duplexCall.SetError(err) return err } -func (r *grpcClientReceiver) Close() error { - return r.duplexCall.CloseRead() +func (cc *grpcClientConn) ResponseHeader() http.Header { + cc.duplexCall.BlockUntilResponseReady() + return cc.responseHeader } -// validateResponse is called by duplexHTTPCall in a separate goroutine. -func (r *grpcClientReceiver) validateResponse(response *http.Response) *Error { +func (cc *grpcClientConn) ResponseTrailer() http.Header { + cc.duplexCall.BlockUntilResponseReady() + return cc.responseTrailer +} + +func (cc *grpcClientConn) CloseResponse() error { + return cc.duplexCall.CloseRead() +} + +func (cc *grpcClientConn) validateResponse(response *http.Response) *Error { if err := grpcValidateResponse( response, - r.header, - r.trailer, - r.compressionPools, - r.bufferPool, - r.protobuf, + cc.responseHeader, + cc.responseTrailer, + cc.compressionPools, + cc.bufferPool, + cc.protobuf, ); err != nil { return err } compression := response.Header.Get(grpcHeaderCompression) - r.unmarshaler.envelopeReader.compressionPool = r.compressionPools.Get(compression) + cc.unmarshaler.envelopeReader.compressionPool = cc.compressionPools.Get(compression) return nil } -type grpcHandlerSender struct { - spec Spec - web bool - marshaler grpcMarshaler - protobuf Codec // for errors - writer http.ResponseWriter - header http.Header - trailer http.Header - wroteToBody bool - bufferPool *bufferPool +type grpcHandlerConn struct { + spec Spec + web bool + bufferPool *bufferPool + protobuf Codec // for errors + marshaler grpcMarshaler + responseWriter http.ResponseWriter + responseHeader http.Header + responseTrailer http.Header + wroteToBody bool + request *http.Request + unmarshaler grpcUnmarshaler } -func (hs *grpcHandlerSender) Send(message any) error { - defer flushResponseWriter(hs.writer) - if !hs.wroteToBody { - mergeHeaders(hs.writer.Header(), hs.header) - hs.wroteToBody = true +func (hc *grpcHandlerConn) Spec() Spec { + return hc.spec +} + +func (hc *grpcHandlerConn) Receive(msg any) error { + if err := hc.unmarshaler.Unmarshal(msg); err != nil { + return err // already coded } - if err := hs.marshaler.Marshal(message); err != nil { + return nil // must be a literal nil: nil *Error is a non-nil error +} + +func (hc *grpcHandlerConn) RequestHeader() http.Header { + return hc.request.Header +} + +func (hc *grpcHandlerConn) Send(msg any) error { + defer flushResponseWriter(hc.responseWriter) + if !hc.wroteToBody { + mergeHeaders(hc.responseWriter.Header(), hc.responseHeader) + hc.wroteToBody = true + } + if err := hc.marshaler.Marshal(msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error } -func (hs *grpcHandlerSender) Close(err error) error { - defer flushResponseWriter(hs.writer) +func (hc *grpcHandlerConn) ResponseHeader() http.Header { + return hc.responseHeader +} + +func (hc *grpcHandlerConn) ResponseTrailer() http.Header { + return hc.responseTrailer +} + +func (hc *grpcHandlerConn) Close(err error) (retErr error) { // nolint:nonamedreturns + defer func() { + // We don't want to copy unread portions of the body to /dev/null here: if + // the client hasn't closed the request body, we'll block until the server + // timeout kicks in. This could happen because the client is malicious, but + // a well-intentioned client may just not expect the server to be returning + // an error for a streaming RPC. Better to accept that we can't always reuse + // TCP connections. + closeErr := hc.request.Body.Close() + if retErr == nil { + retErr = closeErr + } + }() + defer flushResponseWriter(hc.responseWriter) // If we haven't written the headers yet, do so. - if !hs.wroteToBody { - mergeHeaders(hs.writer.Header(), hs.header) + if !hc.wroteToBody { + mergeHeaders(hc.responseWriter.Header(), hc.responseHeader) } // gRPC always sends the error's code, message, details, and metadata as // trailing metadata. The Connect protocol doesn't do this, so we don't want // to mutate the trailers map that the user sees. - mergedTrailers := make(http.Header, len(hs.trailer)+2) // always make space for status & message - mergeHeaders(mergedTrailers, hs.trailer) - grpcErrorToTrailer(hs.bufferPool, mergedTrailers, hs.protobuf, err) - if hs.web && !hs.wroteToBody { + mergedTrailers := make( + http.Header, + len(hc.responseTrailer)+2, // always make space for status & message + ) + mergeHeaders(mergedTrailers, hc.responseTrailer) + grpcErrorToTrailer(hc.bufferPool, mergedTrailers, hc.protobuf, err) + if hc.web && !hc.wroteToBody { // We're using gRPC-Web and we haven't yet written to the body. Since we're // not sending any response messages, the gRPC specification calls this a // "trailers-only" response. Under those circumstances, the gRPC-Web spec @@ -457,13 +473,13 @@ func (hs *grpcHandlerSender) Close(err error) error { // instead. Envoy is the canonical implementation of the gRPC-Web protocol, // so we emulate Envoy's behavior and put the trailing metadata in the HTTP // headers. - mergeHeaders(hs.writer.Header(), mergedTrailers) + mergeHeaders(hc.responseWriter.Header(), mergedTrailers) return nil } - if hs.web { + if hc.web { // We're using gRPC-Web and we've already sent the headers, so we write // trailing metadata to the HTTP body. - if err := hs.marshaler.MarshalWebTrailers(mergedTrailers); err != nil { + if err := hc.marshaler.MarshalWebTrailers(mergedTrailers); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -482,107 +498,12 @@ func (hs *grpcHandlerSender) Close(err error) error { // logic breaks Envoy's gRPC-Web translation. for key, values := range mergedTrailers { for _, value := range values { - hs.writer.Header().Add(http.TrailerPrefix+key, value) + hc.responseWriter.Header().Add(http.TrailerPrefix+key, value) } } return nil } -func (hs *grpcHandlerSender) Spec() Spec { - return hs.spec -} - -func (hs *grpcHandlerSender) Header() http.Header { - return hs.header -} - -func (hs *grpcHandlerSender) Trailer() (http.Header, bool) { - return hs.trailer, true -} - -type grpcHandlerReceiver struct { - spec Spec - unmarshaler grpcUnmarshaler - request *http.Request -} - -func newGRPCHandlerStream( - spec Spec, - web bool, - responseWriter http.ResponseWriter, - request *http.Request, - compressMinBytes int, - codec Codec, - protobuf Codec, // for errors - requestCompressionPools *compressionPool, - responseCompressionPools *compressionPool, - bufferPool *bufferPool, - readMaxBytes int, -) (*grpcHandlerSender, *grpcHandlerReceiver) { - sender := &grpcHandlerSender{ - spec: spec, - web: web, - marshaler: grpcMarshaler{ - envelopeWriter: envelopeWriter{ - writer: responseWriter, - compressionPool: responseCompressionPools, - codec: codec, - compressMinBytes: compressMinBytes, - bufferPool: bufferPool, - }, - }, - protobuf: protobuf, - writer: responseWriter, - header: make(http.Header), - trailer: make(http.Header), - bufferPool: bufferPool, - } - receiver := &grpcHandlerReceiver{ - spec: spec, - unmarshaler: grpcUnmarshaler{ - envelopeReader: envelopeReader{ - reader: request.Body, - codec: codec, - compressionPool: requestCompressionPools, - bufferPool: bufferPool, - readMaxBytes: readMaxBytes, - }, - web: web, - }, - request: request, - } - return sender, receiver -} - -func (hr *grpcHandlerReceiver) Receive(message any) error { - if err := hr.unmarshaler.Unmarshal(message); err != nil { - return err // already coded - } - return nil // must be a literal nil: nil *Error is a non-nil error -} - -func (hr *grpcHandlerReceiver) Close() error { - // We don't want to copy unread portions of the body to /dev/null here: if - // the client hasn't closed the request body, we'll block until the server - // timeout kicks in. This could happen because the client is malicious, but - // a well-intentioned client may just not expect the server to be returning - // an error for a streaming RPC. Better to accept that we can't always reuse - // TCP connections. - return hr.request.Body.Close() -} - -func (hr *grpcHandlerReceiver) Spec() Spec { - return hr.spec -} - -func (hr *grpcHandlerReceiver) Header() http.Header { - return hr.request.Header -} - -func (hr *grpcHandlerReceiver) Trailer() (http.Header, bool) { - return nil, false -} - type grpcMarshaler struct { envelopeWriter } diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index 7b8fa264..04b480f1 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -19,6 +19,7 @@ import ( "math" "net/http" "net/http/httptest" + "strings" "testing" "testing/quick" "time" @@ -30,12 +31,21 @@ import ( func TestGRPCHandlerSender(t *testing.T) { t.Parallel() - newSender := func(web bool) *grpcHandlerSender { + newConn := func(web bool) *grpcHandlerConn { responseWriter := httptest.NewRecorder() protobufCodec := &protoBinaryCodec{} bufferPool := newBufferPool() - return &grpcHandlerSender{ - web: web, + request, err := http.NewRequest( + http.MethodPost, + "https://demo.example.com", + strings.NewReader(""), + ) + assert.Nil(t, err) + return &grpcHandlerConn{ + spec: Spec{}, + web: web, + bufferPool: bufferPool, + protobuf: protobufCodec, marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ writer: responseWriter, @@ -43,37 +53,40 @@ func TestGRPCHandlerSender(t *testing.T) { bufferPool: bufferPool, }, }, - protobuf: protobufCodec, - writer: responseWriter, - header: make(http.Header), - trailer: make(http.Header), - bufferPool: bufferPool, + responseWriter: responseWriter, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), + request: request, + unmarshaler: grpcUnmarshaler{ + envelopeReader: envelopeReader{ + reader: request.Body, + codec: protobufCodec, + bufferPool: bufferPool, + }, + }, } } t.Run("web", func(t *testing.T) { t.Parallel() - testGRPCHandlerSenderMetadata(t, newSender(true)) + testGRPCHandlerConnMetadata(t, newConn(true)) }) t.Run("http2", func(t *testing.T) { t.Parallel() - testGRPCHandlerSenderMetadata(t, newSender(false)) + testGRPCHandlerConnMetadata(t, newConn(false)) }) } -func testGRPCHandlerSenderMetadata(t *testing.T, sender Sender) { +func testGRPCHandlerConnMetadata(t *testing.T, conn handlerConnCloser) { // Closing the sender shouldn't unpredictably mutate user-visible headers or // trailers. t.Helper() - expectHeaders := sender.Header().Clone() - originalTrailers, hasOriginalTrailers := sender.Trailer() - assert.True(t, hasOriginalTrailers) - expectTrailers := originalTrailers.Clone() - sender.Close(NewError(CodeUnavailable, errors.New("oh no"))) - if diff := cmp.Diff(expectHeaders, sender.Header()); diff != "" { + expectHeaders := conn.ResponseHeader().Clone() + expectTrailers := conn.ResponseTrailer().Clone() + conn.Close(NewError(CodeUnavailable, errors.New("oh no"))) + if diff := cmp.Diff(expectHeaders, conn.ResponseHeader()); diff != "" { t.Errorf("headers changed:\n%s", diff) } - gotTrailers, hasGotTrailers := sender.Trailer() - assert.True(t, hasGotTrailers) + gotTrailers := conn.ResponseTrailer() if diff := cmp.Diff(expectTrailers, gotTrailers); diff != "" { t.Errorf("trailers changed:\n%s", diff) }