From b9bc8e75a675b85e57b14635f0d1aca6ac01ba82 Mon Sep 17 00:00:00 2001 From: Garrett Gutierrez Date: Fri, 21 Aug 2020 14:22:17 -0700 Subject: [PATCH] End stream flag bugfix (#3803) --- internal/transport/http2_server.go | 4 ++ test/end2end_test.go | 110 +++++++++++++++++++++++++++-- test/servertester.go | 19 ++++- 3 files changed, 126 insertions(+), 7 deletions(-) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 04cbedf7945f..3be22fee426c 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -611,6 +611,10 @@ func (t *http2Server) handleData(f *http2.DataFrame) { if !ok { return } + if s.getState() == streamReadDone { + t.closeStream(s, true, http2.ErrCodeStreamClosed, false) + return + } if size > 0 { if err := s.fc.onData(size); err != nil { t.closeStream(s, true, http2.ErrCodeFlowControl, false) diff --git a/test/end2end_test.go b/test/end2end_test.go index d1a2cdf68612..0842dccaad01 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -4535,7 +4535,7 @@ func testClientRequestBodyErrorUnexpectedEOF(t *testing.T, e env) { te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { - st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") + st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall", false) // Say we have 5 bytes coming, but set END_STREAM flag: st.writeData(1, true, []byte{0, 0, 0, 0, 5}) st.wantAnyFrame() // wait for server to crash (it used to crash) @@ -4559,7 +4559,7 @@ func testClientRequestBodyErrorCloseAfterLength(t *testing.T, e env) { te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { - st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") + st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall", false) // say we're sending 5 bytes, but then close the connection instead. st.writeData(1, false, []byte{0, 0, 0, 0, 5}) st.cc.Close() @@ -4582,7 +4582,7 @@ func testClientRequestBodyErrorCancel(t *testing.T, e env) { te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { - st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") + st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall", false) // Say we have 5 bytes coming, but cancel it instead. st.writeRSTStream(1, http2.ErrCodeCancel) st.writeData(1, false, []byte{0, 0, 0, 0, 5}) @@ -4595,7 +4595,7 @@ func testClientRequestBodyErrorCancel(t *testing.T, e env) { } // And now send an uncanceled (but still invalid), just to get a response. - st.writeHeadersGRPC(3, "/grpc.testing.TestService/UnaryCall") + st.writeHeadersGRPC(3, "/grpc.testing.TestService/UnaryCall", false) st.writeData(3, true, []byte{0, 0, 0, 0, 0}) <-gotCall st.wantAnyFrame() @@ -4619,7 +4619,7 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) { te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { - st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall") + st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false) // Say we have 5 bytes coming, but cancel it instead. st.writeData(1, false, []byte{0, 0, 0, 0, 5}) st.writeRSTStream(1, http2.ErrCodeCancel) @@ -4636,6 +4636,106 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) { }) } +func (s) TestClientInitialHeaderEndStream(t *testing.T) { + for _, e := range listTestEnv() { + if e.httpHandler { + continue + } + testClientInitialHeaderEndStream(t, e) + } +} + +func testClientInitialHeaderEndStream(t *testing.T, e env) { + // To ensure RST_STREAM is sent for illegal data write and not normal stream + // close. + frameCheckingDone := make(chan struct{}) + // To ensure goroutine for test does not end before RPC handler performs error + // checking. + handlerDone := make(chan struct{}) + te := newTest(t, e) + ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { + defer close(handlerDone) + // Block on serverTester receiving RST_STREAM. This ensures server has closed + // stream before stream.Recv(). + <-frameCheckingDone + data, err := stream.Recv() + if err == nil { + t.Errorf("unexpected data received in func server method: '%v'", data) + } else if status.Code(err) != codes.Canceled { + t.Errorf("expected canceled error, instead received '%v'", err) + } + return nil + }} + te.startServer(ts) + defer te.tearDown() + te.withServerTester(func(st *serverTester) { + // Send a headers with END_STREAM flag, but then write data. + st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", true) + st.writeData(1, false, []byte{0, 0, 0, 0, 0}) + st.wantAnyFrame() + st.wantAnyFrame() + st.wantRSTStream(http2.ErrCodeStreamClosed) + close(frameCheckingDone) + <-handlerDone + }) +} + +func (s) TestClientSendDataAfterCloseSend(t *testing.T) { + for _, e := range listTestEnv() { + if e.httpHandler { + continue + } + testClientSendDataAfterCloseSend(t, e) + } +} + +func testClientSendDataAfterCloseSend(t *testing.T, e env) { + // To ensure RST_STREAM is sent for illegal data write prior to execution of RPC + // handler. + frameCheckingDone := make(chan struct{}) + // To ensure goroutine for test does not end before RPC handler performs error + // checking. + handlerDone := make(chan struct{}) + te := newTest(t, e) + ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { + defer close(handlerDone) + // Block on serverTester receiving RST_STREAM. This ensures server has closed + // stream before stream.Recv(). + <-frameCheckingDone + for { + _, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + if status.Code(err) != codes.Canceled { + t.Errorf("expected canceled error, instead received '%v'", err) + } + break + } + } + if err := stream.SendMsg(nil); err == nil { + t.Error("expected error sending message on stream after stream closed due to illegal data") + } else if status.Code(err) != codes.Internal { + t.Errorf("expected internal error, instead received '%v'", err) + } + return nil + }} + te.startServer(ts) + defer te.tearDown() + te.withServerTester(func(st *serverTester) { + st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false) + // Send data with END_STREAM flag, but then write more data. + st.writeData(1, true, []byte{0, 0, 0, 0, 0}) + st.writeData(1, false, []byte{0, 0, 0, 0, 0}) + st.wantAnyFrame() + st.wantAnyFrame() + st.wantRSTStream(http2.ErrCodeStreamClosed) + close(frameCheckingDone) + <-handlerDone + }) +} + func (s) TestClientResourceExhaustedCancelFullDuplex(t *testing.T) { for _, e := range listTestEnv() { if e.httpHandler { diff --git a/test/servertester.go b/test/servertester.go index ff4fa0b3c6bc..9758e8eb6cf8 100644 --- a/test/servertester.go +++ b/test/servertester.go @@ -138,6 +138,21 @@ func (st *serverTester) writeSettingsAck() { } } +func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting an RST frame: %v", err) + } + sf, ok := f.(*http2.RSTStreamFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f) + } + if sf.ErrCode != errCode { + st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String()) + } + return sf +} + func (st *serverTester) wantSettings() *http2.SettingsFrame { f, err := st.readFrame() if err != nil { @@ -227,7 +242,7 @@ func (st *serverTester) encodeHeader(headers ...string) []byte { return st.headerBuf.Bytes() } -func (st *serverTester) writeHeadersGRPC(streamID uint32, path string) { +func (st *serverTester) writeHeadersGRPC(streamID uint32, path string, endStream bool) { st.writeHeaders(http2.HeadersFrameParam{ StreamID: streamID, BlockFragment: st.encodeHeader( @@ -236,7 +251,7 @@ func (st *serverTester) writeHeadersGRPC(streamID uint32, path string) { "content-type", "application/grpc", "te", "trailers", ), - EndStream: false, + EndStream: endStream, EndHeaders: true, }) }