Skip to content

Commit

Permalink
feat: ✨ transaction can now be placed to the context
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Querier.Conn now have context as an argument
  • Loading branch information
MrEhbr committed Mar 21, 2023
1 parent 069c9e7 commit 61dad71
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 11 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,26 @@ func main() {
_, err = q.Exec(ctx, "UPDATE table SET something = 1")
return err
}, conn.StatementTimeout(time.Second))

tx, err := wrapped.Conn(ctx).Begin(ctx)
if err != nil {
log.Fatalf("failed to start transaction: %s", err)
}

// Put a transaction in the context, so that all subsequent calls use the transaction
txCtx := conn.NewTxContext(ctx, tx)
if _, err := wrapped.Exec(txCtx, "UPDATE table SET something = 1"); err != nil {
_ = tx.Rollback(ctx)
log.Fatalf("failed to exec: %s", err)
}
if err := wrapped.Get(txCtx, &count, "SELECT COUNT(*) FROM table"); err != nil {
_ = tx.Rollback(ctx)
log.Fatalf("failed to get: %s", err)
}

if err := tx.Commit(ctx); err != nil {
log.Fatalf("failed to commit transaction: %s", err)
}
}
```

Expand Down
16 changes: 10 additions & 6 deletions conn/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Querier interface {
Get(ctx context.Context, dst interface{}, sql string, args ...interface{}) error
Exec(ctx context.Context, sql string, args ...interface{}) (int64, error)
Tx(ctx context.Context, f func(q Querier) error, opts ...TxOption) error
Conn() PgxConn
Conn(ctx context.Context) PgxConn
}

var _ Querier = &wrappedConn{}
Expand All @@ -36,7 +36,7 @@ func WrapConn(conn PgxConn, scanAPI *pgxscan.API) *wrappedConn {
// Before starting, Select resets the destination slice,
// so if it's not empty it will overwrite all existing elements.
func (n *wrappedConn) Select(ctx context.Context, dst interface{}, sql string, args ...interface{}) error {
rows, err := n.conn.Query(ctx, sql, args...)
rows, err := n.Conn(ctx).Query(ctx, sql, args...)
if err != nil {
return err
}
Expand All @@ -48,7 +48,7 @@ func (n *wrappedConn) Select(ctx context.Context, dst interface{}, sql string, a
// otherwise it returns an error.
// It scans data from single row into the destination.
func (n *wrappedConn) Get(ctx context.Context, dst interface{}, sql string, args ...interface{}) error {
rows, err := n.conn.Query(ctx, sql, args...)
rows, err := n.Conn(ctx).Query(ctx, sql, args...)
if err != nil {
return err
}
Expand All @@ -58,7 +58,7 @@ func (n *wrappedConn) Get(ctx context.Context, dst interface{}, sql string, args

// Exec executes a query without returning any rows and return affected rows.
func (n *wrappedConn) Exec(ctx context.Context, sql string, args ...interface{}) (int64, error) {
res, err := n.conn.Exec(ctx, sql, args...)
res, err := n.Conn(ctx).Exec(ctx, sql, args...)
if err != nil {
return 0, err
}
Expand All @@ -73,7 +73,7 @@ func (n *wrappedConn) Tx(ctx context.Context, f func(q Querier) error, opts ...T
for _, o := range opts {
o(txOpts)
}
err := n.conn.BeginFunc(ctx, func(txx pgx.Tx) error {
err := n.Conn(ctx).BeginFunc(ctx, func(txx pgx.Tx) error {
if txOpts.TransactionTimeout > 0 {
if _, err := txx.Exec(ctx, transactionTimeoutQuery, pgx.QuerySimpleProtocol(true), txOpts.TransactionTimeout); err != nil {
return fmt.Errorf("set transaction timeout: %w", err)
Expand All @@ -92,6 +92,10 @@ func (n *wrappedConn) Tx(ctx context.Context, f func(q Querier) error, opts ...T
return err
}

func (n *wrappedConn) Conn() PgxConn {
func (n *wrappedConn) Conn(ctx context.Context) PgxConn {
if conn, ok := TxFromContext(ctx); ok {
return conn
}

return n.conn
}
23 changes: 23 additions & 0 deletions conn/querier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,27 @@ func TestQuerier(t *testing.T) {
is.True(pgErr.Code == "57014") // 57014 - query_canceled error code
})
})

t.Run("Conn", func(t *testing.T) {
t.Parallel()
querier := WrapConn(db, pgxscan.DefaultAPI)

t.Run("no transaction in ctx", func(t *testing.T) {
is := is.New(t)
ctx := context.Background()

conn := querier.Conn(ctx)
is.True(conn == db) // must be original database connection
})

t.Run("transaction in ctx", func(t *testing.T) {
is := is.New(t)
ctx := context.Background()
tx, err := querier.Conn(ctx).Begin(ctx)
is.NoErr(err)
ctx = NewTxContext(ctx, tx)
conn := querier.Conn(ctx)
is.True(conn == tx) // must be transaction connection
})
})
}
25 changes: 25 additions & 0 deletions conn/tx_options.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package conn

import (
"context"
"time"

"github.com/jackc/pgx/v4"
)

const (
Expand Down Expand Up @@ -31,3 +34,25 @@ func StatementTimeout(d time.Duration) TxOption {
o.StatementTimeout = d.Milliseconds()
}
}

// Context options for transaction

type txKeyType uint8

const (
txKey txKeyType = 0
)

// NewTxContext returns a new context carrying the transaction connection
func NewTxContext(ctx context.Context, tx pgx.Tx) context.Context {
if tx == nil {
return ctx
}
return context.WithValue(ctx, txKey, tx)
}

// TxFromContext extracts the transaction connection if present.
func TxFromContext(ctx context.Context) (pgx.Tx, bool) {
v, ok := ctx.Value(txKey).(pgx.Tx)
return v, ok
}
2 changes: 1 addition & 1 deletion examples/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func main() {
}

// If needed, one can access the PgxConn to call pgx methods directly such as SendBatch, CopyFrom ... .
conn := db.Primary().Conn()
conn := db.Primary().Conn(ctx)
_ = conn
// If needed, one can access the primary or a replica explicitly.
primary, replica := db.Primary(), db.Replica()
Expand Down
20 changes: 20 additions & 0 deletions examples/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,24 @@ func main() {
_, err = q.Exec(ctx, "UPDATE table SET something = 1")
return err
}, conn.StatementTimeout(time.Second))

tx, err := wrapped.Conn(ctx).Begin(ctx)
if err != nil {
log.Fatalf("failed to start transaction: %s", err)
}

// Put a transaction in the context, so that all subsequent calls use the transaction
txCtx := conn.NewTxContext(ctx, tx)
if _, err := wrapped.Exec(txCtx, "UPDATE table SET something = 1"); err != nil {
_ = tx.Rollback(ctx)
log.Fatalf("failed to exec: %s", err)
}
if err := wrapped.Get(txCtx, &count, "SELECT COUNT(*) FROM table"); err != nil {
_ = tx.Rollback(ctx)
log.Fatalf("failed to get: %s", err)
}

if err := tx.Commit(ctx); err != nil {
log.Fatalf("failed to commit transaction: %s", err)
}
}
2 changes: 1 addition & 1 deletion rules.mk
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ go.bumpdeps:
go.fmt:
@set -e; for dir in $(GOMOD_DIRS); do ( set -e; \
cd $$dir; \
$(GO) run mvdan.cc/gofumpt -extra -w -l -s `go list -f '{{.Dir}}' $(WHAT) | grep -v mocks` \
$(GO) run mvdan.cc/gofumpt -extra -w -l `go list -f '{{.Dir}}' $(WHAT) | grep -v mocks` \
); done

VERIFY_STEPS += go.depaware-check
Expand Down
2 changes: 1 addition & 1 deletion txdb/txdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (c *txdbCluster) Replica() conn.Querier {

func (c *txdbCluster) beginOnce(ctx context.Context) (pgx.Tx, error) {
if c.tx == nil {
tx, err := c.cluster.Primary().Conn().Begin(ctx)
tx, err := c.cluster.Primary().Conn(ctx).Begin(ctx)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions txdb/txdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ func TestIntegration_txdb(t *testing.T) {
t.Parallel()
is := is.New(t)

t1, t2 := sendBatch(db.Primary().Conn())
t1, t2 := sendBatch(db.Primary().Conn(ctx))
is.True(!t1.Equal(t2)) // transaction_timestamp not in transaction must be not equal

txdb := New(db)

t1, t2 = sendBatch(txdb.Primary().Conn())
t1, t2 = sendBatch(txdb.Primary().Conn(ctx))
is.True(t1.Equal(t2)) // transaction_timestamp in transaction must be equal
})
})
Expand Down

0 comments on commit 61dad71

Please sign in to comment.