Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix server-side timeout handling #210

Merged
merged 5 commits into from
May 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"sync"
"testing"
"time"

"github.com/bufbuild/connect-go"
"github.com/bufbuild/connect-go/internal/assert"
Expand Down Expand Up @@ -59,7 +60,17 @@ func TestServer(t *testing.T) {
assert.Equal(t, response.Header().Get(handlerHeader), headerValue)
assert.Equal(t, response.Trailer().Get(handlerTrailer), trailerValue)
})
t.Run("large ping", func(t *testing.T) {
t.Run("zero_ping", func(t *testing.T) {
request := connect.NewRequest(&pingv1.PingRequest{})
request.Header().Set(clientHeader, headerValue)
response, err := client.Ping(context.Background(), request)
assert.Nil(t, err)
var expect pingv1.PingResponse
assert.Equal(t, response.Msg, &expect)
assert.Equal(t, response.Header().Get(handlerHeader), headerValue)
assert.Equal(t, response.Trailer().Get(handlerTrailer), trailerValue)
})
t.Run("large_ping", func(t *testing.T) {
// Using a large payload splits the request and response over multiple
// packets, ensuring that we're managing HTTP readers and writers
// correctly.
Expand Down Expand Up @@ -361,6 +372,31 @@ func TestHeaderBasic(t *testing.T) {
assert.Equal(t, response.Header().Get(key), hval)
}

func TestTimeoutParsing(t *testing.T) {
t.Parallel()
const timeout = 10 * time.Minute
pingServer := &pluggablePingServer{
ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
deadline, ok := ctx.Deadline()
assert.True(t, ok)
remaining := time.Until(deadline)
assert.True(t, remaining > 0)
assert.True(t, remaining <= timeout)
return connect.NewResponse(&pingv1.PingResponse{}), nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := httptest.NewServer(mux)
defer server.Close()

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC())
_, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{}))
assert.Nil(t, err)
}

func TestMarshalStatusError(t *testing.T) {
t.Parallel()

Expand Down
61 changes: 32 additions & 29 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Handler struct {
interceptor Interceptor
implementation func(context.Context, Sender, Receiver, error /* client-visible */)
protocolHandlers []protocolHandler
acceptPost string // Accept-Post header
}

// NewUnaryHandler constructs a Handler for a request-response procedure.
Expand Down Expand Up @@ -101,6 +102,7 @@ func NewUnaryHandler[Req, Res any](
interceptor: nil, // already applied
implementation: implementation,
protocolHandlers: protocolHandlers,
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}

Expand Down Expand Up @@ -178,31 +180,31 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
// okay if we can't re-use the connection.
isBidi := (h.spec.StreamType & StreamTypeBidi) == StreamTypeBidi
if isBidi && request.ProtoMajor < 2 {
h.failNegotiation(responseWriter, http.StatusHTTPVersionNotSupported)
responseWriter.WriteHeader(http.StatusHTTPVersionNotSupported)
return
}

methodHandlers := make([]protocolHandler, 0, len(h.protocolHandlers))
for _, protocolHandler := range h.protocolHandlers {
if protocolHandler.ShouldHandleMethod(request.Method) {
methodHandlers = append(methodHandlers, protocolHandler)
}
}
if len(methodHandlers) == 0 {
// The gRPC-HTTP2, gRPC-Web, and Connect protocols are all POST-only.
if request.Method != http.MethodPost {
// grpc-go returns a 500 here, but interoperability with non-gRPC HTTP
// clients is better if we return a 405.
h.failNegotiation(responseWriter, http.StatusMethodNotAllowed)
responseWriter.Header().Set("Allow", http.MethodPost)
responseWriter.WriteHeader(http.StatusMethodNotAllowed)
return
}

// TODO: for GETs, we should parse the Accept header and offer each handler
// each content-type.
contentType := request.Header.Get("Content-Type")
for _, protocolHandler := range methodHandlers {
if !protocolHandler.ShouldHandleContentType(contentType) {
for _, protocolHandler := range h.protocolHandlers {
if _, ok := protocolHandler.ContentTypes()[contentType]; !ok {
continue
}
ctx := request.Context()
ctx, cancel, timeoutErr := protocolHandler.SetTimeout(request)
if timeoutErr != nil {
ctx = request.Context()
}
if cancel != nil {
defer cancel()
}
if ic := h.interceptor; ic != nil {
ctx = ic.WrapStreamContext(ctx)
}
Expand All @@ -211,11 +213,17 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
// timeout or an unavailable codec. We'd like those errors to be visible to
// the interceptor chain, so we're going to capture them here and pass them
// to the implementation.
sender, receiver, clientVisibleError := protocolHandler.NewStream(responseWriter, request.WithContext(ctx))
// If NewStream errored and the protocol doesn't want the error sent to
// the client, sender and/or receiver may be nil. We still want the
// error to be seen by interceptors, so we provide no-op Sender and
// Receiver implementations.
sender, receiver, clientVisibleError := protocolHandler.NewStream(
responseWriter,
request.WithContext(ctx),
)
if timeoutErr != nil {
clientVisibleError = timeoutErr
}
// If NewStream or SetTimeout errored and the protocol doesn't want the
// error sent to the client, sender and/or receiver may be nil. We still
// want the error to be seen by interceptors, so we provide no-op Sender
// and Receiver implementations.
if clientVisibleError != nil && sender == nil {
sender = newNopSender(h.spec, responseWriter.Header(), make(http.Header))
}
Expand All @@ -230,15 +238,8 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
h.implementation(ctx, sender, receiver, clientVisibleError)
return
}
h.failNegotiation(responseWriter, http.StatusUnsupportedMediaType)
}

func (h *Handler) failNegotiation(w http.ResponseWriter, code int) {
// None of the registered protocols is able to serve the request.
for _, ph := range h.protocolHandlers {
ph.WriteAccept(w.Header())
}
w.WriteHeader(code)
responseWriter.Header().Set("Accept-Post", h.acceptPost)
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
}

type handlerConfig struct {
Expand Down Expand Up @@ -308,6 +309,7 @@ func newStreamHandler(
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
protocolHandlers := config.newProtocolHandlers(streamType)
return &Handler{
spec: config.newSpec(streamType),
interceptor: config.Interceptor,
Expand All @@ -319,6 +321,7 @@ func newStreamHandler(
}
implementation(ctx, sender, receiver)
},
protocolHandlers: config.newProtocolHandlers(streamType),
protocolHandlers: protocolHandlers,
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}
10 changes: 0 additions & 10 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,6 @@ func percentDecodeSlow(bufferPool *bufferPool, encoded string, offset int) strin
return out.String()
}

// addCommaSeparatedHeader is a helper to produce headers like
// {"Allow": "GET, POST"}.
func addCommaSeparatedHeader(header http.Header, key, value string) {
if prev := header.Get(key); prev != "" {
header.Set(key, prev+", "+value)
} else {
header.Set(key, value)
}
}

func mergeHeaders(into, from http.Header) {
for k, vals := range from {
into[k] = append(into[k], vals...)
Expand Down
59 changes: 47 additions & 12 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ package connect

import (
"context"
"io"
"net/http"
"sort"
"strings"
)

const discardLimit = 1024 * 1024 * 4 // 4MiB

// A Protocol defines the HTTP semantics to use when sending and receiving
// messages. It ties together codecs, compressors, and net/http to produce
// Senders and Receivers.
Expand Down Expand Up @@ -55,18 +60,17 @@ type protocolHandlerParams struct {
// Handler is the server side of a protocol. HTTP handlers typically support
// multiple protocols, codecs, and compressors.
type protocolHandler interface {
// ShouldHandleMethod and ShouldHandleContentType check whether the protocol
// can serve requests with a given HTTP method and Content-Type. NewStream
// may assume that any checks in ShouldHandleMethod and
// ShouldHandleContentType have passed.
ShouldHandleMethod(string) bool
ShouldHandleContentType(string) bool

// If no protocol can serve a request, each protocol's WriteAccept method has
// a chance to write to the response headers. Protocols should write their
// supported HTTP methods to the Allow header, and they may write their
// supported content-types to the Accept-Post or Accept-Patch headers.
WriteAccept(http.Header)
// ContentTypes is the set of HTTP Content-Types that the protocol can
// handle.
ContentTypes() map[string]struct{}

// ParseTimeout runs before NewStream. Implementations may inspect the HTTP
// request, parse any timeout set by the client, and return a modified
// context and cancellation function.
//
// If the client didn't send a timeout, SetTimeout should return the
// request's context, a nil cancellation function, and a nil error.
SetTimeout(*http.Request) (context.Context, context.CancelFunc, error)

// NewStream constructs a Sender and Receiver for the message exchange.
//
Expand Down Expand Up @@ -155,3 +159,34 @@ func (r *errorTranslatingReceiver) Receive(msg any) error {
func (r *errorTranslatingReceiver) Close() error {
return r.fromWire(r.Receiver.Close())
}

func sortedAcceptPostValue(handlers []protocolHandler) string {
contentTypes := make(map[string]struct{})
for _, handler := range handlers {
for contentType := range handler.ContentTypes() {
contentTypes[contentType] = struct{}{}
}
}
accept := make([]string, 0, len(contentTypes))
for ct := range contentTypes {
accept = append(accept, ct)
}
sort.Strings(accept)
return strings.Join(accept, ", ")
}

func isCommaOrSpace(c rune) bool {
return c == ',' || c == ' '
}

func discard(reader io.Reader) error {
if lr, ok := reader.(*io.LimitedReader); ok {
_, err := io.Copy(io.Discard, lr)
return err
}
// We don't want to get stuck throwing data away forever, so limit how much
// we're willing to do here.
lr := &io.LimitedReader{R: reader, N: discardLimit}
_, err := io.Copy(io.Discard, lr)
return err
}
51 changes: 26 additions & 25 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,24 @@ type protocolGRPC struct {

// NewHandler implements protocol, so it must return an interface.
func (g *protocolGRPC) NewHandler(params *protocolHandlerParams) protocolHandler {
bare, prefix := typeDefaultGRPC, typeDefaultGRPCPrefix
if g.web {
bare, prefix = typeWebGRPC, typeWebGRPCPrefix
}
contentTypes := make(map[string]struct{})
for _, name := range params.Codecs.Names() {
contentTypes[prefix+name] = struct{}{}
}
if params.Codecs.Get(codecNameProto) != nil {
contentTypes[bare] = struct{}{}
}
return &grpcHandler{
spec: params.Spec,
web: g.web,
codecs: params.Codecs,
compressionPools: params.CompressionPools,
minCompressBytes: params.CompressMinBytes,
accept: acceptPostValue(g.web, params.Codecs),
accept: contentTypes,
bufferPool: params.BufferPool,
}
}
Expand Down Expand Up @@ -70,25 +81,26 @@ type grpcHandler struct {
codecs readOnlyCodecs
compressionPools readOnlyCompressionPools
minCompressBytes int
accept string
accept map[string]struct{}
bufferPool *bufferPool
}

func (g *grpcHandler) ShouldHandleMethod(method string) bool {
return method == http.MethodPost
func (g *grpcHandler) ContentTypes() map[string]struct{} {
return g.accept
}

func (g *grpcHandler) ShouldHandleContentType(contentType string) bool {
codecName := codecFromContentType(g.web, contentType)
if codecName == "" {
return false // not a gRPC content-type
func (g *grpcHandler) SetTimeout(request *http.Request) (context.Context, context.CancelFunc, error) {
timeout, err := parseTimeout(request.Header.Get("Grpc-Timeout"))
if err != nil && !errors.Is(err, errNoTimeout) {
// Errors here indicate that the client sent an invalid timeout header, so
// the error text is safe to send back.
return nil, nil, NewError(CodeInvalidArgument, err)
} else if err != nil {
// err wraps errNoTimeout, nothing to do.
return request.Context(), nil, nil
}
return g.codecs.Get(codecName) != nil
}

func (g *grpcHandler) WriteAccept(header http.Header) {
addCommaSeparatedHeader(header, "Allow", http.MethodPost)
addCommaSeparatedHeader(header, "Accept-Post", g.accept)
ctx, cancel := context.WithTimeout(request.Context(), timeout)
return ctx, cancel, nil
}

func (g *grpcHandler) NewStream(
Expand All @@ -105,17 +117,6 @@ func (g *grpcHandler) NewStream(
// will send the error to the client.
var failed *Error

timeout, err := parseTimeout(request.Header.Get("Grpc-Timeout"))
if err != nil && !errors.Is(err, errNoTimeout) {
// Errors here indicate that the client sent an invalid timeout header, so
// the error text is safe to send back.
failed = NewError(CodeInvalidArgument, err)
} else if err == nil {
ctx, cancel := context.WithTimeout(request.Context(), timeout)
defer cancel()
request = request.WithContext(ctx)
} // else err wraps errNoTimeout, nothing to do

requestCompression := compressionIdentity
if msgEncoding := request.Header.Get("Grpc-Encoding"); msgEncoding != "" && msgEncoding != compressionIdentity {
// We default to identity, so we only care if the client sends something
Expand Down
Loading