Skip to content

Commit 78b0b88

Browse files
committed
consume per-stream inflow window preemptively based on requested read size
1 parent b507112 commit 78b0b88

10 files changed

+442
-92
lines changed

rpc_util.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ type parser struct {
221221
// r is the underlying reader.
222222
// See the comment on recvMsg for the permissible
223223
// error types.
224-
r io.Reader
224+
r transport.FullReader
225225

226226
// The header of a gRPC message. Find more detail
227227
// at http://www.grpc.io/docs/guides/wire.html.
@@ -242,7 +242,7 @@ type parser struct {
242242
// that the underlying io.Reader must not return an incompatible
243243
// error.
244244
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
245-
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
245+
if _, err := p.r.ReadFull(p.header[:]); err != nil {
246246
return 0, nil, err
247247
}
248248

@@ -258,7 +258,7 @@ func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err erro
258258
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
259259
// of making it for each message:
260260
msg = make([]byte, int(length))
261-
if _, err := io.ReadFull(p.r, msg); err != nil {
261+
if _, err := p.r.ReadFull(msg); err != nil {
262262
if err == io.EOF {
263263
err = io.ErrUnexpectedEOF
264264
}

rpc_util_test.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ import (
4747
"google.golang.org/grpc/transport"
4848
)
4949

50+
type fullReaderForTesting struct {
51+
buf io.Reader
52+
}
53+
54+
func (f *fullReaderForTesting) ReadFull(p []byte) (int, error) {
55+
return io.ReadFull(f.buf, p)
56+
}
57+
5058
func TestSimpleParsing(t *testing.T) {
5159
bigMsg := bytes.Repeat([]byte{'x'}, 1<<24)
5260
for _, test := range []struct {
@@ -65,8 +73,8 @@ func TestSimpleParsing(t *testing.T) {
6573
// Check that messages with length >= 2^24 are parsed.
6674
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
6775
} {
68-
buf := bytes.NewReader(test.p)
69-
parser := &parser{r: buf}
76+
fullReader := &fullReaderForTesting{buf: bytes.NewReader(test.p)}
77+
parser := &parser{r: fullReader}
7078
pt, b, err := parser.recvMsg(math.MaxInt32)
7179
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
7280
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
@@ -77,8 +85,8 @@ func TestSimpleParsing(t *testing.T) {
7785
func TestMultipleParsing(t *testing.T) {
7886
// Set a byte stream consists of 3 messages with their headers.
7987
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
80-
b := bytes.NewReader(p)
81-
parser := &parser{r: b}
88+
fullReader := &fullReaderForTesting{buf: bytes.NewReader(p)}
89+
parser := &parser{r: fullReader}
8290

8391
wantRecvs := []struct {
8492
pt payloadFormat

test/end2end_test.go

+37-25
Original file line numberDiff line numberDiff line change
@@ -2243,8 +2243,8 @@ func testCancelNoIO(t *testing.T, e env) {
22432243
// The following tests the gRPC streaming RPC implementations.
22442244
// TODO(zhaoq): Have better coverage on error cases.
22452245
var (
2246-
reqSizes = []int{27182, 8, 1828, 45904}
2247-
respSizes = []int{31415, 9, 2653, 58979}
2246+
pingPongReqSizes = []int{27182, 8, 1828, 45904}
2247+
pingPongRespSizes = []int{31415, 9, 2653, 58979}
22482248
)
22492249

22502250
func TestNoService(t *testing.T) {
@@ -2289,14 +2289,14 @@ func testPingPong(t *testing.T, e env) {
22892289
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
22902290
}
22912291
var index int
2292-
for index < len(reqSizes) {
2292+
for index < len(pingPongReqSizes) {
22932293
respParam := []*testpb.ResponseParameters{
22942294
{
2295-
Size: proto.Int32(int32(respSizes[index])),
2295+
Size: proto.Int32(int32(pingPongRespSizes[index])),
22962296
},
22972297
}
22982298

2299-
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(reqSizes[index]))
2299+
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(pingPongReqSizes[index]))
23002300
if err != nil {
23012301
t.Fatal(err)
23022302
}
@@ -2318,8 +2318,8 @@ func testPingPong(t *testing.T, e env) {
23182318
t.Fatalf("Got the reply of type %d, want %d", pt, testpb.PayloadType_COMPRESSABLE)
23192319
}
23202320
size := len(reply.GetPayload().GetBody())
2321-
if size != int(respSizes[index]) {
2322-
t.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
2321+
if size != int(pingPongRespSizes[index]) {
2322+
t.Fatalf("Got reply body of length %d, want %d", size, pingPongRespSizes[index])
23232323
}
23242324
index++
23252325
}
@@ -2367,14 +2367,14 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
23672367
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
23682368
}
23692369
var index int
2370-
for index < len(reqSizes) {
2370+
for index < len(pingPongReqSizes) {
23712371
respParam := []*testpb.ResponseParameters{
23722372
{
2373-
Size: proto.Int32(int32(respSizes[index])),
2373+
Size: proto.Int32(int32(pingPongRespSizes[index])),
23742374
},
23752375
}
23762376

2377-
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(reqSizes[index]))
2377+
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(pingPongReqSizes[index]))
23782378
if err != nil {
23792379
t.Fatal(err)
23802380
}
@@ -2405,20 +2405,26 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
24052405
}
24062406

24072407
func TestServerStreaming(t *testing.T) {
2408+
serverRespSizes := [][]int{
2409+
{27182, 8, 1828, 45904},
2410+
{(1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21)},
2411+
}
24082412
defer leakCheck(t)()
2409-
for _, e := range listTestEnv() {
2410-
testServerStreaming(t, e)
2413+
for _, s := range serverRespSizes {
2414+
for _, e := range listTestEnv() {
2415+
testServerStreaming(t, e, s)
2416+
}
24112417
}
24122418
}
24132419

2414-
func testServerStreaming(t *testing.T, e env) {
2420+
func testServerStreaming(t *testing.T, e env, serverRespSizes []int) {
24152421
te := newTest(t, e)
24162422
te.startServer(&testServer{security: e.security})
24172423
defer te.tearDown()
24182424
tc := testpb.NewTestServiceClient(te.clientConn())
24192425

2420-
respParam := make([]*testpb.ResponseParameters, len(respSizes))
2421-
for i, s := range respSizes {
2426+
respParam := make([]*testpb.ResponseParameters, len(serverRespSizes))
2427+
for i, s := range serverRespSizes {
24222428
respParam[i] = &testpb.ResponseParameters{
24232429
Size: proto.Int32(int32(s)),
24242430
}
@@ -2445,17 +2451,17 @@ func testServerStreaming(t *testing.T, e env) {
24452451
t.Fatalf("Got the reply of type %d, want %d", pt, testpb.PayloadType_COMPRESSABLE)
24462452
}
24472453
size := len(reply.GetPayload().GetBody())
2448-
if size != int(respSizes[index]) {
2449-
t.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
2454+
if size != int(serverRespSizes[index]) {
2455+
t.Fatalf("Got reply body of length %d, want %d", size, serverRespSizes[index])
24502456
}
24512457
index++
24522458
respCnt++
24532459
}
24542460
if rpcStatus != io.EOF {
24552461
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", rpcStatus)
24562462
}
2457-
if respCnt != len(respSizes) {
2458-
t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)
2463+
if respCnt != len(serverRespSizes) {
2464+
t.Fatalf("Got %d reply, want %d", len(serverRespSizes), respCnt)
24592465
}
24602466
}
24612467

@@ -2473,8 +2479,8 @@ func testFailedServerStreaming(t *testing.T, e env) {
24732479
defer te.tearDown()
24742480
tc := testpb.NewTestServiceClient(te.clientConn())
24752481

2476-
respParam := make([]*testpb.ResponseParameters, len(respSizes))
2477-
for i, s := range respSizes {
2482+
respParam := make([]*testpb.ResponseParameters, len(pingPongRespSizes))
2483+
for i, s := range pingPongRespSizes {
24782484
respParam[i] = &testpb.ResponseParameters{
24792485
Size: proto.Int32(int32(s)),
24802486
}
@@ -2576,12 +2582,18 @@ func testServerStreamingConcurrent(t *testing.T, e env) {
25762582

25772583
func TestClientStreaming(t *testing.T) {
25782584
defer leakCheck(t)()
2579-
for _, e := range listTestEnv() {
2580-
testClientStreaming(t, e)
2585+
clientReqSizes := [][]int{
2586+
{27182, 8, 1828, 45904},
2587+
{(1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21), (1 << 21)},
2588+
}
2589+
for _, s := range clientReqSizes {
2590+
for _, e := range listTestEnv() {
2591+
testClientStreaming(t, e, s)
2592+
}
25812593
}
25822594
}
25832595

2584-
func testClientStreaming(t *testing.T, e env) {
2596+
func testClientStreaming(t *testing.T, e env, clientReqSizes []int) {
25852597
te := newTest(t, e)
25862598
te.startServer(&testServer{security: e.security})
25872599
defer te.tearDown()
@@ -2593,7 +2605,7 @@ func testClientStreaming(t *testing.T, e env) {
25932605
}
25942606

25952607
var sum int
2596-
for _, s := range reqSizes {
2608+
for _, s := range clientReqSizes {
25972609
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(s))
25982610
if err != nil {
25992611
t.Fatal(err)

transport/control.go

+48-5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ import (
3939
"sync"
4040
"time"
4141

42+
"google.golang.org/grpc/grpclog"
43+
4244
"golang.org/x/net/http2"
4345
)
4446

@@ -58,6 +60,14 @@ const (
5860
defaultServerKeepaliveTime = time.Duration(2 * time.Hour)
5961
defaultServerKeepaliveTimeout = time.Duration(20 * time.Second)
6062
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute)
63+
// Put a cap on the max possible window update (this value reached when
64+
// an attempt to read a large message is made).
65+
// 4M is greater than connection window but is arbitrary otherwise.
66+
// Note this must be greater than a stream's incoming window size to have an effect.
67+
maxSingleStreamWindowUpdate = 4194303
68+
69+
// max legal window update
70+
http2MaxWindowUpdate = 2147483647
6171
)
6272

6373
// The following defines various control items which could flow through
@@ -161,34 +171,51 @@ type inFlow struct {
161171
limit uint32
162172

163173
mu sync.Mutex
164-
// pendingData is the overall data which have been received but not been
174+
// PendingData is the overall data which have been received but not been
165175
// consumed by applications.
166176
pendingData uint32
167177
// The amount of data the application has consumed but grpc has not sent
168178
// window update for them. Used to reduce window update frequency.
169179
pendingUpdate uint32
180+
181+
// This is temporary space in the incoming flow control that can be granted at convenient times
182+
// to prevent the sender from stalling for lack flow control space.
183+
// If present, it is paid back when data is consumed from the window.
184+
loanedWindowSpace uint32
170185
}
171186

172187
// onData is invoked when some data frame is received. It updates pendingData.
173188
func (f *inFlow) onData(n uint32) error {
174189
f.mu.Lock()
175190
defer f.mu.Unlock()
176191
f.pendingData += n
177-
if f.pendingData+f.pendingUpdate > f.limit {
178-
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit)
192+
if f.pendingData+f.pendingUpdate > f.limit+f.loanedWindowSpace {
193+
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit+f.loanedWindowSpace)
179194
}
180195
return nil
181196
}
182197

198+
func min(a uint32, b uint32) uint32 {
199+
if a < b {
200+
return a
201+
}
202+
return b
203+
}
204+
183205
// onRead is invoked when the application reads the data. It returns the window size
184206
// to be sent to the peer.
185207
func (f *inFlow) onRead(n uint32) uint32 {
186208
f.mu.Lock()
187209
defer f.mu.Unlock()
188-
if f.pendingData == 0 {
189-
return 0
210+
if n > http2MaxWindowUpdate {
211+
grpclog.Fatalf("potential window update too large. onRead(n) where n is %v; max n is %v", f.pendingUpdate, http2MaxWindowUpdate)
190212
}
191213
f.pendingData -= n
214+
// first use up remaining "loanedWindowSpace", add remaining Read to "pendingUpdate"
215+
windowSpaceDebtPayment := min(n, f.loanedWindowSpace)
216+
f.loanedWindowSpace -= windowSpaceDebtPayment
217+
n -= windowSpaceDebtPayment
218+
192219
f.pendingUpdate += n
193220
if f.pendingUpdate >= f.limit/4 {
194221
wu := f.pendingUpdate
@@ -198,6 +225,22 @@ func (f *inFlow) onRead(n uint32) uint32 {
198225
return 0
199226
}
200227

228+
func (f *inFlow) loanWindowSpace(n uint32) uint32 {
229+
f.mu.Lock()
230+
defer f.mu.Unlock()
231+
if f.loanedWindowSpace > 0 {
232+
grpclog.Fatalf("pre-consuming window space while there is pre-consumed window space still outstanding")
233+
}
234+
f.loanedWindowSpace = n
235+
236+
if f.loanedWindowSpace+f.pendingUpdate >= f.limit/4 {
237+
wu := f.pendingUpdate + f.loanedWindowSpace
238+
f.pendingUpdate = 0
239+
return wu
240+
}
241+
return 0
242+
}
243+
201244
func (f *inFlow) resetPendingData() uint32 {
202245
f.mu.Lock()
203246
defer f.mu.Unlock()

transport/handler_server.go

+15-8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import (
5151
"golang.org/x/net/http2"
5252
"google.golang.org/grpc/codes"
5353
"google.golang.org/grpc/credentials"
54+
"google.golang.org/grpc/grpclog"
5455
"google.golang.org/grpc/metadata"
5556
"google.golang.org/grpc/peer"
5657
"google.golang.org/grpc/status"
@@ -273,6 +274,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
273274
})
274275
}
275276

277+
type handlerServerStreamReader struct{}
278+
279+
func (*handlerServerStreamReader) Read(_ []byte) (int, error) {
280+
grpclog.Fatalf("handler server streamReader is unimplemented")
281+
return 0, nil
282+
}
283+
276284
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
277285
// With this transport type there will be exactly 1 stream: this HTTP request.
278286

@@ -305,13 +313,13 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
305313
req := ht.req
306314

307315
s := &Stream{
308-
id: 0, // irrelevant
309-
windowHandler: func(int) {}, // nothing
310-
cancel: cancel,
311-
buf: newRecvBuffer(),
312-
st: ht,
313-
method: req.URL.Path,
314-
recvCompress: req.Header.Get("grpc-encoding"),
316+
id: 0, // irrelevant
317+
cancel: cancel,
318+
buf: newRecvBuffer(),
319+
st: ht,
320+
method: req.URL.Path,
321+
recvCompress: req.Header.Get("grpc-encoding"),
322+
streamReader: &handlerServerStreamReader{},
315323
}
316324
pr := &peer.Peer{
317325
Addr: ht.RemoteAddr(),
@@ -322,7 +330,6 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
322330
ctx = metadata.NewContext(ctx, ht.headerMD)
323331
ctx = peer.NewContext(ctx, pr)
324332
s.ctx = newContextWithStream(ctx, s)
325-
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
326333

327334
// readerDone is closed when the Body.Read-ing goroutine exits.
328335
readerDone := make(chan struct{})

0 commit comments

Comments
 (0)