Skip to content

Commit

Permalink
Fix server-side timeout handling (#210)
Browse files Browse the repository at this point in the history
* Add test for unary RPC with zero-byte messages

* Add failing test for handler timeout handling

We're parsing timeouts, but not properly propagating them into user
code. Thanks, crosstests!

* Fix timeout handling

Since we know the shape of the Connect protocol, we can simplify the
protocol interfaces and move some shared utility functions into
`protocol.go`. This also fixes server-side timeout handling.

* Keep Accept-Post string manipulation shorter

We're only doing this at startup, so it's okay to make it slow. The code
doesn't get much shorter, but it's arguably more readable.

* Add indirection to constant limit

Move the literal for our discard limit into a constant.
  • Loading branch information
akshayjshah authored May 14, 2022
1 parent 5fd72bc commit 7d6e9a4
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 109 deletions.
38 changes: 37 additions & 1 deletion connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"sync"
"testing"
"time"

"connectrpc.com/connect"
"connectrpc.com/connect/internal/assert"
Expand Down Expand Up @@ -59,7 +60,17 @@ func TestServer(t *testing.T) {
assert.Equal(t, response.Header().Get(handlerHeader), headerValue)
assert.Equal(t, response.Trailer().Get(handlerTrailer), trailerValue)
})
t.Run("large ping", func(t *testing.T) {
t.Run("zero_ping", func(t *testing.T) {
request := connect.NewRequest(&pingv1.PingRequest{})
request.Header().Set(clientHeader, headerValue)
response, err := client.Ping(context.Background(), request)
assert.Nil(t, err)
var expect pingv1.PingResponse
assert.Equal(t, response.Msg, &expect)
assert.Equal(t, response.Header().Get(handlerHeader), headerValue)
assert.Equal(t, response.Trailer().Get(handlerTrailer), trailerValue)
})
t.Run("large_ping", func(t *testing.T) {
// Using a large payload splits the request and response over multiple
// packets, ensuring that we're managing HTTP readers and writers
// correctly.
Expand Down Expand Up @@ -361,6 +372,31 @@ func TestHeaderBasic(t *testing.T) {
assert.Equal(t, response.Header().Get(key), hval)
}

func TestTimeoutParsing(t *testing.T) {
t.Parallel()
const timeout = 10 * time.Minute
pingServer := &pluggablePingServer{
ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
deadline, ok := ctx.Deadline()
assert.True(t, ok)
remaining := time.Until(deadline)
assert.True(t, remaining > 0)
assert.True(t, remaining <= timeout)
return connect.NewResponse(&pingv1.PingResponse{}), nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := httptest.NewServer(mux)
defer server.Close()

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC())
_, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{}))
assert.Nil(t, err)
}

func TestMarshalStatusError(t *testing.T) {
t.Parallel()

Expand Down
61 changes: 32 additions & 29 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Handler struct {
interceptor Interceptor
implementation func(context.Context, Sender, Receiver, error /* client-visible */)
protocolHandlers []protocolHandler
acceptPost string // Accept-Post header
}

// NewUnaryHandler constructs a Handler for a request-response procedure.
Expand Down Expand Up @@ -101,6 +102,7 @@ func NewUnaryHandler[Req, Res any](
interceptor: nil, // already applied
implementation: implementation,
protocolHandlers: protocolHandlers,
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}

Expand Down Expand Up @@ -178,31 +180,31 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
// okay if we can't re-use the connection.
isBidi := (h.spec.StreamType & StreamTypeBidi) == StreamTypeBidi
if isBidi && request.ProtoMajor < 2 {
h.failNegotiation(responseWriter, http.StatusHTTPVersionNotSupported)
responseWriter.WriteHeader(http.StatusHTTPVersionNotSupported)
return
}

methodHandlers := make([]protocolHandler, 0, len(h.protocolHandlers))
for _, protocolHandler := range h.protocolHandlers {
if protocolHandler.ShouldHandleMethod(request.Method) {
methodHandlers = append(methodHandlers, protocolHandler)
}
}
if len(methodHandlers) == 0 {
// The gRPC-HTTP2, gRPC-Web, and Connect protocols are all POST-only.
if request.Method != http.MethodPost {
// grpc-go returns a 500 here, but interoperability with non-gRPC HTTP
// clients is better if we return a 405.
h.failNegotiation(responseWriter, http.StatusMethodNotAllowed)
responseWriter.Header().Set("Allow", http.MethodPost)
responseWriter.WriteHeader(http.StatusMethodNotAllowed)
return
}

// TODO: for GETs, we should parse the Accept header and offer each handler
// each content-type.
contentType := request.Header.Get("Content-Type")
for _, protocolHandler := range methodHandlers {
if !protocolHandler.ShouldHandleContentType(contentType) {
for _, protocolHandler := range h.protocolHandlers {
if _, ok := protocolHandler.ContentTypes()[contentType]; !ok {
continue
}
ctx := request.Context()
ctx, cancel, timeoutErr := protocolHandler.SetTimeout(request)
if timeoutErr != nil {
ctx = request.Context()
}
if cancel != nil {
defer cancel()
}
if ic := h.interceptor; ic != nil {
ctx = ic.WrapStreamContext(ctx)
}
Expand All @@ -211,11 +213,17 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
// timeout or an unavailable codec. We'd like those errors to be visible to
// the interceptor chain, so we're going to capture them here and pass them
// to the implementation.
sender, receiver, clientVisibleError := protocolHandler.NewStream(responseWriter, request.WithContext(ctx))
// If NewStream errored and the protocol doesn't want the error sent to
// the client, sender and/or receiver may be nil. We still want the
// error to be seen by interceptors, so we provide no-op Sender and
// Receiver implementations.
sender, receiver, clientVisibleError := protocolHandler.NewStream(
responseWriter,
request.WithContext(ctx),
)
if timeoutErr != nil {
clientVisibleError = timeoutErr
}
// If NewStream or SetTimeout errored and the protocol doesn't want the
// error sent to the client, sender and/or receiver may be nil. We still
// want the error to be seen by interceptors, so we provide no-op Sender
// and Receiver implementations.
if clientVisibleError != nil && sender == nil {
sender = newNopSender(h.spec, responseWriter.Header(), make(http.Header))
}
Expand All @@ -230,15 +238,8 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
h.implementation(ctx, sender, receiver, clientVisibleError)
return
}
h.failNegotiation(responseWriter, http.StatusUnsupportedMediaType)
}

func (h *Handler) failNegotiation(w http.ResponseWriter, code int) {
// None of the registered protocols is able to serve the request.
for _, ph := range h.protocolHandlers {
ph.WriteAccept(w.Header())
}
w.WriteHeader(code)
responseWriter.Header().Set("Accept-Post", h.acceptPost)
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
}

type handlerConfig struct {
Expand Down Expand Up @@ -308,6 +309,7 @@ func newStreamHandler(
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
protocolHandlers := config.newProtocolHandlers(streamType)
return &Handler{
spec: config.newSpec(streamType),
interceptor: config.Interceptor,
Expand All @@ -319,6 +321,7 @@ func newStreamHandler(
}
implementation(ctx, sender, receiver)
},
protocolHandlers: config.newProtocolHandlers(streamType),
protocolHandlers: protocolHandlers,
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}
10 changes: 0 additions & 10 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,6 @@ func percentDecodeSlow(bufferPool *bufferPool, encoded string, offset int) strin
return out.String()
}

// addCommaSeparatedHeader is a helper to produce headers like
// {"Allow": "GET, POST"}.
func addCommaSeparatedHeader(header http.Header, key, value string) {
if prev := header.Get(key); prev != "" {
header.Set(key, prev+", "+value)
} else {
header.Set(key, value)
}
}

func mergeHeaders(into, from http.Header) {
for k, vals := range from {
into[k] = append(into[k], vals...)
Expand Down
59 changes: 47 additions & 12 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ package connect

import (
"context"
"io"
"net/http"
"sort"
"strings"
)

const discardLimit = 1024 * 1024 * 4 // 4MiB

// A Protocol defines the HTTP semantics to use when sending and receiving
// messages. It ties together codecs, compressors, and net/http to produce
// Senders and Receivers.
Expand Down Expand Up @@ -55,18 +60,17 @@ type protocolHandlerParams struct {
// Handler is the server side of a protocol. HTTP handlers typically support
// multiple protocols, codecs, and compressors.
type protocolHandler interface {
// ShouldHandleMethod and ShouldHandleContentType check whether the protocol
// can serve requests with a given HTTP method and Content-Type. NewStream
// may assume that any checks in ShouldHandleMethod and
// ShouldHandleContentType have passed.
ShouldHandleMethod(string) bool
ShouldHandleContentType(string) bool

// If no protocol can serve a request, each protocol's WriteAccept method has
// a chance to write to the response headers. Protocols should write their
// supported HTTP methods to the Allow header, and they may write their
// supported content-types to the Accept-Post or Accept-Patch headers.
WriteAccept(http.Header)
// ContentTypes is the set of HTTP Content-Types that the protocol can
// handle.
ContentTypes() map[string]struct{}

// ParseTimeout runs before NewStream. Implementations may inspect the HTTP
// request, parse any timeout set by the client, and return a modified
// context and cancellation function.
//
// If the client didn't send a timeout, SetTimeout should return the
// request's context, a nil cancellation function, and a nil error.
SetTimeout(*http.Request) (context.Context, context.CancelFunc, error)

// NewStream constructs a Sender and Receiver for the message exchange.
//
Expand Down Expand Up @@ -155,3 +159,34 @@ func (r *errorTranslatingReceiver) Receive(msg any) error {
func (r *errorTranslatingReceiver) Close() error {
return r.fromWire(r.Receiver.Close())
}

func sortedAcceptPostValue(handlers []protocolHandler) string {
contentTypes := make(map[string]struct{})
for _, handler := range handlers {
for contentType := range handler.ContentTypes() {
contentTypes[contentType] = struct{}{}
}
}
accept := make([]string, 0, len(contentTypes))
for ct := range contentTypes {
accept = append(accept, ct)
}
sort.Strings(accept)
return strings.Join(accept, ", ")
}

func isCommaOrSpace(c rune) bool {
return c == ',' || c == ' '
}

func discard(reader io.Reader) error {
if lr, ok := reader.(*io.LimitedReader); ok {
_, err := io.Copy(io.Discard, lr)
return err
}
// We don't want to get stuck throwing data away forever, so limit how much
// we're willing to do here.
lr := &io.LimitedReader{R: reader, N: discardLimit}
_, err := io.Copy(io.Discard, lr)
return err
}
51 changes: 26 additions & 25 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,24 @@ type protocolGRPC struct {

// NewHandler implements protocol, so it must return an interface.
func (g *protocolGRPC) NewHandler(params *protocolHandlerParams) protocolHandler {
bare, prefix := typeDefaultGRPC, typeDefaultGRPCPrefix
if g.web {
bare, prefix = typeWebGRPC, typeWebGRPCPrefix
}
contentTypes := make(map[string]struct{})
for _, name := range params.Codecs.Names() {
contentTypes[prefix+name] = struct{}{}
}
if params.Codecs.Get(codecNameProto) != nil {
contentTypes[bare] = struct{}{}
}
return &grpcHandler{
spec: params.Spec,
web: g.web,
codecs: params.Codecs,
compressionPools: params.CompressionPools,
minCompressBytes: params.CompressMinBytes,
accept: acceptPostValue(g.web, params.Codecs),
accept: contentTypes,
bufferPool: params.BufferPool,
}
}
Expand Down Expand Up @@ -70,25 +81,26 @@ type grpcHandler struct {
codecs readOnlyCodecs
compressionPools readOnlyCompressionPools
minCompressBytes int
accept string
accept map[string]struct{}
bufferPool *bufferPool
}

func (g *grpcHandler) ShouldHandleMethod(method string) bool {
return method == http.MethodPost
func (g *grpcHandler) ContentTypes() map[string]struct{} {
return g.accept
}

func (g *grpcHandler) ShouldHandleContentType(contentType string) bool {
codecName := codecFromContentType(g.web, contentType)
if codecName == "" {
return false // not a gRPC content-type
func (g *grpcHandler) SetTimeout(request *http.Request) (context.Context, context.CancelFunc, error) {
timeout, err := parseTimeout(request.Header.Get("Grpc-Timeout"))
if err != nil && !errors.Is(err, errNoTimeout) {
// Errors here indicate that the client sent an invalid timeout header, so
// the error text is safe to send back.
return nil, nil, NewError(CodeInvalidArgument, err)
} else if err != nil {
// err wraps errNoTimeout, nothing to do.
return request.Context(), nil, nil
}
return g.codecs.Get(codecName) != nil
}

func (g *grpcHandler) WriteAccept(header http.Header) {
addCommaSeparatedHeader(header, "Allow", http.MethodPost)
addCommaSeparatedHeader(header, "Accept-Post", g.accept)
ctx, cancel := context.WithTimeout(request.Context(), timeout)
return ctx, cancel, nil
}

func (g *grpcHandler) NewStream(
Expand All @@ -105,17 +117,6 @@ func (g *grpcHandler) NewStream(
// will send the error to the client.
var failed *Error

timeout, err := parseTimeout(request.Header.Get("Grpc-Timeout"))
if err != nil && !errors.Is(err, errNoTimeout) {
// Errors here indicate that the client sent an invalid timeout header, so
// the error text is safe to send back.
failed = NewError(CodeInvalidArgument, err)
} else if err == nil {
ctx, cancel := context.WithTimeout(request.Context(), timeout)
defer cancel()
request = request.WithContext(ctx)
} // else err wraps errNoTimeout, nothing to do

requestCompression := compressionIdentity
if msgEncoding := request.Header.Get("Grpc-Encoding"); msgEncoding != "" && msgEncoding != compressionIdentity {
// We default to identity, so we only care if the client sends something
Expand Down
Loading

0 comments on commit 7d6e9a4

Please sign in to comment.