diff --git a/pkg/kv/BUILD.bazel b/pkg/kv/BUILD.bazel index d26699312688..2cd3c79f0f0a 100644 --- a/pkg/kv/BUILD.bazel +++ b/pkg/kv/BUILD.bazel @@ -55,16 +55,23 @@ go_test( "//pkg/config/zonepb", "//pkg/keys", "//pkg/kv/kvserver", + "//pkg/kv/kvserver/concurrency/lock", "//pkg/kv/kvserver/kvserverbase", "//pkg/roachpb", "//pkg/security", "//pkg/security/securitytest", "//pkg/server", + "//pkg/sql", + "//pkg/sql/catalog/descpb", + "//pkg/sql/rowenc", + "//pkg/sql/sem/tree", "//pkg/storage/enginepb", "//pkg/testutils", "//pkg/testutils/kvclientutils", "//pkg/testutils/serverutils", + "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", + "//pkg/util/encoding", "//pkg/util/hlc", "//pkg/util/leaktest", "//pkg/util/log", diff --git a/pkg/kv/txn_external_test.go b/pkg/kv/txn_external_test.go index 2f082ac5c4a1..e44e7aad79b1 100644 --- a/pkg/kv/txn_external_test.go +++ b/pkg/kv/txn_external_test.go @@ -18,16 +18,25 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/kv/kvserver" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver/concurrency/lock" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/rowenc" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/testutils/kvclientutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/encoding" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -410,3 +419,229 @@ func TestChildTransactionDeadlockDetection(t *testing.T) { t.Fatalf("unexpected outcome: a=%s, b=%s", strA, strB) } } + +// TestChildTxnSelfInteractions is an integration-style test of child +// transaction interactions with parent state. +func TestChildTxnSelfInteractions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + type testCase struct { + name string + f func(t *testing.T, db *kv.DB) + } + run := func(test testCase) { + t.Run(test.name, func(t *testing.T) { + tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{}) + defer tc.Stopper().Stop(ctx) + db := tc.Server(0).DB() + test.f(t, db) + }) + } + testCases := []testCase{ + { + "child reads parent's write", + func(t *testing.T, db *kv.DB) { + require.NoError(t, db.Txn(ctx, func( + ctx context.Context, txn *kv.Txn, + ) error { + require.NoError(t, txn.Put(ctx, "foo", "bar")) + return txn.ChildTxn(ctx, func( + ctx context.Context, childTxn *kv.Txn, + ) error { + got, err := childTxn.Get(ctx, "foo") + require.NoError(t, err) + gotBytes, err := got.Value.GetBytes() + require.NoError(t, err) + require.Equal(t, string(gotBytes), "bar") + return nil + }) + })) + }, + }, + { + "child write-write conflict gets an error", + func(t *testing.T, db *kv.DB) { + require.Regexp(t, + `descendant transaction \w+ attempted to write over ancestor \w+ at key "foo"`, + db.Txn(ctx, func( + ctx context.Context, txn *kv.Txn, + ) error { + require.NoError(t, txn.Put(ctx, "foo", "bar")) + + // This attempt to write over the parent's intent will fail. + return txn.ChildTxn(ctx, func( + ctx context.Context, childTxn *kv.Txn, + ) error { + err := childTxn.Put(ctx, "foo", "baz") + return err + }) + })) + }, + }, + { + + // In this case the child should pass through the read lock of the parent, + // write successfully, and then it should push the parent which will then + // be forced to refresh. + "child read-write conflict forces parent to get an error (locking)", + func(t *testing.T, db *kv.DB) { + k := roachpb.Key("foo") + require.NoError(t, db.Put(ctx, k, "bar")) + require.Regexp(t, + `cannot forward provisional commit timestamp due to overlapping write`, + db.Txn(ctx, func( + ctx context.Context, txn *kv.Txn, + ) error { + scan := &roachpb.ScanRequest{ + + KeyLocking: lock.Exclusive, + } + scan.Key = k + scan.EndKey = k.PrefixEnd() + b := txn.NewBatch() + b.AddRawRequest(scan) + require.NoError(t, txn.Run(ctx, b)) + + // The below write will succeed but the forwarding of the parent above + // the write's timestamp will fail as it detects that the parent's + // read has been invalidated. + return txn.ChildTxn(ctx, func( + ctx context.Context, childTxn *kv.Txn, + ) error { + err := childTxn.Put(ctx, k, "baz") + require.NoError(t, err) + return nil + }) + })) + }, + }, + } + for _, tc := range testCases { + run(tc) + } +} + +// TestRestartDueToAbortedParentTransactionDetectedByAncestor works by ensuring +// that the auto-retry facilities of the SQL layer interact with the retry +// detected due to a child detecting the parent having been aborted propagates +// properly and leads to an automatic retry. +func TestRestartDueToAbortedParentTransactionDetectedByAncestor(t *testing.T) { + defer leaktest.AfterTest(t)() + + // The basic idea is that after the transaction runs its first write, the + // knob will inject control flow into the test before the prepare of the + // second write. At that point, a separate transaction with a higher priority + // will write to a scratch key and then get blocked on the first write. + // + // Then, a child transaction on behalf of this sql transaction will run and + // attempt to write over the scratch key. Deadlock detection will fire and + // the other transaction will abort the SQL transaction. The child will + // detect this abort (test asserts that this error is received) and will + // turn it into a restart. At that point, the other transaction will be + // allowed to proceed and commit and the sql transaction will be allowed to + // restart and succeed. + const ( + createTable = "CREATE TABLE foo (i INT PRIMARY KEY, j INT)" + firstWrite = "UPSERT INTO foo VALUES (1, 2)" + secondWrite = "INSERT INTO foo VALUES (2, 2)" + txn = "BEGIN;" + + " " + firstWrite + ";" + + // Use prepare because it has a handy testing knob which provides access + // to the underlying transaction. + " PREPARE second AS " + secondWrite + ";" + + " EXECUTE second; " + + " COMMIT;" + ) + blockCh := make(chan chan error, 1) + var interceptedTxn *kv.Txn + tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLExecutor: &sql.ExecutorTestingKnobs{ + BeforePrepare: func(ctx context.Context, stmt string, txn *kv.Txn) error { + if stmt == secondWrite { + interceptedTxn = txn + ch := make(chan error) + select { + case blockCh <- ch: // will set to nil after the first pass + return <-ch + default: + } + } + return nil + }, + }, + }, + }, + }) + ctx := context.Background() + defer tc.Stopper().Stop(ctx) + scratch := tc.ScratchRange(t) + + sqlDB := tc.ServerConn(0) + tdb := sqlutils.MakeSQLRunner(sqlDB) + tdb.Exec(t, createTable) + var fooTableID descpb.ID + tdb.QueryRow(t, "SELECT $1::regclass::int", "foo").Scan(&fooTableID) + + db := tc.Server(0).DB() + codec := tc.Server(0).ExecutorConfig().(sql.ExecutorConfig).Codec + + errCh := make(chan error, 1) + go func() { + _, err := tc.ServerConn(0).Exec(txn) + errCh <- err + }() + + unblock := <-blockCh + + // We want to create a deadlock situation where our new transaction has a + // higher priority than the parent (and its descendants). + txnProto := roachpb.MakeTransaction( + "test", nil, roachpb.NormalUserPriority, db.Clock().Now(), + db.Clock().MaxOffset().Nanoseconds()) + txnProto.Priority = interceptedTxn.TestingCloneTxn().Priority + 1 + otherTxn := kv.NewTxnFromProto( + ctx, db, tc.Server(0).NodeID(), + txnProto.ReadTimestamp.UnsafeToClockTimestamp(), kv.RootTxn, &txnProto) + + // Write to scratch with otherTxn, then attempt to write over the write by + // the mainTxn. + require.NoError(t, otherTxn.Put(ctx, scratch, "bar")) + + indexKeyPrefix, err := rowenc.EncodeTableKey( + codec.IndexPrefix(uint32(fooTableID), 1), + tree.NewDInt(1), + encoding.Ascending, + ) + require.NoError(t, err) + key := keys.MakeFamilyKey(indexKeyPrefix, 0) + + otherWriteErrCh := make(chan error) + go func() { + otherWriteErrCh <- otherTxn.Put(ctx, key, 1) + }() + select { + case err := <-otherWriteErrCh: + t.Fatal(err) + case <-time.After(time.Millisecond): + } + + err = interceptedTxn.ChildTxn(ctx, func(_ context.Context, childTxn *kv.Txn) error { + err := childTxn.Put(ctx, scratch, "baz") + require.Truef(t, errors.HasType(err, (*roachpb.AncestorAbortedError)(nil)), "%T", err) + return err + }) + if retry := (*roachpb.TransactionRetryWithProtoRefreshError)(nil); !errors.As(err, &retry) { + t.Fatalf("expected %T, got %T", + (*roachpb.TransactionRetryWithProtoRefreshError)(nil), err) + } else { + require.Equal(t, interceptedTxn.ID(), retry.TxnID) + } + blockCh = nil + unblock <- err + require.NoError(t, <-otherWriteErrCh) + require.NoError(t, otherTxn.Commit(ctx)) + require.NoError(t, <-errCh) +}