From ffe51755e3515660b62a58dc41de5bc0e5f10e2d Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Fri, 2 Sep 2022 17:51:16 -0700 Subject: [PATCH] Add support for http.MaxBytesHandler (#355) --- connect_ext_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++++ envelope.go | 8 ++++ maxbytes.go | 32 ++++++++++++++++ maxbytes_go118.go | 32 ++++++++++++++++ option.go | 5 +++ protocol_connect.go | 3 ++ 6 files changed, 169 insertions(+) create mode 100644 maxbytes.go create mode 100644 maxbytes_go118.go diff --git a/connect_ext_test.go b/connect_ext_test.go index 83ced32c..4e4d3bc4 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -951,6 +951,95 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { }) } +func TestHandlerWithHTTPMaxBytes(t *testing.T) { + // This is similar to Connect's own ReadMaxBytes option, but applied to the + // whole stream using the stdlib's http.MaxBytesHandler. + t.Parallel() + const readMaxBytes = 128 + mux := http.NewServeMux() + pingRoute, pingHandler := pingv1connect.NewPingServiceHandler(pingServer{}) + mux.Handle(pingRoute, http.MaxBytesHandler(pingHandler, readMaxBytes)) + run := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + t.Helper() + t.Run("below_read_max", func(t *testing.T) { + t.Parallel() + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) + assert.Nil(t, err) + }) + t.Run("just_above_max", func(t *testing.T) { + t.Parallel() + pingRequest := &pingv1.PingRequest{Text: strings.Repeat("a", readMaxBytes*10)} + _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) + if compressed { + compressedSize := gzipCompressedSize(t, pingRequest) + assert.True(t, compressedSize < readMaxBytes, assert.Sprintf("expected compressed size %d < %d", compressedSize, readMaxBytes)) + assert.Nil(t, err) + return + } + assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) + assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) + }) + t.Run("read_max_large", func(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skipf("skipping %s test in short mode", t.Name()) + } + pingRequest := &pingv1.PingRequest{Text: strings.Repeat("abcde", 1024*1024)} + if compressed { + expectedSize := gzipCompressedSize(t, pingRequest) + assert.True(t, expectedSize > readMaxBytes, assert.Sprintf("expected compressed size %d > %d", expectedSize, readMaxBytes)) + } + _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) + assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) + assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) + }) + } + newHTTP2Server := func(t *testing.T) *httptest.Server { + t.Helper() + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + t.Cleanup(server.Close) + return server + } + t.Run("connect", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + run(t, client, false) + }) + t.Run("connect_gzip", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendGzip()) + run(t, client, true) + }) + t.Run("grpc", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + run(t, client, false) + }) + t.Run("grpc_gzip", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC(), connect.WithSendGzip()) + run(t, client, true) + }) + t.Run("grpcweb", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + run(t, client, false) + }) + t.Run("grpcweb_gzip", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb(), connect.WithSendGzip()) + run(t, client, true) + }) +} + func TestClientWithReadMaxBytes(t *testing.T) { t.Parallel() createServer := func(tb testing.TB, enableCompression bool) *httptest.Server { diff --git a/envelope.go b/envelope.go index 9275f460..0f259817 100644 --- a/envelope.go +++ b/envelope.go @@ -201,6 +201,10 @@ func (r *envelopeReader) Read(env *envelope) *Error { if connectErr, ok := asError(err); ok { return connectErr } + if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil { + // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. + return maxBytesErr + } return errorf( CodeInvalidArgument, "protocol error: incomplete envelope: %w", err, @@ -227,6 +231,10 @@ func (r *envelopeReader) Read(env *envelope) *Error { for remaining > 0 { bytesRead, err := io.CopyN(env.Data, r.reader, remaining) if err != nil && !errors.Is(err, io.EOF) { + if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil { + // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. + return maxBytesErr + } return errorf(CodeUnknown, "read enveloped message: %w", err) } if errors.Is(err, io.EOF) && bytesRead == 0 { diff --git a/maxbytes.go b/maxbytes.go new file mode 100644 index 00000000..455dba2b --- /dev/null +++ b/maxbytes.go @@ -0,0 +1,32 @@ +// Copyright 2021-2022 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.19 + +package connect + +import ( + "errors" + "fmt" + "net/http" +) + +func asMaxBytesError(err error, tmpl string, args ...any) *Error { + var maxBytesErr *http.MaxBytesError + if ok := errors.As(err, &maxBytesErr); !ok { + return nil + } + prefix := fmt.Sprintf(tmpl, args...) + return errorf(CodeResourceExhausted, "%s: exceeded %d byte http.MaxBytesReader limit", prefix, maxBytesErr.Limit) +} diff --git a/maxbytes_go118.go b/maxbytes_go118.go new file mode 100644 index 00000000..32a3e022 --- /dev/null +++ b/maxbytes_go118.go @@ -0,0 +1,32 @@ +// Copyright 2021-2022 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !go1.19 + +package connect + +import ( + "fmt" + "strings" +) + +func asMaxBytesError(err error, tmpl string, args ...any) *Error { + const expect = "http: request body too large" + text := err.Error() + if text != expect && !strings.HasSuffix(text, ": "+expect) { + return nil + } + prefix := fmt.Sprintf(tmpl, args...) + return errorf(CodeResourceExhausted, "%s: exceeded http.MaxBytesReader limit", prefix) +} diff --git a/option.go b/option.go index b5e52b31..9697245c 100644 --- a/option.go +++ b/option.go @@ -197,6 +197,11 @@ func WithCompressMinBytes(min int) Option { // // Setting WithReadMaxBytes to zero allows any message size. Both clients and // handlers default to allowing any request size. +// +// Handlers may also use [http.MaxBytesHandler] to limit the total size of the +// HTTP request stream (rather than the per-message size). Connect handles +// [http.MaxBytesError] specially, so clients still receive errors with the +// appropriate error code and informative messages. func WithReadMaxBytes(max int) Option { return &readMaxBytesOption{Max: max} } diff --git a/protocol_connect.go b/protocol_connect.go index f7c6d81e..7b845da8 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -781,6 +781,9 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by if connectErr, ok := asError(err); ok { return connectErr } + if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", bytesRead); readMaxBytesErr != nil { + return readMaxBytesErr + } return errorf(CodeUnknown, "read message: %w", err) } if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {