Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pool buffers per handler and client #192

Merged
merged 10 commits into from
Apr 9, 2022
12 changes: 7 additions & 5 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,23 @@ func BenchmarkConnect(b *testing.B) {
connect.WithGzipRequests(),
)
assert.Nil(b, err)
twoMiB := strings.Repeat("a", 2*1024*1024)
b.ResetTimer()

b.Run("unary", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = client.Ping(
context.Background(),
connect.NewRequest(&pingv1.PingRequest{Number: 42}),
connect.NewRequest(&pingv1.PingRequest{Text: twoMiB}),
)
}
})
})
}

type ping struct {
Number int `json:"number"`
Text string `json:"text"`
}

func BenchmarkREST(b *testing.B) {
Expand Down Expand Up @@ -118,23 +119,24 @@ func BenchmarkREST(b *testing.B) {
server.EnableHTTP2 = true
server.StartTLS()
defer server.Close()
twoMiB := strings.Repeat("a", 2*1024*1024)
b.ResetTimer()

b.Run("unary", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
unaryRESTIteration(b, server.Client(), server.URL)
unaryRESTIteration(b, server.Client(), server.URL, twoMiB)
}
})
})
}

func unaryRESTIteration(b *testing.B, client *http.Client, url string) {
func unaryRESTIteration(b *testing.B, client *http.Client, url string, text string) {
b.Helper()
rawRequestBody := bytes.NewBuffer(nil)
compressedRequestBody := gzip.NewWriter(rawRequestBody)
encoder := json.NewEncoder(compressedRequestBody)
if err := encoder.Encode(&ping{42}); err != nil {
if err := encoder.Encode(&ping{text}); err != nil {
b.Fatalf("marshal request: %v", err)
}
compressedRequestBody.Close()
Expand Down
54 changes: 54 additions & 0 deletions buffer_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect

import (
"bytes"
"sync"
)

const (
initialBufferSize = 512
maxRecycleBufferSize = 8 * 1024 * 1024 // if >8MiB, don't hold onto a buffer
)

type bufferPool struct {
sync.Pool
}

func newBufferPool() *bufferPool {
return &bufferPool{
Pool: sync.Pool{
New: func() any {
return bytes.NewBuffer(make([]byte, 0, initialBufferSize))
},
},
}
}

func (b *bufferPool) Get() *bytes.Buffer {
if buf, ok := b.Pool.Get().(*bytes.Buffer); ok {
return buf
}
return bytes.NewBuffer(make([]byte, 0, initialBufferSize))
}

func (b *bufferPool) Put(buffer *bytes.Buffer) {
if buffer.Cap() > maxRecycleBufferSize {
return
}
buffer.Reset()
b.Pool.Put(buffer)
}
3 changes: 3 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func NewClient[Req, Res any](
CompressMinBytes: config.CompressMinBytes,
HTTPClient: httpClient,
URL: url,
BufferPool: config.BufferPool,
})
if protocolErr != nil {
return nil, protocolErr
Expand Down Expand Up @@ -170,13 +171,15 @@ type clientConfiguration struct {
CompressionPools map[string]*compressionPool
Codec Codec
RequestCompressionName string
BufferPool *bufferPool
}

func newClientConfiguration(url string, options []ClientOption) (*clientConfiguration, *Error) {
protoPath := extractProtobufPath(url)
config := clientConfiguration{
Procedure: protoPath,
CompressionPools: make(map[string]*compressionPool),
BufferPool: newBufferPool(),
}
WithProtoBinaryCodec().applyToClient(&config)
WithGzip().applyToClient(&config)
Expand Down
3 changes: 3 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ type handlerConfiguration struct {
Procedure string
HandleGRPC bool
HandleGRPCWeb bool
BufferPool *bufferPool
}

func newHandlerConfiguration(procedure string, options []HandlerOption) *handlerConfiguration {
Expand All @@ -259,6 +260,7 @@ func newHandlerConfiguration(procedure string, options []HandlerOption) *handler
Codecs: make(map[string]Codec),
HandleGRPC: true,
HandleGRPCWeb: true,
BufferPool: newBufferPool(),
}
WithProtoBinaryCodec().applyToHandler(&config)
WithProtoJSONCodec().applyToHandler(&config)
Expand Down Expand Up @@ -293,6 +295,7 @@ func (c *handlerConfiguration) newProtocolHandlers(streamType StreamType) []prot
Codecs: codecs,
CompressionPools: compressors,
CompressMinBytes: c.CompressMinBytes,
BufferPool: c.BufferPool,
}))
}
return handlers
Expand Down
21 changes: 10 additions & 11 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package connect

import (
"bytes"
"encoding/base64"
"fmt"
"net/http"
Expand Down Expand Up @@ -58,22 +57,22 @@ func DecodeBinaryHeader(data string) ([]byte, error) {
// References:
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses
// https://datatracker.ietf.org/doc/html/rfc3986#section-2.1
func percentEncode(msg string) string {
func percentEncode(bufferPool *bufferPool, msg string) string {
for i := 0; i < len(msg); i++ {
// Characters that need to be escaped are defined in gRPC's HTTP/2 spec.
// They're different from the generic set defined in RFC 3986.
if c := msg[i]; c < ' ' || c > '~' || c == '%' {
return percentEncodeSlow(msg, i)
return percentEncodeSlow(bufferPool, msg, i)
}
}
return msg
}

// msg needs some percent-escaping. Bytes before offset don't require
// percent-encoding, so they can be copied to the output as-is.
func percentEncodeSlow(msg string, offset int) string {
// OPT: easy opportunity to pool buffers
out := bytes.NewBuffer(make([]byte, 0, len(msg)))
func percentEncodeSlow(bufferPool *bufferPool, msg string, offset int) string {
out := bufferPool.Get()
defer bufferPool.Put(out)
out.WriteString(msg[:offset])
for i := offset; i < len(msg); i++ {
c := msg[i]
Expand All @@ -86,20 +85,20 @@ func percentEncodeSlow(msg string, offset int) string {
return out.String()
}

func percentDecode(encoded string) string {
func percentDecode(bufferPool *bufferPool, encoded string) string {
for i := 0; i < len(encoded); i++ {
if c := encoded[i]; c == '%' && i+2 < len(encoded) {
return percentDecodeSlow(encoded, i)
return percentDecodeSlow(bufferPool, encoded, i)
}
}
return encoded
}

// Similar to percentEncodeSlow: encoded is percent-encoded, and needs to be
// decoded byte-by-byte starting at offset.
func percentDecodeSlow(encoded string, offset int) string {
// OPT: easy opportunity to pool buffers
out := bytes.NewBuffer(make([]byte, 0, len(encoded)))
func percentDecodeSlow(bufferPool *bufferPool, encoded string, offset int) string {
out := bufferPool.Get()
defer bufferPool.Put(out)
out.WriteString(encoded[:offset])
for i := offset; i < len(encoded); i++ {
c := encoded[i]
Expand Down
10 changes: 6 additions & 4 deletions header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ func TestBinaryEncodingQuick(t *testing.T) {

func TestPercentEncodingQuick(t *testing.T) {
t.Parallel()
pool := newBufferPool()
roundtrip := func(input string) bool {
if !utf8.ValidString(input) {
return true
}
encoded := percentEncode(input)
decoded := percentDecode(encoded)
encoded := percentEncode(pool, input)
decoded := percentDecode(pool, encoded)
return decoded == input
}
if err := quick.Check(roundtrip, nil /* config */); err != nil {
Expand All @@ -57,11 +58,12 @@ func TestPercentEncodingQuick(t *testing.T) {

func TestPercentEncoding(t *testing.T) {
t.Parallel()
pool := newBufferPool()
roundtrip := func(input string) {
assert.True(t, utf8.ValidString(input), assert.Sprintf("input invalid UTF-8"))
encoded := percentEncode(input)
encoded := percentEncode(pool, input)
t.Logf("%q encoded as %q", input, encoded)
decoded := percentDecode(encoded)
decoded := percentDecode(pool, encoded)
assert.Equal(t, decoded, input)
}

Expand Down
3 changes: 2 additions & 1 deletion protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type protocolHandlerParams struct {
Codecs readOnlyCodecs
CompressionPools readOnlyCompressionPools
CompressMinBytes int
BufferPool *bufferPool
}

// Handler is the server side of a protocol. HTTP handlers typically support
Expand Down Expand Up @@ -93,7 +94,7 @@ type protocolClientParams struct {
CompressMinBytes int
HTTPClient HTTPClient
URL string

BufferPool *bufferPool
// The gRPC family of protocols always needs access to a Protobuf codec to
// marshal and unmarshal errors.
Protobuf Codec
Expand Down
7 changes: 7 additions & 0 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (g *protocolGRPC) NewHandler(params *protocolHandlerParams) protocolHandler
compressionPools: params.CompressionPools,
minCompressBytes: params.CompressMinBytes,
accept: acceptPostValue(g.web, params.Codecs),
bufferPool: params.BufferPool,
}
}

Expand All @@ -59,6 +60,7 @@ func (g *protocolGRPC) NewClient(params *protocolClientParams) (protocolClient,
minCompressBytes: params.CompressMinBytes,
httpClient: params.HTTPClient,
procedureURL: params.URL,
bufferPool: params.BufferPool,
}, nil
}

Expand All @@ -69,6 +71,7 @@ type grpcHandler struct {
compressionPools readOnlyCompressionPools
minCompressBytes int
accept string
bufferPool *bufferPool
}

func (g *grpcHandler) ShouldHandleMethod(method string) bool {
Expand Down Expand Up @@ -173,6 +176,7 @@ func (g *grpcHandler) NewStream(
g.codecs.Protobuf(), // for errors
g.compressionPools.Get(requestCompression),
g.compressionPools.Get(responseCompression),
g.bufferPool,
))
// We can't return failed as-is: a nil *Error is non-nil when returned as an
// error interface.
Expand Down Expand Up @@ -212,6 +216,7 @@ type grpcClient struct {
httpClient HTTPClient
procedureURL string
wrapErrorInterceptor Interceptor
bufferPool *bufferPool
}

func (g *grpcClient) WriteRequestHeader(header http.Header) {
Expand Down Expand Up @@ -262,6 +267,7 @@ func (g *grpcClient) NewStream(
compressionPool: g.compressionPools.Get(g.compressionName),
codec: g.codec,
compressMinBytes: g.minCompressBytes,
bufferPool: g.bufferPool,
},
header: header,
trailer: make(http.Header),
Expand All @@ -270,6 +276,7 @@ func (g *grpcClient) NewStream(
responseHeader: make(http.Header),
responseTrailer: make(http.Header),
compressionPools: g.compressionPools,
bufferPool: g.bufferPool,
responseReady: make(chan struct{}),
}
return g.wrapStream(&clientSender{duplex}, &clientReceiver{duplex})
Expand Down
10 changes: 6 additions & 4 deletions protocol_grpc_client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type duplexClientStream struct {
responseReady chan struct{}
unmarshaler unmarshaler
compressionPools readOnlyCompressionPools
bufferPool *bufferPool

errMu sync.Mutex
requestErr error
Expand Down Expand Up @@ -185,7 +186,7 @@ func (cs *duplexClientStream) Receive(message any) error {
if errors.Is(err, errGotWebTrailers) {
mergeHeaders(cs.responseTrailer, cs.unmarshaler.WebTrailer())
}
if serverErr := extractError(cs.protobuf, cs.responseTrailer); serverErr != nil {
if serverErr := extractError(cs.bufferPool, cs.protobuf, cs.responseTrailer); serverErr != nil {
// This is expected from a protocol perspective, but receiving trailers
// means that we're _not_ getting a message. For users to realize that
// the stream has ended, Receive must return an error.
Expand Down Expand Up @@ -312,7 +313,7 @@ func (cs *duplexClientStream) makeRequest(prepared chan struct{}) {
// DATA frames have been sent on the stream - isn't standard HTTP/2
// semantics, so net/http doesn't know anything about it. To us, then, these
// trailers-only responses actually appear as headers-only responses.
if err := extractError(cs.protobuf, res.Header); err != nil {
if err := extractError(cs.bufferPool, cs.protobuf, res.Header); err != nil {
// Per the specification, only the HTTP status code and Content-Type should
// be treated as headers. The rest should be treated as trailing metadata.
if contentType := res.Header.Get("Content-Type"); contentType != "" {
Expand All @@ -338,6 +339,7 @@ func (cs *duplexClientStream) makeRequest(prepared chan struct{}) {
codec: cs.codec,
compressionPool: cs.compressionPools.Get(compression),
web: cs.web,
bufferPool: cs.bufferPool,
}
}

Expand Down Expand Up @@ -395,7 +397,7 @@ func (cs *duplexClientStream) getRequestOrResponseError() error {
// binary Protobuf format, even if the messages in the request/response stream
// use a different codec. Consequently, this function needs a Protobuf codec to
// unmarshal error information in the headers.
func extractError(protobuf Codec, trailer http.Header) *Error {
func extractError(bufferPool *bufferPool, protobuf Codec, trailer http.Header) *Error {
codeHeader := trailer.Get("Grpc-Status")
if codeHeader == "" || codeHeader == "0" {
return nil
Expand All @@ -405,7 +407,7 @@ func extractError(protobuf Codec, trailer http.Header) *Error {
if err != nil {
return errorf(CodeUnknown, "gRPC protocol error: got invalid error code %q", codeHeader)
}
message := percentDecode(trailer.Get("Grpc-Message"))
message := percentDecode(bufferPool, trailer.Get("Grpc-Message"))
retErr := NewError(Code(code), errors.New(message))

detailsBinaryEncoded := trailer.Get("Grpc-Status-Details-Bin")
Expand Down
Loading