Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose peer information to handlers and clients #364

Merged
merged 2 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
// once at client creation.
unarySpec := config.newSpec(StreamTypeUnary)
unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
conn := protocolClient.NewConn(ctx, unarySpec, request.Header())
conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header())
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
Expand All @@ -94,9 +94,11 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
unaryFunc = interceptor.WrapUnary(unaryFunc)
}
client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
// To make the specification and RPC headers visible to the full interceptor
// chain (as though they were supplied by the caller), we'll add them here.
// To make the specification, peer, and RPC headers visible to the full
// interceptor chain (as though they were supplied by the caller), we'll
// add them here.
request.spec = unarySpec
request.peer = client.protocolClient.Peer()
protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header())
response, err := unaryFunc(ctx, request)
if err != nil {
Expand Down
88 changes: 88 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/bufbuild/connect-go"
Expand Down Expand Up @@ -68,3 +69,90 @@ func TestNewClient_InitFailure(t *testing.T) {
validateExpectedError(t, err)
})
}

func TestClientPeer(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}))
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)

run := func(t *testing.T, opts ...connect.ClientOption) {
t.Helper()
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithClientOptions(opts...),
connect.WithInterceptors(&assertPeerInterceptor{t}),
)
ctx := context.Background()
// unary
_, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{}))
assert.Nil(t, err)
// client streaming
clientStream := client.Sum(ctx)
t.Cleanup(func() {
_, closeErr := clientStream.CloseAndReceive()
assert.Nil(t, closeErr)
})
assert.NotNil(t, clientStream.Peer().Addr)
err = clientStream.Send(&pingv1.SumRequest{})
assert.Nil(t, err)
// server streaming
serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{}))
t.Cleanup(func() {
assert.Nil(t, serverStream.Close())
})
assert.Nil(t, err)
// bidi streaming
bidiStream := client.CumSum(ctx)
t.Cleanup(func() {
assert.Nil(t, bidiStream.CloseRequest())
assert.Nil(t, bidiStream.CloseResponse())
})
assert.NotNil(t, bidiStream.Peer().Addr)
err = bidiStream.Send(&pingv1.CumSumRequest{})
assert.Nil(t, err)
}

t.Run("connect", func(t *testing.T) {
t.Parallel()
run(t)
})
t.Run("grpc", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPC())
})
t.Run("grpcweb", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPCWeb())
})
}

type assertPeerInterceptor struct {
tb testing.TB
}

func (a *assertPeerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
assert.NotZero(a.tb, req.Peer().Addr)
return next(ctx, req)
}
}

func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
assert.NotZero(a.tb, conn.Peer().Addr)
return conn
}
}

func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
assert.NotZero(a.tb, conn.Peer().Addr)
return next(ctx, conn)
}
}
20 changes: 20 additions & 0 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ type ClientStreamForClient[Req, Res any] struct {
err error
}

// Spec returns the specification for the RPC.
func (c *ClientStreamForClient[_, _]) Spec() Spec {
return c.conn.Spec()
}

// Peer describes the server for the RPC.
func (c *ClientStreamForClient[_, _]) Peer() Peer {
return c.conn.Peer()
}

// RequestHeader returns the request headers. Headers are sent to the server with the
// first call to Send.
func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header {
Expand Down Expand Up @@ -164,6 +174,16 @@ type BidiStreamForClient[Req, Res any] struct {
err error
}

// Spec returns the specification for the RPC.
func (b *BidiStreamForClient[_, _]) Spec() Spec {
return b.conn.Spec()
}

// Peer describes the server for the RPC.
func (b *BidiStreamForClient[_, _]) Peer() Peer {
return b.conn.Peer()
}

// RequestHeader returns the request headers. Headers are sent with the first
// call to Send.
func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header {
Expand Down
27 changes: 26 additions & 1 deletion connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"errors"
"io"
"net/http"
"net/url"
)

// Version is the semantic version of the connect module.
Expand Down Expand Up @@ -68,6 +69,7 @@ const (
// StreamingHandlerConn implementations do not need to be safe for concurrent use.
type StreamingHandlerConn interface {
Spec() Spec
Peer() Peer

Receive(any) error
RequestHeader() http.Header
Expand Down Expand Up @@ -97,8 +99,9 @@ type StreamingHandlerConn interface {
// implementations must support limited concurrent use. See the comments on
// each group of methods for details.
type StreamingClientConn interface {
// Spec must be safe to call concurrently with all other methods.
// Spec and Peer must be safe to call concurrently with all other methods.
Spec() Spec
Peer() Peer

// Send, RequestHeader, and CloseRequest may race with each other, but must
// be safe to call concurrently with all other methods.
Expand All @@ -121,6 +124,7 @@ type Request[T any] struct {
Msg *T

spec Spec
peer Peer
header http.Header
}

Expand All @@ -144,6 +148,11 @@ func (r *Request[_]) Spec() Spec {
return r.spec
}

// Peer describes the other party for this RPC.
func (r *Request[_]) Peer() Peer {
return r.peer
}

// Header returns the HTTP headers for this request.
func (r *Request[_]) Header() http.Header {
if r.header == nil {
Expand All @@ -164,6 +173,7 @@ func (r *Request[_]) internalOnly() {}
type AnyRequest interface {
Any() any
Spec() Spec
Peer() Peer
Header() http.Header

internalOnly()
Expand Down Expand Up @@ -243,6 +253,21 @@ type Spec struct {
IsClient bool // otherwise we're in a handler
}

// Peer describes the other party to an RPC. When accessed client-side, Addr
// contains the host or host:port from the server's URL. When accessed
// server-side, Addr contains the client's address in IP:port format.
type Peer struct {
Addr string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gRPC exposes this as net.Addr (which includes a Network string: https://pkg.go.dev/google.golang.org/grpc/peer#Peer). Do you think we won't need to provide that additional data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly? We can always expose a Network attribute down the road too. I'm not sure how to reliably detect non-TCP transports through the net/http abstractions - for example, unix domain sockets are sort of the wild west, and I have no idea what HTTP/3 will end up looking like.

}

func newPeerFromURL(s string) Peer {
u, err := url.Parse(s)
if err != nil {
return Peer{}
}
return Peer{Addr: u.Host}
}

// handlerConnCloser extends HandlerConn with a method for handlers to
// terminate the message exchange (and optionally send an error to the client).
type handlerConnCloser interface {
Expand Down
15 changes: 15 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,9 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return nil, err
}
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
response := connect.NewResponse(
&pingv1.PingResponse{
Number: request.Msg.Number,
Expand All @@ -1446,6 +1449,9 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return nil, err
}
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
err := connect.NewError(connect.Code(request.Msg.Code), errors.New(errorMessage))
err.Meta().Set(handlerHeader, headerValue)
err.Meta().Set(handlerTrailer, trailerValue)
Expand All @@ -1461,6 +1467,9 @@ func (p pingServer) Sum(
return nil, err
}
}
if stream.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
var sum int64
for stream.Receive() {
sum += stream.Msg().Number
Expand All @@ -1482,6 +1491,9 @@ func (p pingServer) CountUp(
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return err
}
if request.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Msg.Number <= 0 {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf(
"number must be positive: got %v",
Expand All @@ -1508,6 +1520,9 @@ func (p pingServer) CumSum(
return err
}
}
if stream.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
stream.ResponseHeader().Set(handlerHeader, headerValue)
stream.ResponseTrailer().Set(handlerTrailer, trailerValue)
for {
Expand Down
2 changes: 2 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func NewUnaryHandler[Req, Res any](
request := &Request[Req]{
Msg: &msg,
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
}
response, err := untyped(ctx, request)
Expand Down Expand Up @@ -124,6 +125,7 @@ func NewServerStreamHandler[Req, Res any](
&Request[Req]{
Msg: &msg,
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
},
&ServerStream[Res]{conn: conn},
Expand Down
20 changes: 20 additions & 0 deletions handler_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ type ClientStream[Req any] struct {
err error
}

// Spec returns the specification for the RPC.
func (c *ClientStream[_]) Spec() Spec {
return c.conn.Spec()
}

// Peer describes the client for this RPC.
func (c *ClientStream[_]) Peer() Peer {
return c.conn.Peer()
}

// RequestHeader returns the headers received from the client.
func (c *ClientStream[Req]) RequestHeader() http.Header {
return c.conn.RequestHeader()
Expand Down Expand Up @@ -111,6 +121,16 @@ type BidiStream[Req, Res any] struct {
conn StreamingHandlerConn
}

// Spec returns the specification for the RPC.
func (b *BidiStream[_, _]) Spec() Spec {
return b.conn.Spec()
}

// Peer describes the client for this RPC.
func (b *BidiStream[_, _]) Peer() Peer {
return b.conn.Peer()
}

// RequestHeader returns the headers received from the client.
func (b *BidiStream[Req, Res]) RequestHeader() http.Header {
return b.conn.RequestHeader()
Expand Down
3 changes: 3 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ type protocolClientParams struct {
// Client is the client side of a protocol. HTTP clients typically use a single
// protocol, codec, and compressor to send requests.
type protocolClient interface {
// Peer describes the server for the RPC.
Peer() Peer

// WriteRequestHeader writes any protocol-specific request headers.
WriteRequestHeader(StreamType, http.Header)

Expand Down
Loading