Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Add support for http.MaxBytesHandler (connectrpc#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayjshah authored Sep 3, 2022
1 parent 89b4d66 commit ffe5175
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 0 deletions.
89 changes: 89 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,95 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
})
}

func TestHandlerWithHTTPMaxBytes(t *testing.T) {
// This is similar to Connect's own ReadMaxBytes option, but applied to the
// whole stream using the stdlib's http.MaxBytesHandler.
t.Parallel()
const readMaxBytes = 128
mux := http.NewServeMux()
pingRoute, pingHandler := pingv1connect.NewPingServiceHandler(pingServer{})
mux.Handle(pingRoute, http.MaxBytesHandler(pingHandler, readMaxBytes))
run := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) {
t.Helper()
t.Run("below_read_max", func(t *testing.T) {
t.Parallel()
_, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{}))
assert.Nil(t, err)
})
t.Run("just_above_max", func(t *testing.T) {
t.Parallel()
pingRequest := &pingv1.PingRequest{Text: strings.Repeat("a", readMaxBytes*10)}
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
if compressed {
compressedSize := gzipCompressedSize(t, pingRequest)
assert.True(t, compressedSize < readMaxBytes, assert.Sprintf("expected compressed size %d < %d", compressedSize, readMaxBytes))
assert.Nil(t, err)
return
}
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
})
t.Run("read_max_large", func(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skipf("skipping %s test in short mode", t.Name())
}
pingRequest := &pingv1.PingRequest{Text: strings.Repeat("abcde", 1024*1024)}
if compressed {
expectedSize := gzipCompressedSize(t, pingRequest)
assert.True(t, expectedSize > readMaxBytes, assert.Sprintf("expected compressed size %d > %d", expectedSize, readMaxBytes))
}
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
})
}
newHTTP2Server := func(t *testing.T) *httptest.Server {
t.Helper()
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)
return server
}
t.Run("connect", func(t *testing.T) {
t.Parallel()
server := newHTTP2Server(t)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL)
run(t, client, false)
})
t.Run("connect_gzip", func(t *testing.T) {
t.Parallel()
server := newHTTP2Server(t)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendGzip())
run(t, client, true)
})
t.Run("grpc", func(t *testing.T) {
t.Parallel()
server := newHTTP2Server(t)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC())
run(t, client, false)
})
t.Run("grpc_gzip", func(t *testing.T) {
t.Parallel()
server := newHTTP2Server(t)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC(), connect.WithSendGzip())
run(t, client, true)
})
t.Run("grpcweb", func(t *testing.T) {
t.Parallel()
server := newHTTP2Server(t)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb())
run(t, client, false)
})
t.Run("grpcweb_gzip", func(t *testing.T) {
t.Parallel()
server := newHTTP2Server(t)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb(), connect.WithSendGzip())
run(t, client, true)
})
}

func TestClientWithReadMaxBytes(t *testing.T) {
t.Parallel()
createServer := func(tb testing.TB, enableCompression bool) *httptest.Server {
Expand Down
8 changes: 8 additions & 0 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ func (r *envelopeReader) Read(env *envelope) *Error {
if connectErr, ok := asError(err); ok {
return connectErr
}
if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
return errorf(
CodeInvalidArgument,
"protocol error: incomplete envelope: %w", err,
Expand All @@ -227,6 +231,10 @@ func (r *envelopeReader) Read(env *envelope) *Error {
for remaining > 0 {
bytesRead, err := io.CopyN(env.Data, r.reader, remaining)
if err != nil && !errors.Is(err, io.EOF) {
if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
return errorf(CodeUnknown, "read enveloped message: %w", err)
}
if errors.Is(err, io.EOF) && bytesRead == 0 {
Expand Down
32 changes: 32 additions & 0 deletions maxbytes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build go1.19

package connect

import (
"errors"
"fmt"
"net/http"
)

func asMaxBytesError(err error, tmpl string, args ...any) *Error {
var maxBytesErr *http.MaxBytesError
if ok := errors.As(err, &maxBytesErr); !ok {
return nil
}
prefix := fmt.Sprintf(tmpl, args...)
return errorf(CodeResourceExhausted, "%s: exceeded %d byte http.MaxBytesReader limit", prefix, maxBytesErr.Limit)
}
32 changes: 32 additions & 0 deletions maxbytes_go118.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build !go1.19

package connect

import (
"fmt"
"strings"
)

func asMaxBytesError(err error, tmpl string, args ...any) *Error {
const expect = "http: request body too large"
text := err.Error()
if text != expect && !strings.HasSuffix(text, ": "+expect) {
return nil
}
prefix := fmt.Sprintf(tmpl, args...)
return errorf(CodeResourceExhausted, "%s: exceeded http.MaxBytesReader limit", prefix)
}
5 changes: 5 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ func WithCompressMinBytes(min int) Option {
//
// Setting WithReadMaxBytes to zero allows any message size. Both clients and
// handlers default to allowing any request size.
//
// Handlers may also use [http.MaxBytesHandler] to limit the total size of the
// HTTP request stream (rather than the per-message size). Connect handles
// [http.MaxBytesError] specially, so clients still receive errors with the
// appropriate error code and informative messages.
func WithReadMaxBytes(max int) Option {
return &readMaxBytesOption{Max: max}
}
Expand Down
3 changes: 3 additions & 0 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,9 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by
if connectErr, ok := asError(err); ok {
return connectErr
}
if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", bytesRead); readMaxBytesErr != nil {
return readMaxBytesErr
}
return errorf(CodeUnknown, "read message: %w", err)
}
if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {
Expand Down

0 comments on commit ffe5175

Please sign in to comment.