diff --git a/spanner/read.go b/spanner/read.go index 2752b0ec93e0..eefd44b4843a 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -64,6 +64,9 @@ func stream( rpc, nil, nil, + func(err error) error { + return err + }, setTimestamp, release, ) @@ -79,6 +82,7 @@ func streamWithReplaceSessionFunc( rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error), replaceSession func(ctx context.Context) error, setTransactionID func(transactionID), + updateTxState func(err error) error, setTimestamp func(time.Time), release func(error), ) *RowIterator { @@ -89,6 +93,7 @@ func streamWithReplaceSessionFunc( streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession), rowd: &partialResultSetDecoder{}, setTransactionID: setTransactionID, + updateTxState: updateTxState, setTimestamp: setTimestamp, release: release, cancel: cancel, @@ -127,6 +132,7 @@ type RowIterator struct { streamd *resumableStreamDecoder rowd *partialResultSetDecoder setTransactionID func(transactionID) + updateTxState func(err error) error setTimestamp func(time.Time) release func(error) cancel func() @@ -214,7 +220,7 @@ func (r *RowIterator) Next() (*Row, error) { return row, nil } if err := r.streamd.lastErr(); err != nil { - r.err = ToSpannerError(err) + r.err = r.updateTxState(ToSpannerError(err)) } else if !r.rowd.done() { r.err = errEarlyReadEnd() } else { diff --git a/spanner/transaction.go b/spanner/transaction.go index 9e3e107d1065..a33d03628685 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -18,6 +18,7 @@ package spanner import ( "context" + "fmt" "sync" "sync/atomic" "time" @@ -63,6 +64,12 @@ type txReadOnly struct { // operations. txReadEnv + // updateTxStateFunc is a function that updates the state of the current + // transaction based on the given error. This function is by default a no-op, + // but is overridden for read/write transactions to set the state to txAborted + // if Spanner aborts the transaction. + updateTxStateFunc func(err error) error + // Atomic. Only needed for DML statements, but used forall. sequenceNumber int64 @@ -98,6 +105,13 @@ type txReadOnly struct { otConfig *openTelemetryConfig } +func (t *txReadOnly) updateTxState(err error) error { + if t.updateTxStateFunc == nil { + return err + } + return t.updateTxStateFunc(err) +} + // TransactionOptions provides options for a transaction. type TransactionOptions struct { CommitOptions CommitOptions @@ -323,7 +337,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key t.setTransactionID(nil) return client, errInlineBeginTransactionFailed() } - return client, err + return client, t.updateTxState(err) } md, err := client.Header() if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { @@ -338,6 +352,9 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key }, t.replaceSessionFunc, setTransactionID, + func(err error) error { + return t.updateTxState(err) + }, t.setTimestamp, t.release, ) @@ -607,7 +624,7 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que t.setTransactionID(nil) return client, errInlineBeginTransactionFailed() } - return client, err + return client, t.updateTxState(err) } md, err := client.Header() if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { @@ -622,6 +639,9 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que }, t.replaceSessionFunc, setTransactionID, + func(err error) error { + return t.updateTxState(err) + }, t.setTimestamp, t.release) } @@ -673,6 +693,8 @@ const ( txActive // transaction is closed, cannot be used anymore. txClosed + // transaction was aborted by Spanner and should be retried. + txAborted ) // errRtsUnavailable returns error for read transaction's read timestamp being @@ -1216,7 +1238,7 @@ func (t *ReadWriteTransaction) update(ctx context.Context, stmt Statement, opts t.setTransactionID(nil) return 0, errInlineBeginTransactionFailed() } - return 0, ToSpannerError(err) + return 0, t.txReadOnly.updateTxState(ToSpannerError(err)) } if hasInlineBeginTransaction { if resultSet != nil && resultSet.GetMetadata() != nil && resultSet.GetMetadata().GetTransaction() != nil && @@ -1325,7 +1347,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts t.setTransactionID(nil) return nil, errInlineBeginTransactionFailed() } - return nil, ToSpannerError(err) + return nil, t.txReadOnly.updateTxState(ToSpannerError(err)) } haveTransactionID := false @@ -1348,7 +1370,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts return counts, errInlineBeginTransactionFailed() } if resp.Status != nil && resp.Status.Code != 0 { - return counts, spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message) + return counts, t.txReadOnly.updateTxState(spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)) } return counts, nil } @@ -1666,7 +1688,7 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr) } if e != nil { - return resp, toSpannerErrorWithCommitInfo(e, true) + return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(e, true)) } if tstamp := res.GetCommitTimestamp(); tstamp != nil { resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) @@ -1758,6 +1780,7 @@ type ReadWriteStmtBasedTransaction struct { // ReadWriteTransaction contains methods for performing transactional reads. ReadWriteTransaction + client *Client options TransactionOptions } @@ -1783,23 +1806,35 @@ func NewReadWriteStmtBasedTransaction(ctx context.Context, c *Client) (*ReadWrit // used by the transaction will not be returned to the pool and cause a session // leak. // +// ResetForRetry resets the transaction before a retry attempt. This function +// returns a new transaction that should be used for the retry attempt. The +// transaction that is returned by this function is assigned a higher priority +// than the previous transaction, making it less probable to be aborted by +// Spanner again during the retry. +// // NewReadWriteStmtBasedTransactionWithOptions is a configurable version of // NewReadWriteStmtBasedTransaction. func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client, options TransactionOptions) (*ReadWriteStmtBasedTransaction, error) { + return newReadWriteStmtBasedTransactionWithSessionHandle(ctx, c, options, nil) +} + +func newReadWriteStmtBasedTransactionWithSessionHandle(ctx context.Context, c *Client, options TransactionOptions, sh *sessionHandle) (*ReadWriteStmtBasedTransaction, error) { var ( - sh *sessionHandle err error t *ReadWriteStmtBasedTransaction ) - sh, err = c.idleSessions.take(ctx) - if err != nil { - // If session retrieval fails, just fail the transaction. - return nil, err + if sh == nil { + sh, err = c.idleSessions.take(ctx) + if err != nil { + // If session retrieval fails, just fail the transaction. + return nil, err + } } t = &ReadWriteStmtBasedTransaction{ ReadWriteTransaction: ReadWriteTransaction{ txReadyOrClosed: make(chan struct{}), }, + client: c, } t.txReadOnly.sp = c.idleSessions t.txReadOnly.sh = sh @@ -1807,6 +1842,15 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client, t.txReadOnly.qo = c.qo t.txReadOnly.ro = c.ro t.txReadOnly.disableRouteToLeader = c.disableRouteToLeader + t.txReadOnly.updateTxStateFunc = func(err error) error { + if ErrCode(err) == codes.Aborted { + t.mu.Lock() + t.state = txAborted + t.mu.Unlock() + } + return err + } + t.txOpts = c.txo.merge(options) t.ct = c.ct t.otConfig = c.otConfig @@ -1838,6 +1882,7 @@ func (t *ReadWriteStmtBasedTransaction) CommitWithReturnResp(ctx context.Context } if t.sh != nil { t.sh.recycle() + t.sh = nil } return resp, err } @@ -1848,7 +1893,22 @@ func (t *ReadWriteStmtBasedTransaction) Rollback(ctx context.Context) { t.rollback(ctx) if t.sh != nil { t.sh.recycle() + t.sh = nil + } +} + +// ResetForRetry resets the transaction before a retry. This should be +// called if the transaction was aborted by Spanner and the application +// wants to retry the transaction. +// It is recommended to use this method above creating a new transaction, +// as this method will give the transaction a higher priority and thus a +// smaller probability of being aborted again by Spanner. +func (t *ReadWriteStmtBasedTransaction) ResetForRetry(ctx context.Context) (*ReadWriteStmtBasedTransaction, error) { + if t.state != txAborted { + return nil, fmt.Errorf("ResetForRetry should only be called on an active transaction that was aborted by Spanner") } + // Create a new transaction that re-uses the current session if it is available. + return newReadWriteStmtBasedTransactionWithSessionHandle(ctx, t.client, t.options, t.sh) } // writeOnlyTransaction provides the most efficient way of doing write-only diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 983b9902f0b9..930c9c18142e 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -470,8 +470,103 @@ func TestReadWriteStmtBasedTransaction_CommitAborted(t *testing.T) { } } +func TestReadWriteStmtBasedTransaction_QueryAborted(t *testing.T) { + t.Parallel() + rowCount, attempts, err := testReadWriteStmtBasedTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}, + }) + if err != nil { + t.Fatalf("transaction failed to commit: %v", err) + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + t.Fatalf("Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + } + if g, w := attempts, 2; g != w { + t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w) + } +} + +func TestReadWriteStmtBasedTransaction_UpdateAborted(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + // Use a session pool with size 1 to ensure that there are no session leaks. + MinOpened: 1, + MaxOpened: 1, + }, + }) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteSql, + SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}) + + ctx := context.Background() + tx, err := NewReadWriteStmtBasedTransaction(ctx, client) + if err != nil { + t.Fatal(err) + } + _, err = tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}) + if g, w := ErrCode(err), codes.Aborted; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + tx, err = tx.ResetForRetry(ctx) + if err != nil { + t.Fatal(err) + } + c, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}) + if err != nil { + t.Fatal(err) + } + if g, w := c, int64(UpdateBarSetFooRowCount); g != w { + t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestReadWriteStmtBasedTransaction_BatchUpdateAborted(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + // Use a session pool with size 1 to ensure that there are no session leaks. + MinOpened: 1, + MaxOpened: 1, + }, + }) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteBatchDml, + SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}) + + ctx := context.Background() + tx, err := NewReadWriteStmtBasedTransaction(ctx, client) + if err != nil { + t.Fatal(err) + } + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}}) + if g, w := ErrCode(err), codes.Aborted; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + tx, err = tx.ResetForRetry(ctx) + if err != nil { + t.Fatal(err) + } + c, err := tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}}) + if err != nil { + t.Fatal(err) + } + if g, w := c, []int64{UpdateBarSetFooRowCount}; !reflect.DeepEqual(g, w) { + t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w) + } +} + func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) (rowCount int64, attempts int, err error) { - server, client, teardown := setupMockedTestServer(t) + // server, client, teardown := setupMockedTestServer(t) + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + // Use a session pool with size 1 to ensure that there are no session leaks. + MinOpened: 1, + MaxOpened: 1, + }, + }) defer teardown() for method, exec := range executionTimes { server.TestSpanner.PutExecutionTime(method, exec) @@ -500,9 +595,14 @@ func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]S return rowCount, nil } + var tx *ReadWriteStmtBasedTransaction for { attempts++ - tx, err := NewReadWriteStmtBasedTransaction(ctx, client) + if attempts > 1 { + tx, err = tx.ResetForRetry(ctx) + } else { + tx, err = NewReadWriteStmtBasedTransaction(ctx, client) + } if err != nil { return 0, attempts, fmt.Errorf("failed to begin a transaction: %v", err) }