Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Improve client streaming ergonomics (connectrpc#227)
Browse files Browse the repository at this point in the history
Rather than having handlers implementing client streaming methods call a
method to send the response, make things more like unary: they should
just return `(response, error)`.
  • Loading branch information
akshayjshah authored May 27, 2022
1 parent b79148b commit 159c801
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 35 deletions.
6 changes: 3 additions & 3 deletions cmd/protoc-gen-connect-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, servic
g.P()
for _, method := range service.Methods {
g.P("func (", names.UnimplementedServer, ") ", serverSignature(g, method), "{")
if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() {
if method.Desc.IsStreamingServer() {
g.P("return ", connectPackage.Ident("NewError"), "(",
connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"),
`("`, method.Desc.FullName(), ` is not implemented"))`)
Expand Down Expand Up @@ -387,8 +387,8 @@ func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, n
// client streaming
return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " +
streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) +
"[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" +
") error"
"[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" +
") (*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "] ,error)"
}
if method.Desc.IsStreamingServer() {
// server streaming
Expand Down
10 changes: 5 additions & 5 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,24 +607,24 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa

func (p pingServer) Sum(
ctx context.Context,
stream *connect.ClientStream[pingv1.SumRequest, pingv1.SumResponse],
) error {
stream *connect.ClientStream[pingv1.SumRequest],
) (*connect.Response[pingv1.SumResponse], error) {
if p.checkMetadata {
if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil {
return err
return nil, err
}
}
var sum int64
for stream.Receive() {
sum += stream.Msg().Number
}
if stream.Err() != nil {
return stream.Err()
return nil, stream.Err()
}
response := connect.NewResponse(&pingv1.SumResponse{Sum: sum})
response.Header().Set(handlerHeader, headerValue)
response.Trailer().Set(handlerTrailer, trailerValue)
return stream.SendAndClose(response)
return response, nil
}

func (p pingServer) CountUp(
Expand Down
22 changes: 17 additions & 5 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,29 @@ func NewUnaryHandler[Req, Res any](
// NewClientStreamHandler constructs a Handler for a client streaming procedure.
func NewClientStreamHandler[Req, Res any](
procedure string,
implementation func(context.Context, *ClientStream[Req, Res]) error,
implementation func(context.Context, *ClientStream[Req]) (*Response[Res], error),
options ...HandlerOption,
) *Handler {
return newStreamHandler(
procedure,
StreamTypeClient,
func(ctx context.Context, sender Sender, receiver Receiver) {
stream := &ClientStream[Req, Res]{sender: sender, receiver: receiver}
err := implementation(ctx, stream)
_ = receiver.Close()
_ = sender.Close(err)
stream := &ClientStream[Req]{receiver: receiver}
res, err := implementation(ctx, stream)
if err != nil {
_ = receiver.Close()
_ = sender.Close(err)
return
}
if err := receiver.Close(); err != nil {
_ = sender.Close(err)
return
}
mergeHeaders(sender.Header(), res.header)
if trailer, ok := sender.Trailer(); ok {
mergeHeaders(trailer, res.trailer)
}
_ = sender.Close(sender.Send(res.Msg))
},
options...,
)
Expand Down
24 changes: 5 additions & 19 deletions handler_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ import (
//
// It's constructed as part of Handler invocation, but doesn't currently have
// an exported constructor.
type ClientStream[Req, Res any] struct {
sender Sender
type ClientStream[Req any] struct {
receiver Receiver
msg Req
err error
}

// RequestHeader returns the headers received from the client.
func (c *ClientStream[Req, Res]) RequestHeader() http.Header {
func (c *ClientStream[Req]) RequestHeader() http.Header {
return c.receiver.Header()
}

Expand All @@ -41,7 +40,7 @@ func (c *ClientStream[Req, Res]) RequestHeader() http.Header {
// either by reaching the end or by encountering an unexpected error. After
// Receive returns false, the Err method will return any unexpected error
// encountered.
func (c *ClientStream[Req, Res]) Receive() bool {
func (c *ClientStream[Req]) Receive() bool {
if c.err != nil {
return false
}
Expand All @@ -52,31 +51,18 @@ func (c *ClientStream[Req, Res]) Receive() bool {
// Msg returns the most recent message unmarshaled by a call to Receive. The
// returned message points to data that will be overwritten by the next call to
// Receive.
func (c *ClientStream[Req, Res]) Msg() *Req {
func (c *ClientStream[Req]) Msg() *Req {
return &c.msg
}

// Err returns the first non-EOF error that was encountered by Receive.
func (c *ClientStream[Req, Res]) Err() error {
func (c *ClientStream[Req]) Err() error {
if c.err == nil || errors.Is(c.err, io.EOF) {
return nil
}
return c.err
}

// SendAndClose closes the receive side of the stream, then sends a response
// back to the client.
func (c *ClientStream[Req, Res]) SendAndClose(envelope *Response[Res]) error {
if err := c.receiver.Close(); err != nil {
return err
}
mergeHeaders(c.sender.Header(), envelope.header)
if trailer, ok := c.sender.Trailer(); ok {
mergeHeaders(trailer, envelope.trailer)
}
return c.sender.Send(envelope.Msg)
}

// ServerStream is the handler's view of a server streaming RPC.
//
// It's constructed as part of Handler invocation, but doesn't currently have
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 159c801

Please sign in to comment.