Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework streaming interceptors #316

Merged
merged 11 commits into from
Jul 12, 2022
54 changes: 25 additions & 29 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,25 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
// once at client creation.
unarySpec := config.newSpec(StreamTypeUnary)
unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
sender, receiver := protocolClient.NewStream(ctx, unarySpec, request.Header())
conn := protocolClient.NewConn(ctx, unarySpec, request.Header())
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
if err := sender.Send(request.Any()); err != nil && !errors.Is(err, io.EOF) {
_ = sender.Close(err)
_ = receiver.Close()
if err := conn.Send(request.Any()); err != nil && !errors.Is(err, io.EOF) {
_ = conn.CloseRequest()
_ = conn.CloseResponse()
return nil, err
}
if err := sender.Close(nil); err != nil {
_ = receiver.Close()
if err := conn.CloseRequest(); err != nil {
_ = conn.CloseResponse()
return nil, err
}
response, err := receiveUnaryResponse[Res](receiver)
response, err := receiveUnaryResponse[Res](conn)
if err != nil {
_ = receiver.Close()
_ = conn.CloseResponse()
return nil, err
}
return response, receiver.Close()
return response, conn.CloseResponse()
})
if interceptor := config.Interceptor; interceptor != nil {
unaryFunc = interceptor.WrapUnary(unaryFunc)
Expand Down Expand Up @@ -123,52 +123,48 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
sender, receiver := c.newStream(ctx, StreamTypeClient)
return &ClientStreamForClient[Req, Res]{sender: sender, receiver: receiver}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient)}
}

// CallServerStream calls a server streaming procedure.
func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) {
if c.err != nil {
return nil, c.err
}
sender, receiver := c.newStream(ctx, StreamTypeServer)
mergeHeaders(sender.Header(), request.header)
conn := c.newConn(ctx, StreamTypeServer)
mergeHeaders(conn.RequestHeader(), request.header)
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
if err := sender.Send(request.Msg); err != nil && !errors.Is(err, io.EOF) {
_ = sender.Close(err)
_ = receiver.Close()
if err := conn.Send(request.Msg); err != nil && !errors.Is(err, io.EOF) {
_ = conn.CloseRequest()
_ = conn.CloseResponse()
return nil, err
}
if err := sender.Close(nil); err != nil {
if err := conn.CloseRequest(); err != nil {
return nil, err
}
return &ServerStreamForClient[Res]{receiver: receiver}, nil
return &ServerStreamForClient[Res]{conn: conn}, nil
}

// CallBidiStream calls a bidirectional streaming procedure.
func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] {
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
sender, receiver := c.newStream(ctx, StreamTypeBidi)
return &BidiStreamForClient[Req, Res]{sender: sender, receiver: receiver}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi)}
}

func (c *Client[Req, Res]) newStream(ctx context.Context, streamType StreamType) (Sender, Receiver) {
if interceptor := c.config.Interceptor; interceptor != nil {
ctx = interceptor.WrapStreamContext(ctx)
func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType) ClientConn {
newConn := func(ctx context.Context, spec Spec) ClientConn {
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
c.protocolClient.WriteRequestHeader(streamType, header)
return c.protocolClient.NewConn(ctx, spec, header)
}
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
c.protocolClient.WriteRequestHeader(streamType, header)
sender, receiver := c.protocolClient.NewStream(ctx, c.config.newSpec(streamType), header)
if interceptor := c.config.Interceptor; interceptor != nil {
sender = interceptor.WrapStreamSender(ctx, sender)
receiver = interceptor.WrapStreamReceiver(ctx, receiver)
newConn = interceptor.WrapStreamingClient(newConn)
}
return sender, receiver
return newConn(ctx, c.config.newSpec(streamType))
}

type clientConfig struct {
Expand Down
64 changes: 27 additions & 37 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import (
// It's returned from Client.CallClientStream, but doesn't currently have an
// exported constructor function.
type ClientStreamForClient[Req, Res any] struct {
sender Sender
receiver Receiver
conn ClientConn
// Error from client construction. If non-nil, return for all calls.
err error
}
Expand All @@ -37,7 +36,7 @@ func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header {
if c.err != nil {
return http.Header{}
}
return c.sender.Header()
return c.conn.RequestHeader()
}

// Send a message to the server. The first call to Send also sends the request
Expand All @@ -50,7 +49,7 @@ func (c *ClientStreamForClient[Req, Res]) Send(request *Req) error {
if c.err != nil {
return c.err
}
return c.sender.Send(request)
return c.conn.Send(request)
}

// CloseAndReceive closes the send side of the stream and waits for the
Expand All @@ -59,30 +58,28 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err
if c.err != nil {
return nil, c.err
}
if err := c.sender.Close(nil); err != nil {
if err := c.conn.CloseRequest(); err != nil {
_ = c.conn.CloseResponse()
return nil, err
}
response, err := receiveUnaryResponse[Res](c.receiver)
response, err := receiveUnaryResponse[Res](c.conn)
if err != nil {
_ = c.receiver.Close()
_ = c.conn.CloseResponse()
return nil, err
}
if err := c.receiver.Close(); err != nil {
return nil, err
}
return response, nil
return response, c.conn.CloseResponse()
}

// ServerStreamForClient is the client's view of a server streaming RPC.
//
// It's returned from Client.CallServerStream, but doesn't currently have an
// exported constructor function.
type ServerStreamForClient[Res any] struct {
receiver Receiver
msg Res
conn ClientConn
msg Res
// Error from client construction. If non-nil, return for all calls.
constructErr error
// Error from Receive().
// Error from conn.Receive().
receiveErr error
}

Expand All @@ -95,7 +92,7 @@ func (s *ServerStreamForClient[Res]) Receive() bool {
if s.constructErr != nil || s.receiveErr != nil {
return false
}
s.receiveErr = s.receiver.Receive(&s.msg)
s.receiveErr = s.conn.Receive(&s.msg)
return s.receiveErr == nil
}

Expand Down Expand Up @@ -123,7 +120,7 @@ func (s *ServerStreamForClient[Res]) ResponseHeader() http.Header {
if s.constructErr != nil {
return http.Header{}
}
return s.receiver.Header()
return s.conn.ResponseHeader()
}

// ResponseTrailer returns the trailers received from the server. Trailers
Expand All @@ -132,27 +129,23 @@ func (s *ServerStreamForClient[Res]) ResponseTrailer() http.Header {
if s.constructErr != nil {
return http.Header{}
}
if trailer, ok := s.receiver.Trailer(); ok {
return trailer
}
return make(http.Header)
return s.conn.ResponseTrailer()
}

// Close the receive side of the stream.
func (s *ServerStreamForClient[Res]) Close() error {
if s.constructErr != nil {
return s.constructErr
}
return s.receiver.Close()
return s.conn.CloseResponse()
}

// BidiStreamForClient is the client's view of a bidirectional streaming RPC.
//
// It's returned from Client.CallBidiStream, but doesn't currently have an
// exported constructor function.
type BidiStreamForClient[Req, Res any] struct {
sender Sender
receiver Receiver
conn ClientConn
// Error from client construction. If non-nil, return for all calls.
err error
}
Expand All @@ -163,7 +156,7 @@ func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header {
if b.err != nil {
return http.Header{}
}
return b.sender.Header()
return b.conn.RequestHeader()
}

// Send a message to the server. The first call to Send also sends the request
Expand All @@ -176,15 +169,15 @@ func (b *BidiStreamForClient[Req, Res]) Send(msg *Req) error {
if b.err != nil {
return b.err
}
return b.sender.Send(msg)
return b.conn.Send(msg)
}

// CloseSend closes the send side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseSend() error {
// Close the send side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseRequest() error {
if b.err != nil {
return b.err
}
return b.sender.Close(nil)
return b.conn.CloseRequest()
}

// Receive a message. When the server is done sending messages and no other
Expand All @@ -194,18 +187,18 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {
return nil, b.err
}
var msg Res
if err := b.receiver.Receive(&msg); err != nil {
if err := b.conn.Receive(&msg); err != nil {
return nil, err
}
return &msg, nil
}

// CloseReceive closes the receive side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseReceive() error {
// Close the receive side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseResponse() error {
if b.err != nil {
return b.err
}
return b.receiver.Close()
return b.conn.CloseResponse()
}

// ResponseHeader returns the headers received from the server. It blocks until
Expand All @@ -214,7 +207,7 @@ func (b *BidiStreamForClient[Req, Res]) ResponseHeader() http.Header {
if b.err != nil {
return http.Header{}
}
return b.receiver.Header()
return b.conn.ResponseHeader()
}

// ResponseTrailer returns the trailers received from the server. Trailers
Expand All @@ -223,8 +216,5 @@ func (b *BidiStreamForClient[Req, Res]) ResponseTrailer() http.Header {
if b.err != nil {
return http.Header{}
}
if trailer, ok := b.receiver.Trailer(); ok {
return trailer
}
return make(http.Header)
return b.conn.ResponseTrailer()
}
4 changes: 2 additions & 2 deletions client_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ func TestBidiStreamForClient_NoPanics(t *testing.T) {
verifyHeaders(t, bidiStream.ResponseHeader())
verifyHeaders(t, bidiStream.ResponseTrailer())
assert.ErrorIs(t, bidiStream.Send(&pingv1.CumSumRequest{}), initErr)
assert.ErrorIs(t, bidiStream.CloseReceive(), initErr)
assert.ErrorIs(t, bidiStream.CloseSend(), initErr)
assert.ErrorIs(t, bidiStream.CloseRequest(), initErr)
assert.ErrorIs(t, bidiStream.CloseResponse(), initErr)
}

func verifyHeaders(t *testing.T, headers http.Header) {
Expand Down
2 changes: 2 additions & 0 deletions code.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ func (c Code) String() string {
return fmt.Sprintf("code_%d", c)
}

// MarshalText implements encoding.TextMarshaler.
func (c Code) MarshalText() ([]byte, error) {
return []byte(c.String()), nil
}

// UnmarshalText implements encoding.TextUnmarshaler.
func (c *Code) UnmarshalText(data []byte) error {
dataStr := string(data)
switch dataStr {
Expand Down
Loading