Skip to content

Commit

Permalink
[internal-branch.go1.17-vendor] http2: fix Transport connection pool …
Browse files Browse the repository at this point in the history
…TOCTOU max concurrent stream bug

Updates golang/go#49077

Change-Id: I3e02072403f2f40ade4ef931058bbb5892776754
Reviewed-on: https://go-review.googlesource.com/c/net/+/352469
Run-TryBot: Brad Fitzpatrick <[email protected]>
TryBot-Result: Go Bot <[email protected]>
Reviewed-by: Damien Neil <[email protected]>
Trust: Brad Fitzpatrick <[email protected]>
Reviewed-on: https://go-review.googlesource.com/c/net/+/357681
Trust: Damien Neil <[email protected]>
Run-TryBot: Damien Neil <[email protected]>
Reviewed-by: Dmitri Shuralyov <[email protected]>
  • Loading branch information
bradfitz authored and dmitshur committed Oct 29, 2021
1 parent 248c63b commit 7e8f03d
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 6 deletions.
24 changes: 19 additions & 5 deletions client_conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ import (

// ClientConnPool manages a pool of HTTP/2 client connections.
type ClientConnPool interface {
// GetClientConn returns a specific HTTP/2 connection (usually
// a TLS-TCP connection) to an HTTP/2 server. On success, the
// returned ClientConn accounts for the upcoming RoundTrip
// call, so the caller should not omit it. If the caller needs
// to, ClientConn.RoundTrip can be called with a bogus
// new(http.Request) to release the stream reservation.
GetClientConn(req *http.Request, addr string) (*ClientConn, error)
MarkDead(*ClientConn)
}
Expand Down Expand Up @@ -61,7 +67,7 @@ const (
// during the back-and-forth between net/http and x/net/http2 (when the
// net/http.Transport is upgraded to also speak http2), as well as support
// the case where x/net/http2 is being used directly.
func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
func (p *clientConnPool) shouldTraceGetConn(cc *ClientConn) bool {
// If our Transport wasn't made via ConfigureTransport, always
// trace the GetConn hook if provided, because that means the
// http2 package is being used directly and it's the one
Expand All @@ -72,7 +78,9 @@ func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
// Otherwise, only use the GetConn hook if this connection has
// been used previously for other requests. For fresh
// connections, the net/http package does the dialing.
return !st.freshConn
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.nextStreamID == 1
}

func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
Expand All @@ -89,8 +97,8 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
for {
p.mu.Lock()
for _, cc := range p.conns[addr] {
if st := cc.idleState(); st.canTakeNewRequest {
if p.shouldTraceGetConn(st) {
if cc.ReserveNewRequest() {
if p.shouldTraceGetConn(cc) {
traceGetConn(req, addr)
}
p.mu.Unlock()
Expand All @@ -108,7 +116,13 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
if shouldRetryDial(call, req) {
continue
}
return call.res, call.err
cc, err := call.res, call.err
if err != nil {
return nil, err
}
if cc.ReserveNewRequest() {
return cc, nil
}
}
}

Expand Down
45 changes: 44 additions & 1 deletion transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ type ClientConn struct {
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
goAwayDebug string // goAway frame's debug data, retained as a string
streams map[uint32]*clientStream // client-initiated
streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
nextStreamID uint32
pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
pings map[[8]byte]chan struct{} // in flight ping data to notification channel
Expand Down Expand Up @@ -778,12 +779,28 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {

// CanTakeNewRequest reports whether the connection can take a new request,
// meaning it has not been closed or received or sent a GOAWAY.
//
// If the caller is going to immediately make a new request on this
// connection, use ReserveNewRequest instead.
func (cc *ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.canTakeNewRequestLocked()
}

// ReserveNewRequest is like CanTakeNewRequest but also reserves a
// concurrent stream in cc. The reservation is decremented on the
// next call to RoundTrip.
func (cc *ClientConn) ReserveNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
if st := cc.idleStateLocked(); !st.canTakeNewRequest {
return false
}
cc.streamsReserved++
return true
}

// clientConnIdleState describes the suitability of a client
// connection to initiate a new RoundTrip request.
type clientConnIdleState struct {
Expand All @@ -809,7 +826,7 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
// writing it.
maxConcurrentOkay = true
} else {
maxConcurrentOkay = int64(len(cc.streams)+1) <= int64(cc.maxConcurrentStreams)
maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
}

st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
Expand Down Expand Up @@ -1029,6 +1046,18 @@ func actualContentLength(req *http.Request) int64 {
return -1
}

func (cc *ClientConn) decrStreamReservations() {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.decrStreamReservationsLocked()
}

func (cc *ClientConn) decrStreamReservationsLocked() {
if cc.streamsReserved > 0 {
cc.streamsReserved--
}
}

func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
resp, _, err := cc.roundTrip(req)
return resp, err
Expand All @@ -1037,6 +1066,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
ctx := req.Context()
if err := checkConnHeaders(req); err != nil {
cc.decrStreamReservations()
return nil, false, err
}
if cc.idleTimer != nil {
Expand All @@ -1045,6 +1075,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf

trailers, err := commaSeparatedTrailers(req)
if err != nil {
cc.decrStreamReservations()
return nil, false, err
}
hasTrailers := trailers != ""
Expand All @@ -1058,8 +1089,10 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
select {
case cc.reqHeaderMu <- struct{}{}:
case <-req.Cancel:
cc.decrStreamReservations()
return nil, false, errRequestCanceled
case <-ctx.Done():
cc.decrStreamReservations()
return nil, false, ctx.Err()
}
reqHeaderMuNeedsUnlock := true
Expand All @@ -1070,6 +1103,11 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
}()

cc.mu.Lock()
cc.decrStreamReservationsLocked()
if req.URL == nil {
cc.mu.Unlock()
return nil, false, errNilRequestURL
}
if err := cc.awaitOpenSlotForRequest(req); err != nil {
cc.mu.Unlock()
return nil, false, err
Expand Down Expand Up @@ -1522,9 +1560,14 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
}
}

var errNilRequestURL = errors.New("http2: Request.URI is nil")

// requires cc.wmu be held.
func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
cc.hbuf.Reset()
if req.URL == nil {
return nil, errNilRequestURL
}

host := req.Host
if host == "" {
Expand Down
39 changes: 39 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5296,3 +5296,42 @@ func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (
func (p *collectClientsConnPool) MarkDead(cc *ClientConn) {
p.lower.MarkDead(cc)
}

func TestClientConnReservations(t *testing.T) {
cc := &ClientConn{
reqHeaderMu: make(chan struct{}, 1),
streams: make(map[uint32]*clientStream),
maxConcurrentStreams: initialMaxConcurrentStreams,
t: &Transport{},
}
n := 0
for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
n++
}
if n != initialMaxConcurrentStreams {
t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
}
if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
}
n2 := 0
for n2 <= 5 && cc.ReserveNewRequest() {
n2++
}
if n2 != 1 {
t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
}

// Use up all the reservations
for i := 0; i < n; i++ {
cc.RoundTrip(new(http.Request))
}

n2 = 0
for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
n2++
}
if n2 != n {
t.Errorf("after reset, reservations = %v; want %v", n2, n)
}
}

0 comments on commit 7e8f03d

Please sign in to comment.