Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose request method of unary requests to clients and server handlers #502

Merged
merged 5 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
unarySpec := config.newSpec(StreamTypeUnary)
unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header())
if hasRequestMethod, ok := conn.(clientConnWithRequestMethod); ok {
hasRequestMethod.onSetMethod(request.setRequestMethod)
}
conn.onRequestSend(func(r *http.Request) {
request.setRequestMethod(r.Method)
})
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
Expand Down Expand Up @@ -135,15 +135,17 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient)}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)}
}

// CallServerStream calls a server streaming procedure.
func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) {
if c.err != nil {
return nil, c.err
}
conn := c.newConn(ctx, StreamTypeServer)
conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) {
request.method = r.Method
})
request.spec = conn.Spec()
request.peer = conn.Peer()
mergeHeaders(conn.RequestHeader(), request.header)
Expand All @@ -166,14 +168,16 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi)}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)}
}

func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType) StreamingClientConn {
func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn {
newConn := func(ctx context.Context, spec Spec) StreamingClientConn {
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
c.protocolClient.WriteRequestHeader(streamType, header)
return c.protocolClient.NewConn(ctx, spec, header)
conn := c.protocolClient.NewConn(ctx, spec, header)
conn.onRequestSend(onRequestSend)
return conn
}
if interceptor := c.config.Interceptor; interceptor != nil {
newConn = interceptor.WrapStreamingClient(newConn)
Expand Down
33 changes: 12 additions & 21 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,17 @@ func (r *Request[_]) Header() http.Header {
return r.header
}

// Method returns the HTTP method for this request. This is nearly always
// POST, but side-effect-free RPCs could be made via a GET.
// HTTPMethod returns the HTTP method for this request. This is nearly always
// POST, but side-effect-free unary RPCs could be made via a GET.
//
// On a newly created request, via NewRequest, this will return "POST" and
// only changes to return to "GET" after it is actually sent to a server
// using GET as the method.
func (r *Request[_]) Method() string {
if r.method == "" {
return http.MethodPost
}
// On a newly created request, via NewRequest, this will return the empty
// string until the actual request is actually sent and the HTTP method
// determined. This means that client interceptor functions will see the
// empty string until *after* they delegate to the handler they wrapped. It
// is even possible for this to return the empty string after such delegation,
// if the request was never actually sent to the server (and thus no
// determination ever made about the HTTP method).
func (r *Request[_]) HTTPMethod() string {
return r.method
}

Expand All @@ -209,7 +210,7 @@ type AnyRequest interface {
Spec() Spec
Peer() Peer
Header() http.Header
Method() string
HTTPMethod() string

internalOnly()
setRequestMethod(string)
Expand Down Expand Up @@ -328,24 +329,14 @@ func newPeerFromURL(url *url.URL, protocol string) Peer {
}
}

// handlerConnCloser extends HandlerConn with a method for handlers to
// handlerConnCloser extends StreamingHandlerConn with a method for handlers to
// terminate the message exchange (and optionally send an error to the client).
type handlerConnCloser interface {
StreamingHandlerConn

Close(error) error
}

type handlerConnWithRequestMethod interface {
StreamingHandlerConn
getMethod() string
}

type clientConnWithRequestMethod interface {
StreamingClientConn
onSetMethod(fn func(string))
}

// receiveUnaryResponse unmarshals a message from a StreamingClientConn, then
// envelopes the message and attaches headers and trailers. It attempts to
// consume the response stream and isn't appropriate when receiving multiple
Expand Down
8 changes: 4 additions & 4 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type duplexHTTPCall struct {
ctx context.Context
httpClient HTTPClient
streamType StreamType
onSetMethod func(method string)
onRequestSend func(*http.Request)
validateResponse func(*http.Response) *Error

// We'll use a pipe as the request body. We hand the read side of the pipe to
Expand Down Expand Up @@ -151,9 +151,6 @@ func (d *duplexHTTPCall) URL() *url.URL {
// SetMethod changes the method of the request before it is sent.
func (d *duplexHTTPCall) SetMethod(method string) {
d.request.Method = method
if d.onSetMethod != nil {
d.onSetMethod(method)
}
}

// Read from the response body. Returns the first error passed to SetError.
Expand Down Expand Up @@ -259,6 +256,9 @@ func (d *duplexHTTPCall) makeRequest() {
// on d.responseReady, so we can't race with them.
defer close(d.responseReady)

if d.onRequestSend != nil {
d.onRequestSend(d.request)
}
// Once we send a message to the server, they send a message back and
// establish the receive side of the stream.
response, err := d.httpClient.Do(d.request) //nolint:bodyclose
Expand Down
5 changes: 3 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func NewUnaryHandler[Req, Res any](
return err
}
method := http.MethodPost
if hasRequestMethod, ok := conn.(handlerConnWithRequestMethod); ok {
method = hasRequestMethod.getMethod()
if hasRequestMethod, ok := conn.(interface{ getHTTPMethod() string }); ok {
method = hasRequestMethod.getHTTPMethod()
}
request := &Request[Req]{
Msg: &msg,
Expand Down Expand Up @@ -146,6 +146,7 @@ func NewServerStreamHandler[Req, Res any](
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
method: http.MethodPost,
},
&ServerStream[Res]{conn: conn},
)
Expand Down
123 changes: 123 additions & 0 deletions interceptor_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package connect_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"

"github.com/bufbuild/connect-go"
Expand Down Expand Up @@ -66,6 +68,8 @@ func TestOnionOrderingEndToEnd(t *testing.T) {
}
}

var client1, client2, client3, handler1, handler2, handler3 atomic.Int32

// The client and handler interceptor onions are the meat of the test. The
// order of interceptor execution must be the same for unary and streaming
// procedures.
Expand All @@ -79,6 +83,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) {
// intended order clear.
clientOnion := connect.WithInterceptors(
newHeaderInterceptor(
&client1,
// 1 (start). request: should see protocol-related headers
func(_ connect.Spec, h http.Header) {
assert.NotZero(t, h.Get("Content-Type"))
Expand All @@ -87,24 +92,29 @@ func TestOnionOrderingEndToEnd(t *testing.T) {
assertAllPresent,
),
newHeaderInterceptor(
&client2,
newInspector("", "one"), // 2. request: add header "one"
newInspector("three", "four"), // 11. response: check "three", add "four"
),
newHeaderInterceptor(
&client3,
newInspector("one", "two"), // 3. request: check "one", add "two"
newInspector("two", "three"), // 10. response: check "two", add "three"
),
)
handlerOnion := connect.WithInterceptors(
newHeaderInterceptor(
&handler1,
newInspector("two", "three"), // 4. request: check "two", add "three"
newInspector("one", "two"), // 9. response: check "one", add "two"
),
newHeaderInterceptor(
&handler2,
newInspector("three", "four"), // 5. request: check "three", add "four"
newInspector("", "one"), // 8. response: add "one"
),
newHeaderInterceptor(
&handler3,
assertAllPresent, // 6. request: check "one"-"four"
nil, // 7. response: no-op
),
Expand All @@ -129,6 +139,14 @@ func TestOnionOrderingEndToEnd(t *testing.T) {
_, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10}))
assert.Nil(t, err)

// make sure the interceptors were actually invoked
assert.Equal(t, int32(1), client1.Load())
assert.Equal(t, int32(1), client2.Load())
assert.Equal(t, int32(1), client3.Load())
assert.Equal(t, int32(1), handler1.Load())
assert.Equal(t, int32(1), handler2.Load())
assert.Equal(t, int32(1), handler3.Load())

responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10}))
assert.Nil(t, err)
var sum int64
Expand All @@ -137,6 +155,14 @@ func TestOnionOrderingEndToEnd(t *testing.T) {
}
assert.Equal(t, sum, 55)
assert.Nil(t, responses.Close())

// make sure the interceptors were invoked again
assert.Equal(t, int32(2), client1.Load())
assert.Equal(t, int32(2), client2.Load())
assert.Equal(t, int32(2), client3.Load())
assert.Equal(t, int32(2), handler1.Load())
assert.Equal(t, int32(2), handler2.Load())
assert.Equal(t, int32(2), handler3.Load())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 appreciate the improvement here!

}

func TestEmptyUnaryInterceptorFunc(t *testing.T) {
Expand Down Expand Up @@ -166,24 +192,75 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) {
assert.Nil(t, countUpStream.Close())
}

func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) {
t.Parallel()
clientChecker := &httpMethodChecker{client: true}
handlerChecker := &httpMethodChecker{}

mux := http.NewServeMux()
mux.Handle(
pingv1connect.NewPingServiceHandler(
pingServer{},
connect.WithInterceptors(handlerChecker),
),
)
server := httptest.NewServer(mux)
defer server.Close()

client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithInterceptors(clientChecker),
)

pingReq := connect.NewRequest(&pingv1.PingRequest{Number: 10})
assert.Equal(t, "", pingReq.HTTPMethod())
_, err := client.Ping(context.Background(), pingReq)
assert.Nil(t, err)
assert.Equal(t, http.MethodPost, pingReq.HTTPMethod())

// make sure interceptor was invoked
assert.Equal(t, int32(1), clientChecker.count.Load())
assert.Equal(t, int32(1), handlerChecker.count.Load())

countUpReq := connect.NewRequest(&pingv1.CountUpRequest{Number: 10})
assert.Equal(t, "", countUpReq.HTTPMethod())
responses, err := client.CountUp(context.Background(), countUpReq)
assert.Nil(t, err)
var sum int64
for responses.Receive() {
sum += responses.Msg().Number
}
assert.Equal(t, sum, 55)
assert.Nil(t, responses.Close())
assert.Equal(t, http.MethodPost, countUpReq.HTTPMethod())

// make sure interceptor was invoked again
assert.Equal(t, int32(2), clientChecker.count.Load())
assert.Equal(t, int32(2), handlerChecker.count.Load())
}

// headerInterceptor makes it easier to write interceptors that inspect or
// mutate HTTP headers. It applies the same logic to unary and streaming
// procedures, wrapping the send or receive side of the stream as appropriate.
//
// It's useful as a testing harness to make sure that we're chaining
// interceptors in the correct order.
type headerInterceptor struct {
counter *atomic.Int32
inspectRequestHeader func(connect.Spec, http.Header)
inspectResponseHeader func(connect.Spec, http.Header)
}

// newHeaderInterceptor constructs a headerInterceptor. Nil function pointers
// are treated as no-ops.
func newHeaderInterceptor(
counter *atomic.Int32,
inspectRequestHeader func(connect.Spec, http.Header),
inspectResponseHeader func(connect.Spec, http.Header),
) *headerInterceptor {
interceptor := headerInterceptor{
counter: counter,
inspectRequestHeader: inspectRequestHeader,
inspectResponseHeader: inspectResponseHeader,
}
Expand All @@ -198,6 +275,7 @@ func newHeaderInterceptor(

func (h *headerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
h.counter.Add(1)
h.inspectRequestHeader(req.Spec(), req.Header())
res, err := next(ctx, req)
if err != nil {
Expand All @@ -210,6 +288,7 @@ func (h *headerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc

func (h *headerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
h.counter.Add(1)
return &headerInspectingClientConn{
StreamingClientConn: next(ctx, spec),
inspectRequestHeader: h.inspectRequestHeader,
Expand All @@ -220,6 +299,7 @@ func (h *headerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc

func (h *headerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
h.counter.Add(1)
h.inspectRequestHeader(conn.Spec(), conn.RequestHeader())
return next(ctx, &headerInspectingHandlerConn{
StreamingHandlerConn: conn,
Expand Down Expand Up @@ -268,3 +348,46 @@ func (cc *headerInspectingClientConn) Receive(msg any) error {
}
return err
}

type httpMethodChecker struct {
client bool
count atomic.Int32
}

func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
h.count.Add(1)
if h.client {
// should be blank until after we make request
if req.HTTPMethod() != "" {
return nil, fmt.Errorf("expected blank HTTP method but instead got %q", req.HTTPMethod())
}
} else {
// server interceptors see method from the start
if req.HTTPMethod() != http.MethodPost {
return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod())
}
}
resp, err := unaryFunc(ctx, req)
if req.HTTPMethod() != http.MethodPost {
return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod())
}
return resp, err
}
}

func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
// method not exposed to streaming interceptor, but that's okay because it's always POST for streams
h.count.Add(1)
return clientFunc(ctx, spec)
}
}

func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
// method not exposed to streaming interceptor, but that's okay because it's always POST for streams
h.count.Add(1)
return handlerFunc(ctx, conn)
}
}
Loading