Skip to content

Commit

Permalink
Pool buffers per handler and client (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
bufdev authored Apr 9, 2022
1 parent 28baade commit 7021f08
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 57 deletions.
12 changes: 7 additions & 5 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,23 @@ func BenchmarkConnect(b *testing.B) {
connect.WithGzipRequests(),
)
assert.Nil(b, err)
twoMiB := strings.Repeat("a", 2*1024*1024)
b.ResetTimer()

b.Run("unary", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = client.Ping(
context.Background(),
connect.NewRequest(&pingv1.PingRequest{Number: 42}),
connect.NewRequest(&pingv1.PingRequest{Text: twoMiB}),
)
}
})
})
}

type ping struct {
Number int `json:"number"`
Text string `json:"text"`
}

func BenchmarkREST(b *testing.B) {
Expand Down Expand Up @@ -118,23 +119,24 @@ func BenchmarkREST(b *testing.B) {
server.EnableHTTP2 = true
server.StartTLS()
defer server.Close()
twoMiB := strings.Repeat("a", 2*1024*1024)
b.ResetTimer()

b.Run("unary", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
unaryRESTIteration(b, server.Client(), server.URL)
unaryRESTIteration(b, server.Client(), server.URL, twoMiB)
}
})
})
}

func unaryRESTIteration(b *testing.B, client *http.Client, url string) {
func unaryRESTIteration(b *testing.B, client *http.Client, url string, text string) {
b.Helper()
rawRequestBody := bytes.NewBuffer(nil)
compressedRequestBody := gzip.NewWriter(rawRequestBody)
encoder := json.NewEncoder(compressedRequestBody)
if err := encoder.Encode(&ping{42}); err != nil {
if err := encoder.Encode(&ping{text}); err != nil {
b.Fatalf("marshal request: %v", err)
}
compressedRequestBody.Close()
Expand Down
54 changes: 54 additions & 0 deletions buffer_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect

import (
"bytes"
"sync"
)

const (
initialBufferSize = 512
maxRecycleBufferSize = 8 * 1024 * 1024 // if >8MiB, don't hold onto a buffer
)

type bufferPool struct {
sync.Pool
}

func newBufferPool() *bufferPool {
return &bufferPool{
Pool: sync.Pool{
New: func() any {
return bytes.NewBuffer(make([]byte, 0, initialBufferSize))
},
},
}
}

func (b *bufferPool) Get() *bytes.Buffer {
if buf, ok := b.Pool.Get().(*bytes.Buffer); ok {
return buf
}
return bytes.NewBuffer(make([]byte, 0, initialBufferSize))
}

func (b *bufferPool) Put(buffer *bytes.Buffer) {
if buffer.Cap() > maxRecycleBufferSize {
return
}
buffer.Reset()
b.Pool.Put(buffer)
}
3 changes: 3 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func NewClient[Req, Res any](
CompressMinBytes: config.CompressMinBytes,
HTTPClient: httpClient,
URL: url,
BufferPool: config.BufferPool,
})
if protocolErr != nil {
return nil, protocolErr
Expand Down Expand Up @@ -170,13 +171,15 @@ type clientConfiguration struct {
CompressionPools map[string]*compressionPool
Codec Codec
RequestCompressionName string
BufferPool *bufferPool
}

func newClientConfiguration(url string, options []ClientOption) (*clientConfiguration, *Error) {
protoPath := extractProtobufPath(url)
config := clientConfiguration{
Procedure: protoPath,
CompressionPools: make(map[string]*compressionPool),
BufferPool: newBufferPool(),
}
WithProtoBinaryCodec().applyToClient(&config)
WithGzip().applyToClient(&config)
Expand Down
3 changes: 3 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ type handlerConfiguration struct {
Procedure string
HandleGRPC bool
HandleGRPCWeb bool
BufferPool *bufferPool
}

func newHandlerConfiguration(procedure string, options []HandlerOption) *handlerConfiguration {
Expand All @@ -259,6 +260,7 @@ func newHandlerConfiguration(procedure string, options []HandlerOption) *handler
Codecs: make(map[string]Codec),
HandleGRPC: true,
HandleGRPCWeb: true,
BufferPool: newBufferPool(),
}
WithProtoBinaryCodec().applyToHandler(&config)
WithProtoJSONCodec().applyToHandler(&config)
Expand Down Expand Up @@ -293,6 +295,7 @@ func (c *handlerConfiguration) newProtocolHandlers(streamType StreamType) []prot
Codecs: codecs,
CompressionPools: compressors,
CompressMinBytes: c.CompressMinBytes,
BufferPool: c.BufferPool,
}))
}
return handlers
Expand Down
21 changes: 10 additions & 11 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package connect

import (
"bytes"
"encoding/base64"
"fmt"
"net/http"
Expand Down Expand Up @@ -58,22 +57,22 @@ func DecodeBinaryHeader(data string) ([]byte, error) {
// References:
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses
// https://datatracker.ietf.org/doc/html/rfc3986#section-2.1
func percentEncode(msg string) string {
func percentEncode(bufferPool *bufferPool, msg string) string {
for i := 0; i < len(msg); i++ {
// Characters that need to be escaped are defined in gRPC's HTTP/2 spec.
// They're different from the generic set defined in RFC 3986.
if c := msg[i]; c < ' ' || c > '~' || c == '%' {
return percentEncodeSlow(msg, i)
return percentEncodeSlow(bufferPool, msg, i)
}
}
return msg
}

// msg needs some percent-escaping. Bytes before offset don't require
// percent-encoding, so they can be copied to the output as-is.
func percentEncodeSlow(msg string, offset int) string {
// OPT: easy opportunity to pool buffers
out := bytes.NewBuffer(make([]byte, 0, len(msg)))
func percentEncodeSlow(bufferPool *bufferPool, msg string, offset int) string {
out := bufferPool.Get()
defer bufferPool.Put(out)
out.WriteString(msg[:offset])
for i := offset; i < len(msg); i++ {
c := msg[i]
Expand All @@ -86,20 +85,20 @@ func percentEncodeSlow(msg string, offset int) string {
return out.String()
}

func percentDecode(encoded string) string {
func percentDecode(bufferPool *bufferPool, encoded string) string {
for i := 0; i < len(encoded); i++ {
if c := encoded[i]; c == '%' && i+2 < len(encoded) {
return percentDecodeSlow(encoded, i)
return percentDecodeSlow(bufferPool, encoded, i)
}
}
return encoded
}

// Similar to percentEncodeSlow: encoded is percent-encoded, and needs to be
// decoded byte-by-byte starting at offset.
func percentDecodeSlow(encoded string, offset int) string {
// OPT: easy opportunity to pool buffers
out := bytes.NewBuffer(make([]byte, 0, len(encoded)))
func percentDecodeSlow(bufferPool *bufferPool, encoded string, offset int) string {
out := bufferPool.Get()
defer bufferPool.Put(out)
out.WriteString(encoded[:offset])
for i := offset; i < len(encoded); i++ {
c := encoded[i]
Expand Down
10 changes: 6 additions & 4 deletions header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ func TestBinaryEncodingQuick(t *testing.T) {

func TestPercentEncodingQuick(t *testing.T) {
t.Parallel()
pool := newBufferPool()
roundtrip := func(input string) bool {
if !utf8.ValidString(input) {
return true
}
encoded := percentEncode(input)
decoded := percentDecode(encoded)
encoded := percentEncode(pool, input)
decoded := percentDecode(pool, encoded)
return decoded == input
}
if err := quick.Check(roundtrip, nil /* config */); err != nil {
Expand All @@ -57,11 +58,12 @@ func TestPercentEncodingQuick(t *testing.T) {

func TestPercentEncoding(t *testing.T) {
t.Parallel()
pool := newBufferPool()
roundtrip := func(input string) {
assert.True(t, utf8.ValidString(input), assert.Sprintf("input invalid UTF-8"))
encoded := percentEncode(input)
encoded := percentEncode(pool, input)
t.Logf("%q encoded as %q", input, encoded)
decoded := percentDecode(encoded)
decoded := percentDecode(pool, encoded)
assert.Equal(t, decoded, input)
}

Expand Down
3 changes: 2 additions & 1 deletion protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type protocolHandlerParams struct {
Codecs readOnlyCodecs
CompressionPools readOnlyCompressionPools
CompressMinBytes int
BufferPool *bufferPool
}

// Handler is the server side of a protocol. HTTP handlers typically support
Expand Down Expand Up @@ -93,7 +94,7 @@ type protocolClientParams struct {
CompressMinBytes int
HTTPClient HTTPClient
URL string

BufferPool *bufferPool
// The gRPC family of protocols always needs access to a Protobuf codec to
// marshal and unmarshal errors.
Protobuf Codec
Expand Down
7 changes: 7 additions & 0 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (g *protocolGRPC) NewHandler(params *protocolHandlerParams) protocolHandler
compressionPools: params.CompressionPools,
minCompressBytes: params.CompressMinBytes,
accept: acceptPostValue(g.web, params.Codecs),
bufferPool: params.BufferPool,
}
}

Expand All @@ -59,6 +60,7 @@ func (g *protocolGRPC) NewClient(params *protocolClientParams) (protocolClient,
minCompressBytes: params.CompressMinBytes,
httpClient: params.HTTPClient,
procedureURL: params.URL,
bufferPool: params.BufferPool,
}, nil
}

Expand All @@ -69,6 +71,7 @@ type grpcHandler struct {
compressionPools readOnlyCompressionPools
minCompressBytes int
accept string
bufferPool *bufferPool
}

func (g *grpcHandler) ShouldHandleMethod(method string) bool {
Expand Down Expand Up @@ -173,6 +176,7 @@ func (g *grpcHandler) NewStream(
g.codecs.Protobuf(), // for errors
g.compressionPools.Get(requestCompression),
g.compressionPools.Get(responseCompression),
g.bufferPool,
))
// We can't return failed as-is: a nil *Error is non-nil when returned as an
// error interface.
Expand Down Expand Up @@ -212,6 +216,7 @@ type grpcClient struct {
httpClient HTTPClient
procedureURL string
wrapErrorInterceptor Interceptor
bufferPool *bufferPool
}

func (g *grpcClient) WriteRequestHeader(header http.Header) {
Expand Down Expand Up @@ -262,6 +267,7 @@ func (g *grpcClient) NewStream(
compressionPool: g.compressionPools.Get(g.compressionName),
codec: g.codec,
compressMinBytes: g.minCompressBytes,
bufferPool: g.bufferPool,
},
header: header,
trailer: make(http.Header),
Expand All @@ -270,6 +276,7 @@ func (g *grpcClient) NewStream(
responseHeader: make(http.Header),
responseTrailer: make(http.Header),
compressionPools: g.compressionPools,
bufferPool: g.bufferPool,
responseReady: make(chan struct{}),
}
return g.wrapStream(&clientSender{duplex}, &clientReceiver{duplex})
Expand Down
10 changes: 6 additions & 4 deletions protocol_grpc_client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type duplexClientStream struct {
responseReady chan struct{}
unmarshaler unmarshaler
compressionPools readOnlyCompressionPools
bufferPool *bufferPool

errMu sync.Mutex
requestErr error
Expand Down Expand Up @@ -185,7 +186,7 @@ func (cs *duplexClientStream) Receive(message any) error {
if errors.Is(err, errGotWebTrailers) {
mergeHeaders(cs.responseTrailer, cs.unmarshaler.WebTrailer())
}
if serverErr := extractError(cs.protobuf, cs.responseTrailer); serverErr != nil {
if serverErr := extractError(cs.bufferPool, cs.protobuf, cs.responseTrailer); serverErr != nil {
// This is expected from a protocol perspective, but receiving trailers
// means that we're _not_ getting a message. For users to realize that
// the stream has ended, Receive must return an error.
Expand Down Expand Up @@ -312,7 +313,7 @@ func (cs *duplexClientStream) makeRequest(prepared chan struct{}) {
// DATA frames have been sent on the stream - isn't standard HTTP/2
// semantics, so net/http doesn't know anything about it. To us, then, these
// trailers-only responses actually appear as headers-only responses.
if err := extractError(cs.protobuf, res.Header); err != nil {
if err := extractError(cs.bufferPool, cs.protobuf, res.Header); err != nil {
// Per the specification, only the HTTP status code and Content-Type should
// be treated as headers. The rest should be treated as trailing metadata.
if contentType := res.Header.Get("Content-Type"); contentType != "" {
Expand All @@ -338,6 +339,7 @@ func (cs *duplexClientStream) makeRequest(prepared chan struct{}) {
codec: cs.codec,
compressionPool: cs.compressionPools.Get(compression),
web: cs.web,
bufferPool: cs.bufferPool,
}
}

Expand Down Expand Up @@ -395,7 +397,7 @@ func (cs *duplexClientStream) getRequestOrResponseError() error {
// binary Protobuf format, even if the messages in the request/response stream
// use a different codec. Consequently, this function needs a Protobuf codec to
// unmarshal error information in the headers.
func extractError(protobuf Codec, trailer http.Header) *Error {
func extractError(bufferPool *bufferPool, protobuf Codec, trailer http.Header) *Error {
codeHeader := trailer.Get("Grpc-Status")
if codeHeader == "" || codeHeader == "0" {
return nil
Expand All @@ -405,7 +407,7 @@ func extractError(protobuf Codec, trailer http.Header) *Error {
if err != nil {
return errorf(CodeUnknown, "gRPC protocol error: got invalid error code %q", codeHeader)
}
message := percentDecode(trailer.Get("Grpc-Message"))
message := percentDecode(bufferPool, trailer.Get("Grpc-Message"))
retErr := NewError(Code(code), errors.New(message))

detailsBinaryEncoded := trailer.Get("Grpc-Status-Details-Bin")
Expand Down
Loading

0 comments on commit 7021f08

Please sign in to comment.