Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

DB Backend: report explicit error when transactions are used concurrently #37172

Merged
merged 7 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 95 additions & 14 deletions internal/database/basestore/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"database/sql"
"fmt"
"strings"
"sync"

"github.com/google/uuid"
"github.com/sourcegraph/log"

"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
"github.com/sourcegraph/sourcegraph/lib/errors"
Expand Down Expand Up @@ -43,17 +45,28 @@ var (

// NewHandleWithDB returns a new transactable database handle using the given database connection.
func NewHandleWithDB(db *sql.DB, txOptions sql.TxOptions) TransactableHandle {
return &dbHandle{DB: db, txOptions: txOptions}
return &dbHandle{
DB: db,
logger: log.Scoped("internal", "database"),
txOptions: txOptions,
}
}

// NewHandleWithTx returns a new transactable database handle using the given transaction.
func NewHandleWithTx(tx *sql.Tx, txOptions sql.TxOptions) TransactableHandle {
return &txHandle{Tx: tx, txOptions: txOptions}
return &txHandle{
lockingTx: &lockingTx{
tx: tx,
logger: log.Scoped("internal", "database"),
},
txOptions: txOptions,
}
}

type dbHandle struct {
*sql.DB
txOptions sql.TxOptions
logger log.Logger
}

func (h *dbHandle) InTransaction() bool {
Expand All @@ -65,15 +78,15 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) {
if err != nil {
return nil, err
}
return &txHandle{Tx: tx, txOptions: h.txOptions}, nil
return &txHandle{lockingTx: &lockingTx{tx: tx, logger: h.logger}, txOptions: h.txOptions}, nil
}

func (h *dbHandle) Done(err error) error {
return errors.Append(err, ErrNotInTransaction)
}

type txHandle struct {
*sql.Tx
*lockingTx
txOptions sql.TxOptions
}

Expand All @@ -82,23 +95,23 @@ func (h *txHandle) InTransaction() bool {
}

func (h *txHandle) Transact(ctx context.Context) (TransactableHandle, error) {
savepointID, err := newTxSavepoint(ctx, h.Tx)
savepointID, err := newTxSavepoint(ctx, h.lockingTx)
if err != nil {
return nil, err
}

return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil
return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil
}

func (h *txHandle) Done(err error) error {
if err == nil {
return h.Tx.Commit()
return h.Commit()
}
return errors.Append(err, h.Tx.Rollback())
return errors.Append(err, h.Rollback())
}

type savepointHandle struct {
*sql.Tx
*lockingTx
savepointID string
}

Expand All @@ -107,21 +120,21 @@ func (h *savepointHandle) InTransaction() bool {
}

func (h *savepointHandle) Transact(ctx context.Context) (TransactableHandle, error) {
savepointID, err := newTxSavepoint(ctx, h.Tx)
savepointID, err := newTxSavepoint(ctx, h.lockingTx)
if err != nil {
return nil, err
}

return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil
return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil
}

func (h *savepointHandle) Done(err error) error {
if err == nil {
_, execErr := h.Tx.Exec(fmt.Sprintf(commitSavepointQuery, h.savepointID))
_, execErr := h.ExecContext(context.Background(), fmt.Sprintf(commitSavepointQuery, h.savepointID))
return execErr
}

_, execErr := h.Tx.Exec(fmt.Sprintf(rollbackSavepointQuery, h.savepointID))
_, execErr := h.ExecContext(context.Background(), fmt.Sprintf(rollbackSavepointQuery, h.savepointID))
return errors.Append(err, execErr)
}

Expand All @@ -131,7 +144,7 @@ const (
rollbackSavepointQuery = "ROLLBACK TO %s"
)

func newTxSavepoint(ctx context.Context, tx *sql.Tx) (string, error) {
func newTxSavepoint(ctx context.Context, tx *lockingTx) (string, error) {
savepointID, err := makeSavepointID()
if err != nil {
return "", err
Expand All @@ -153,3 +166,71 @@ func makeSavepointID() (string, error) {

return fmt.Sprintf("sp_%s", strings.ReplaceAll(id.String(), "-", "_")), nil
}

var ErrConcurrentTransactionAccess = errors.New("transaction used concurrently")

// lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries to
// use the transaction concurrently. Since using a transaction concurrently is
// unsafe, we want to catch these issues. If lockingTx detects that a
// transaction is being used concurrently, it will log an error and attempt to
// serialize the transaction accesses.
//
// NOTE: this is not foolproof. Interleaving savepoints, accessing rows while
// sending another query, etc. will still fail, so the logged error is a
// notification that something needs fixed, not a notification that the locking
// successfully prevented an issue. In the future, this will likely be upgraded
// to a hard error. Think of this like the race detector, not a race protector.
type lockingTx struct {
tx *sql.Tx
mu sync.Mutex
logger log.Logger
}

func (t *lockingTx) lock() {
if !t.mu.TryLock() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woah, a valid use of the new TryLock()!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to admit I'm a bit skeptical of using TryLock() followed by Lock() in the same function (and there's caveats around using TryLock). But I also don't see another way around it that doesn't end up doing the same thing under the hood (semaphore, atomic compare and swap, etc.)

Copy link
Member Author

@camdencheek camdencheek Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So...Russ Cox is totally correct, including in this case. It is always incorrect to use a transaction concurrently, including after this PR.

However, my concerns are slightly different than Russ's. I want to do anything I can to prevent a panic in production, which might include implementing some imprecise logic that might avoid a handful of panics.

That said, the more important thing here IMO is that we report on incorrect usage, which this does with basically the same consistency as any race detector. This is what TryLock allows us to do: report when the invariant that a transaction should never be used concurrently is violated. A plain Lock cannot do this.

Copy link
Contributor

@mrnugget mrnugget Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I was just wondering whether we need the TryLock for example or could use a one-word value instead?

type lockingTx struct {
	tx     *sql.Tx

	mu     sync.Mutex
	inUse bool

	logger log.Logger
}

func (t *lockingTx) lock() {
	if t.inUse {
		// For now, log an error, but try to serialize access anyways to try to
		// keep things slightly safer.
		err := errors.WithStack(ErrConcurrentTransactionAccess)
		t.logger.Error("transaction used concurrently", log.Error(err))
	}
	t.mu.Lock()
	t.inUse = true
}

func (t *lockingTx) unlock() {
	t.inUse = false
	t.mu.Unlock()
}

(Edit: ... and that's what I meant with "it all ends up being the same" 😄 because this is not different than what TryLock does somewhere under the hood, but I guess TryLock is now the officially supported version, since I'm not even 100% sure 1-word vars are safe to read concurrently like that.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not safe to read inUse outside the mutex though, right?

// For now, log an error, but try to serialize access anyways to try to
// keep things slightly safer.
err := errors.WithStack(ErrConcurrentTransactionAccess)
t.logger.Error("transaction used concurrently", log.Error(err))
t.mu.Lock()
}
}

func (t *lockingTx) unlock() {
t.mu.Unlock()
}

func (t *lockingTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
t.lock()
defer t.unlock()

return t.tx.ExecContext(ctx, query, args...)
}

func (t *lockingTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
t.lock()
defer t.unlock()

return t.tx.QueryContext(ctx, query, args...)
}

func (t *lockingTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
t.lock()
defer t.unlock()

return t.tx.QueryRowContext(ctx, query, args...)
}

func (t *lockingTx) Commit() error {
t.lock()
defer t.unlock()

return t.tx.Commit()
}

func (t *lockingTx) Rollback() error {
t.lock()
defer t.unlock()

return t.tx.Rollback()
}
54 changes: 51 additions & 3 deletions internal/database/basestore/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/keegancsmith/sqlf"
"github.com/sourcegraph/log/logtest"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

func TestTransaction(t *testing.T) {
db := dbtest.NewDB(t)
db := dbtest.NewRawDB(t)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated, but we don't need a db with a frontend schema in these tests

setupStoreTest(t, db)
store := testStore(db)

Expand Down Expand Up @@ -61,8 +64,53 @@ func TestTransaction(t *testing.T) {
assertCounts(t, db, map[int]int{1: 42, 3: 44})
}

func TestConcurrentTransactions(t *testing.T) {
db := dbtest.NewRawDB(t)
setupStoreTest(t, db)
store := testStore(db)
ctx := context.Background()

t.Run("creating transactions concurrently does not fail", func(t *testing.T) {
var g errgroup.Group
for i := 0; i < 2; i++ {
g.Go(func() (err error) {
tx, err := store.Transact(ctx)
if err != nil {
return err
}
defer func() { err = tx.Done(err) }()

return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`))
})
}
require.NoError(t, g.Wait())
})

t.Run("parallel insertion on a single transaction does not fail but logs an error", func(t *testing.T) {
tx, err := store.Transact(ctx)
if err != nil {
t.Fatal(err)
}
capturingLogger, export := logtest.Captured(t)
tx.handle.(*txHandle).logger = capturingLogger

var g errgroup.Group
for i := 0; i < 2; i++ {
g.Go(func() (err error) {
return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`))
})
}
err = g.Wait()
require.NoError(t, err)

captured := export()
require.Greater(t, len(captured), 0)
require.Equal(t, "transaction used concurrently", captured[0].Message)
})
}

func TestSavepoints(t *testing.T) {
db := dbtest.NewDB(t)
db := dbtest.NewRawDB(t)
setupStoreTest(t, db)

NumSavepointTests := 10
Expand All @@ -88,7 +136,7 @@ func TestSavepoints(t *testing.T) {
}

func TestSetLocal(t *testing.T) {
db := dbtest.NewDB(t)
db := dbtest.NewRawDB(t)
setupStoreTest(t, db)
store := testStore(db)

Expand Down