Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose request method of unary requests to clients and server handlers #502

Merged
merged 5 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
unarySpec := config.newSpec(StreamTypeUnary)
unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header())
conn.onRequestSend(func(r *http.Request) {
request.setRequestMethod(r.Method)
})
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
Expand Down Expand Up @@ -132,15 +135,17 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient)}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)}
}

// CallServerStream calls a server streaming procedure.
func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) {
if c.err != nil {
return nil, c.err
}
conn := c.newConn(ctx, StreamTypeServer)
conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) {
request.method = r.Method
})
request.spec = conn.Spec()
request.peer = conn.Peer()
mergeHeaders(conn.RequestHeader(), request.header)
Expand All @@ -163,14 +168,16 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi)}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)}
}

func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType) StreamingClientConn {
func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn {
newConn := func(ctx context.Context, spec Spec) StreamingClientConn {
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
c.protocolClient.WriteRequestHeader(streamType, header)
return c.protocolClient.NewConn(ctx, spec, header)
conn := c.protocolClient.NewConn(ctx, spec, header)
conn.onRequestSend(onRequestSend)
return conn
}
if interceptor := c.config.Interceptor; interceptor != nil {
newConn = interceptor.WrapStreamingClient(newConn)
Expand Down
25 changes: 15 additions & 10 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestClientPeer(t *testing.T) {
server.StartTLS()
t.Cleanup(server.Close)

run := func(t *testing.T, opts ...connect.ClientOption) {
run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) {
t.Helper()
client := pingv1connect.NewPingServiceClient(
server.Client(),
Expand All @@ -90,8 +90,10 @@ func TestClientPeer(t *testing.T) {
)
ctx := context.Background()
// unary
_, err := client.Ping(ctx, connect.NewRequest[pingv1.PingRequest](nil))
unaryReq := connect.NewRequest[pingv1.PingRequest](nil)
_, err := client.Ping(ctx, unaryReq)
assert.Nil(t, err)
assert.Equal(t, unaryHTTPMethod, unaryReq.HTTPMethod())
text := strings.Repeat(".", 256)
r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text}))
assert.Nil(t, err)
Expand Down Expand Up @@ -126,22 +128,22 @@ func TestClientPeer(t *testing.T) {

t.Run("connect", func(t *testing.T) {
t.Parallel()
run(t)
run(t, http.MethodPost)
})
t.Run("connect+get", func(t *testing.T) {
t.Parallel()
run(t,
run(t, http.MethodGet,
connect.WithHTTPGet(),
connect.WithSendGzip(),
)
})
t.Run("grpc", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPC())
run(t, http.MethodPost, connect.WithGRPC())
})
t.Run("grpcweb", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPCWeb())
run(t, http.MethodPost, connect.WithGRPCWeb())
})
}

Expand All @@ -167,21 +169,24 @@ func TestGetNotModified(t *testing.T) {
)
ctx := context.Background()
// unconditional request
res, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{}))
unaryReq := connect.NewRequest(&pingv1.PingRequest{})
res, err := client.Ping(ctx, unaryReq)
assert.Nil(t, err)
assert.Equal(t, res.Header().Get("Etag"), etag)
assert.Equal(t, res.Header().Values("Vary"), expectVary)
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())

conditional := connect.NewRequest(&pingv1.PingRequest{})
conditional.Header().Set("If-None-Match", etag)
_, err = client.Ping(ctx, conditional)
unaryReq = connect.NewRequest(&pingv1.PingRequest{})
unaryReq.Header().Set("If-None-Match", etag)
_, err = client.Ping(ctx, unaryReq)
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
assert.True(t, connect.IsNotModifiedError(err))
var connectErr *connect.Error
assert.True(t, errors.As(err, &connectErr))
assert.Equal(t, connectErr.Meta().Get("Etag"), etag)
assert.Equal(t, connectErr.Meta().Values("Vary"), expectVary)
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())
}

type notModifiedPingServer struct {
Expand Down
24 changes: 23 additions & 1 deletion connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ type Request[T any] struct {
spec Spec
peer Peer
header http.Header
method string
}

// NewRequest wraps a generated request message.
Expand Down Expand Up @@ -172,9 +173,28 @@ func (r *Request[_]) Header() http.Header {
return r.header
}

// HTTPMethod returns the HTTP method for this request. This is nearly always
// POST, but side-effect-free unary RPCs could be made via a GET.
//
// On a newly created request, via NewRequest, this will return the empty
// string until the actual request is actually sent and the HTTP method
// determined. This means that client interceptor functions will see the
// empty string until *after* they delegate to the handler they wrapped. It
// is even possible for this to return the empty string after such delegation,
// if the request was never actually sent to the server (and thus no
// determination ever made about the HTTP method).
func (r *Request[_]) HTTPMethod() string {
return r.method
}

// internalOnly implements AnyRequest.
func (r *Request[_]) internalOnly() {}

// setRequestMethod sets the request method to the given value.
func (r *Request[_]) setRequestMethod(method string) {
r.method = method
}

// AnyRequest is the common method set of every [Request], regardless of type
// parameter. It's used in unary interceptors.
//
Expand All @@ -190,8 +210,10 @@ type AnyRequest interface {
Spec() Spec
Peer() Peer
Header() http.Header
HTTPMethod() string

internalOnly()
setRequestMethod(string)
}

// Response is a wrapper around a generated response message. It provides
Expand Down Expand Up @@ -307,7 +329,7 @@ func newPeerFromURL(url *url.URL, protocol string) Peer {
}
}

// handlerConnCloser extends HandlerConn with a method for handlers to
// handlerConnCloser extends StreamingHandlerConn with a method for handlers to
// terminate the message exchange (and optionally send an error to the client).
type handlerConnCloser interface {
StreamingHandlerConn
Expand Down
4 changes: 4 additions & 0 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type duplexHTTPCall struct {
ctx context.Context
httpClient HTTPClient
streamType StreamType
onRequestSend func(*http.Request)
validateResponse func(*http.Response) *Error

// We'll use a pipe as the request body. We hand the read side of the pipe to
Expand Down Expand Up @@ -255,6 +256,9 @@ func (d *duplexHTTPCall) makeRequest() {
// on d.responseReady, so we can't race with them.
defer close(d.responseReady)

if d.onRequestSend != nil {
d.onRequestSend(d.request)
}
// Once we send a message to the server, they send a message back and
// establish the receive side of the stream.
response, err := d.httpClient.Do(d.request) //nolint:bodyclose
Expand Down
6 changes: 6 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ func NewUnaryHandler[Req, Res any](
if err := conn.Receive(&msg); err != nil {
return err
}
method := http.MethodPost
if hasRequestMethod, ok := conn.(interface{ getHTTPMethod() string }); ok {
method = hasRequestMethod.getHTTPMethod()
}
request := &Request[Req]{
Msg: &msg,
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
method: method,
}
response, err := untyped(ctx, request)
if err != nil {
Expand Down Expand Up @@ -141,6 +146,7 @@ func NewServerStreamHandler[Req, Res any](
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
method: http.MethodPost,
},
&ServerStream[Res]{conn: conn},
)
Expand Down
Loading