Skip to content

Commit

Permalink
database/sql: do not exhaust connection pool on conn request timeout
Browse files Browse the repository at this point in the history
Previously if a context was canceled while it was waiting for a
connection request, that connection request would leak.

To prevent this remove the pending connection request if the
context is canceled and ensure no connection has been sent on the channel.
This requires a change to how the connection requests are represented in the DB.

Fixes #18995

Change-Id: I9a274b48b8f4f7ca46cdee166faa38f56d030852
Reviewed-on: https://go-review.googlesource.com/36563
Reviewed-by: Russ Cox <[email protected]>
Run-TryBot: Russ Cox <[email protected]>
TryBot-Result: Gobot Gobot <[email protected]>
  • Loading branch information
kardianos authored and rsc committed Feb 9, 2017
1 parent c57d91e commit 4f6d4bb
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 13 deletions.
49 changes: 36 additions & 13 deletions src/database/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,9 @@ type DB struct {

mu sync.Mutex // protects following fields
freeConn []*driverConn
connRequests []chan connRequest
numOpen int // number of opened and pending open connections
connRequests map[uint64]chan connRequest
nextRequest uint64 // Next key to use in connRequests.
numOpen int // number of opened and pending open connections
// Used to signal the need for new connections
// a goroutine running connectionOpener() reads on this chan and
// maybeOpenNewConnections sends on the chan (one send per needed connection)
Expand Down Expand Up @@ -572,10 +573,11 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
db := &DB{
driver: driveri,
dsn: dataSourceName,
openerCh: make(chan struct{}, connectionRequestQueueSize),
lastPut: make(map[*driverConn]string),
driver: driveri,
dsn: dataSourceName,
openerCh: make(chan struct{}, connectionRequestQueueSize),
lastPut: make(map[*driverConn]string),
connRequests: make(map[uint64]chan connRequest),
}
go db.connectionOpener()
return db, nil
Expand Down Expand Up @@ -881,6 +883,14 @@ type connRequest struct {

var errDBClosed = errors.New("sql: database is closed")

// nextRequestKeyLocked returns the next connection request key.
// It is assumed that nextRequest will not overflow.
func (db *DB) nextRequestKeyLocked() uint64 {
next := db.nextRequest
db.nextRequest++
return next
}

// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
Expand Down Expand Up @@ -918,12 +928,25 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
// Make the connRequest channel. It's buffered so that the
// connectionOpener doesn't block while waiting for the req to be read.
req := make(chan connRequest, 1)
db.connRequests = append(db.connRequests, req)
reqKey := db.nextRequestKeyLocked()
db.connRequests[reqKey] = req
db.mu.Unlock()

// Timeout the connection request with the context.
select {
case <-ctx.Done():
// Remove the connection request and ensure no value has been sent
// on it after removing.
db.mu.Lock()
delete(db.connRequests, reqKey)
select {
default:
case ret, ok := <-req:
if ok {
db.putConnDBLocked(ret.conn, ret.err)
}
}
db.mu.Unlock()
return nil, ctx.Err()
case ret, ok := <-req:
if !ok {
Expand Down Expand Up @@ -1044,12 +1067,12 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
return false
}
if c := len(db.connRequests); c > 0 {
req := db.connRequests[0]
// This copy is O(n) but in practice faster than a linked list.
// TODO: consider compacting it down less often and
// moving the base instead?
copy(db.connRequests, db.connRequests[1:])
db.connRequests = db.connRequests[:c-1]
var req chan connRequest
var reqKey uint64
for reqKey, req = range db.connRequests {
break
}
delete(db.connRequests, reqKey) // Remove from pending requests.
if err == nil {
dc.inUse = true
}
Expand Down
57 changes: 57 additions & 0 deletions src/database/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,63 @@ func TestQueryNamedArg(t *testing.T) {
}
}

func TestPoolExhaustOnCancel(t *testing.T) {
if testing.Short() {
t.Skip("long test")
}
db := newTestDB(t, "people")
defer closeDB(t, db)

max := 3

db.SetMaxOpenConns(max)

// First saturate the connection pool.
// Then start new requests for a connection that is cancelled after it is requested.

var saturate, saturateDone sync.WaitGroup
saturate.Add(max)
saturateDone.Add(max)

for i := 0; i < max; i++ {
go func() {
saturate.Done()
rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|")
if err != nil {
t.Fatalf("Query: %v", err)
}
rows.Close()
saturateDone.Done()
}()
}

saturate.Wait()

// Now cancel the request while it is waiting.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()

for i := 0; i < max; i++ {
ctxReq, cancelReq := context.WithCancel(ctx)
go func() {
time.Sleep(time.Millisecond * 100)
cancelReq()
}()
err := db.PingContext(ctxReq)
if err != context.Canceled {
t.Fatalf("PingContext (Exhaust): %v", err)
}
}

saturateDone.Wait()

// Now try to open a normal connection.
err := db.PingContext(ctx)
if err != nil {
t.Fatalf("PingContext (Normal): %v", err)
}
}

func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
Expand Down

0 comments on commit 4f6d4bb

Please sign in to comment.