diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index d084ace0..8d958328 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -350,7 +350,7 @@ func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, servic g.P() for _, method := range service.Methods { g.P("func (", names.UnimplementedServer, ") ", serverSignature(g, method), "{") - if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() { + if method.Desc.IsStreamingServer() { g.P("return ", connectPackage.Ident("NewError"), "(", connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"), `("`, method.Desc.FullName(), ` is not implemented"))`) @@ -387,8 +387,8 @@ func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, n // client streaming return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + - ") error" + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + + ") (*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "] ,error)" } if method.Desc.IsStreamingServer() { // server streaming diff --git a/connect_ext_test.go b/connect_ext_test.go index 6f394006..6b8205c0 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -607,11 +607,11 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa func (p pingServer) Sum( ctx context.Context, - stream *connect.ClientStream[pingv1.SumRequest, pingv1.SumResponse], -) error { + stream *connect.ClientStream[pingv1.SumRequest], +) (*connect.Response[pingv1.SumResponse], error) { if p.checkMetadata { if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { - return err + return nil, err } } var sum int64 @@ -619,12 +619,12 @@ func (p pingServer) Sum( sum += stream.Msg().Number } if stream.Err() != nil { - return stream.Err() + return nil, stream.Err() } response := connect.NewResponse(&pingv1.SumResponse{Sum: sum}) response.Header().Set(handlerHeader, headerValue) response.Trailer().Set(handlerTrailer, trailerValue) - return stream.SendAndClose(response) + return response, nil } func (p pingServer) CountUp( diff --git a/handler.go b/handler.go index c6427d70..bc3e4c19 100644 --- a/handler.go +++ b/handler.go @@ -111,17 +111,29 @@ func NewUnaryHandler[Req, Res any]( // NewClientStreamHandler constructs a Handler for a client streaming procedure. func NewClientStreamHandler[Req, Res any]( procedure string, - implementation func(context.Context, *ClientStream[Req, Res]) error, + implementation func(context.Context, *ClientStream[Req]) (*Response[Res], error), options ...HandlerOption, ) *Handler { return newStreamHandler( procedure, StreamTypeClient, func(ctx context.Context, sender Sender, receiver Receiver) { - stream := &ClientStream[Req, Res]{sender: sender, receiver: receiver} - err := implementation(ctx, stream) - _ = receiver.Close() - _ = sender.Close(err) + stream := &ClientStream[Req]{receiver: receiver} + res, err := implementation(ctx, stream) + if err != nil { + _ = receiver.Close() + _ = sender.Close(err) + return + } + if err := receiver.Close(); err != nil { + _ = sender.Close(err) + return + } + mergeHeaders(sender.Header(), res.header) + if trailer, ok := sender.Trailer(); ok { + mergeHeaders(trailer, res.trailer) + } + _ = sender.Close(sender.Send(res.Msg)) }, options..., ) diff --git a/handler_stream.go b/handler_stream.go index 9976f501..517a4d9e 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -24,15 +24,14 @@ import ( // // It's constructed as part of Handler invocation, but doesn't currently have // an exported constructor. -type ClientStream[Req, Res any] struct { - sender Sender +type ClientStream[Req any] struct { receiver Receiver msg Req err error } // RequestHeader returns the headers received from the client. -func (c *ClientStream[Req, Res]) RequestHeader() http.Header { +func (c *ClientStream[Req]) RequestHeader() http.Header { return c.receiver.Header() } @@ -41,7 +40,7 @@ func (c *ClientStream[Req, Res]) RequestHeader() http.Header { // either by reaching the end or by encountering an unexpected error. After // Receive returns false, the Err method will return any unexpected error // encountered. -func (c *ClientStream[Req, Res]) Receive() bool { +func (c *ClientStream[Req]) Receive() bool { if c.err != nil { return false } @@ -52,31 +51,18 @@ func (c *ClientStream[Req, Res]) Receive() bool { // Msg returns the most recent message unmarshaled by a call to Receive. The // returned message points to data that will be overwritten by the next call to // Receive. -func (c *ClientStream[Req, Res]) Msg() *Req { +func (c *ClientStream[Req]) Msg() *Req { return &c.msg } // Err returns the first non-EOF error that was encountered by Receive. -func (c *ClientStream[Req, Res]) Err() error { +func (c *ClientStream[Req]) Err() error { if c.err == nil || errors.Is(c.err, io.EOF) { return nil } return c.err } -// SendAndClose closes the receive side of the stream, then sends a response -// back to the client. -func (c *ClientStream[Req, Res]) SendAndClose(envelope *Response[Res]) error { - if err := c.receiver.Close(); err != nil { - return err - } - mergeHeaders(c.sender.Header(), envelope.header) - if trailer, ok := c.sender.Trailer(); ok { - mergeHeaders(trailer, envelope.trailer) - } - return c.sender.Send(envelope.Msg) -} - // ServerStream is the handler's view of a server streaming RPC. // // It's constructed as part of Handler invocation, but doesn't currently have diff --git a/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go index 47116268..c74508ec 100644 --- a/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go @@ -132,7 +132,7 @@ type PingServiceHandler interface { // Fail always fails. Fail(context.Context, *connect_go.Request[v1.FailRequest]) (*connect_go.Response[v1.FailResponse], error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context, *connect_go.ClientStream[v1.SumRequest, v1.SumResponse]) error + Sum(context.Context, *connect_go.ClientStream[v1.SumRequest]) (*connect_go.Response[v1.SumResponse], error) // CountUp returns a stream of the numbers up to the given request. CountUp(context.Context, *connect_go.Request[v1.CountUpRequest], *connect_go.ServerStream[v1.CountUpResponse]) error // CumSum determines the cumulative sum of all the numbers sent on the stream. @@ -185,8 +185,8 @@ func (UnimplementedPingServiceHandler) Fail(context.Context, *connect_go.Request return nil, connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Fail is not implemented")) } -func (UnimplementedPingServiceHandler) Sum(context.Context, *connect_go.ClientStream[v1.SumRequest, v1.SumResponse]) error { - return connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum is not implemented")) +func (UnimplementedPingServiceHandler) Sum(context.Context, *connect_go.ClientStream[v1.SumRequest]) (*connect_go.Response[v1.SumResponse], error) { + return nil, connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum is not implemented")) } func (UnimplementedPingServiceHandler) CountUp(context.Context, *connect_go.Request[v1.CountUpRequest], *connect_go.ServerStream[v1.CountUpResponse]) error {