Skip to content

Commit

Permalink
In iterators, allocate new msg for each Receive (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayjshah authored Aug 17, 2022
1 parent bb13b7a commit 398f155
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 16 deletions.
32 changes: 26 additions & 6 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err
return response, c.conn.CloseResponse()
}

// Conn exposes the underlying StreamingClientConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (c *ClientStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) {
return c.conn, c.err
}

// ServerStreamForClient is the client's view of a server streaming RPC.
//
// It's returned from [Client].CallServerStream, but doesn't currently have an
// exported constructor function.
type ServerStreamForClient[Res any] struct {
conn StreamingClientConn
msg Res
msg *Res
// Error from client construction. If non-nil, return for all calls.
constructErr error
// Error from conn.Receive().
Expand All @@ -92,15 +98,17 @@ func (s *ServerStreamForClient[Res]) Receive() bool {
if s.constructErr != nil || s.receiveErr != nil {
return false
}
s.receiveErr = s.conn.Receive(&s.msg)
s.msg = new(Res)
s.receiveErr = s.conn.Receive(s.msg)
return s.receiveErr == nil
}

// Msg returns the most recent message unmarshaled by a call to Receive. The
// returned message points to data that will be overwritten by the next call to
// Receive.
// Msg returns the most recent message unmarshaled by a call to Receive.
func (s *ServerStreamForClient[Res]) Msg() *Res {
return &s.msg
if s.msg == nil {
s.msg = new(Res)
}
return s.msg
}

// Err returns the first non-EOF error that was encountered by Receive.
Expand Down Expand Up @@ -140,6 +148,12 @@ func (s *ServerStreamForClient[Res]) Close() error {
return s.conn.CloseResponse()
}

// Conn exposes the underlying StreamingClientConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (s *ServerStreamForClient[Res]) Conn() (StreamingClientConn, error) {
return s.conn, s.constructErr
}

// BidiStreamForClient is the client's view of a bidirectional streaming RPC.
//
// It's returned from [Client].CallBidiStream, but doesn't currently have an
Expand Down Expand Up @@ -218,3 +232,9 @@ func (b *BidiStreamForClient[Req, Res]) ResponseTrailer() http.Header {
}
return b.conn.ResponseTrailer()
}

// Conn exposes the underlying StreamingClientConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (b *BidiStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) {
return b.conn, b.err
}
43 changes: 39 additions & 4 deletions client_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"errors"
"fmt"
"net/http"
"testing"

Expand All @@ -26,12 +27,15 @@ import (
func TestClientStreamForClient_NoPanics(t *testing.T) {
t.Parallel()
initErr := errors.New("client init failure")
cs := &ClientStreamForClient[pingv1.PingRequest, pingv1.PingResponse]{err: initErr}
assert.ErrorIs(t, cs.Send(&pingv1.PingRequest{}), initErr)
verifyHeaders(t, cs.RequestHeader())
res, err := cs.CloseAndReceive()
clientStream := &ClientStreamForClient[pingv1.PingRequest, pingv1.PingResponse]{err: initErr}
assert.ErrorIs(t, clientStream.Send(&pingv1.PingRequest{}), initErr)
verifyHeaders(t, clientStream.RequestHeader())
res, err := clientStream.CloseAndReceive()
assert.Nil(t, res)
assert.ErrorIs(t, err, initErr)
conn, err := clientStream.Conn()
assert.NotNil(t, err)
assert.Nil(t, conn)
}

func TestServerStreamForClient_NoPanics(t *testing.T) {
Expand All @@ -44,6 +48,26 @@ func TestServerStreamForClient_NoPanics(t *testing.T) {
assert.False(t, serverStream.Receive())
verifyHeaders(t, serverStream.ResponseHeader())
verifyHeaders(t, serverStream.ResponseTrailer())
conn, err := serverStream.Conn()
assert.NotNil(t, err)
assert.Nil(t, conn)
}

func TestServerStreamForClient(t *testing.T) {
t.Parallel()
stream := &ServerStreamForClient[pingv1.PingResponse]{conn: &nopStreamingClientConn{}}
// Ensure that each call to Receive allocates a new message. This helps
// vtprotobuf, which doesn't automatically zero messages before unmarshaling
// (see https://github.com/bufbuild/connect-go/issues/345), and it's also
// less error-prone for users.
assert.True(t, stream.Receive())
first := fmt.Sprintf("%p", stream.Msg())
assert.True(t, stream.Receive())
second := fmt.Sprintf("%p", stream.Msg())
assert.NotEqual(t, first, second)
conn, err := stream.Conn()
assert.Nil(t, err)
assert.NotNil(t, conn)
}

func TestBidiStreamForClient_NoPanics(t *testing.T) {
Expand All @@ -59,6 +83,9 @@ func TestBidiStreamForClient_NoPanics(t *testing.T) {
assert.ErrorIs(t, bidiStream.Send(&pingv1.CumSumRequest{}), initErr)
assert.ErrorIs(t, bidiStream.CloseRequest(), initErr)
assert.ErrorIs(t, bidiStream.CloseResponse(), initErr)
conn, err := bidiStream.Conn()
assert.NotNil(t, err)
assert.Nil(t, conn)
}

func verifyHeaders(t *testing.T, headers http.Header) {
Expand All @@ -69,3 +96,11 @@ func verifyHeaders(t *testing.T, headers http.Header) {
headers.Set("a", "b")
headers.Del("a")
}

type nopStreamingClientConn struct {
StreamingClientConn
}

func (c *nopStreamingClientConn) Receive(msg any) error {
return nil
}
32 changes: 26 additions & 6 deletions handler_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
// an exported constructor.
type ClientStream[Req any] struct {
conn StreamingHandlerConn
msg Req
msg *Req
err error
}

Expand All @@ -44,15 +44,17 @@ func (c *ClientStream[Req]) Receive() bool {
if c.err != nil {
return false
}
c.err = c.conn.Receive(&c.msg)
c.msg = new(Req)
c.err = c.conn.Receive(c.msg)
return c.err == nil
}

// Msg returns the most recent message unmarshaled by a call to Receive. The
// returned message points to data that will be overwritten by the next call to
// Receive.
// Msg returns the most recent message unmarshaled by a call to Receive.
func (c *ClientStream[Req]) Msg() *Req {
return &c.msg
if c.msg == nil {
c.msg = new(Req)
}
return c.msg
}

// Err returns the first non-EOF error that was encountered by Receive.
Expand All @@ -63,6 +65,12 @@ func (c *ClientStream[Req]) Err() error {
return c.err
}

// Conn exposes the underlying StreamingHandlerConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (c *ClientStream[Req]) Conn() StreamingHandlerConn {
return c.conn
}

// ServerStream is the handler's view of a server streaming RPC.
//
// It's constructed as part of [Handler] invocation, but doesn't currently have
Expand All @@ -89,6 +97,12 @@ func (s *ServerStream[Res]) Send(msg *Res) error {
return s.conn.Send(msg)
}

// Conn exposes the underlying StreamingHandlerConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (s *ServerStream[Res]) Conn() StreamingHandlerConn {
return s.conn
}

// BidiStream is the handler's view of a bidirectional streaming RPC.
//
// It's constructed as part of [Handler] invocation, but doesn't currently have
Expand Down Expand Up @@ -129,3 +143,9 @@ func (b *BidiStream[Req, Res]) ResponseTrailer() http.Header {
func (b *BidiStream[Req, Res]) Send(msg *Res) error {
return b.conn.Send(msg)
}

// Conn exposes the underlying StreamingHandlerConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (b *BidiStream[Req, Res]) Conn() StreamingHandlerConn {
return b.conn
}
41 changes: 41 additions & 0 deletions handler_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect

import (
"fmt"
"testing"

"github.com/bufbuild/connect-go/internal/assert"
pingv1 "github.com/bufbuild/connect-go/internal/gen/connect/ping/v1"
)

func TestClientStream(t *testing.T) {
t.Parallel()
stream := &ClientStream[pingv1.PingRequest]{conn: &nopStreamingHandlerConn{}}
assert.True(t, stream.Receive())
first := fmt.Sprintf("%p", stream.Msg())
assert.True(t, stream.Receive())
second := fmt.Sprintf("%p", stream.Msg())
assert.NotEqual(t, first, second)
}

type nopStreamingHandlerConn struct {
StreamingHandlerConn
}

func (nopStreamingHandlerConn) Receive(msg any) error {
return nil
}

0 comments on commit 398f155

Please sign in to comment.