Skip to content

Commit

Permalink
session: add a doNotCommit flag to transaction when StmtCommit fa…
Browse files Browse the repository at this point in the history
…il (#8918) (#8924)
  • Loading branch information
tiancaiamao authored and zz-jason committed Jan 4, 2019
1 parent 5a1ec09 commit 63070fc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
16 changes: 12 additions & 4 deletions session/session_fail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,23 @@ import (
)

func (s *testSessionSuite) TestFailStatementCommit(c *C) {
defer gofail.Disable("github.com/pingcap/tidb/session/mockStmtCommitError")

tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("create table t (id int)")
tk.MustExec("begin")
tk.MustExec("insert into t values (1)")

gofail.Enable("github.com/pingcap/tidb/session/mockStmtCommitError", `return(true)`)
_, err := tk.Exec("insert into t values (2)")
_, err := tk.Exec("insert into t values (2),(3),(4),(5)")
c.Assert(err, NotNil)
tk.MustExec("commit")
tk.MustQuery(`select * from t`).Check(testkit.Rows("1"))

gofail.Disable("github.com/pingcap/tidb/session/mockStmtCommitError")

tk.MustQuery("select * from t").Check(testkit.Rows("1"))
tk.MustExec("insert into t values (3)")
tk.MustExec("insert into t values (4)")
_, err = tk.Exec("commit")
c.Assert(err, NotNil)

tk.MustQuery(`select * from t`).Check(testkit.Rows())
}
33 changes: 31 additions & 2 deletions session/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ type TxnState struct {
buf kv.MemBuffer
mutations map[int64]*binlog.TableMutation
dirtyTableOP []dirtyTableOperation

// If doNotCommit is not nil, Commit() will not commit the transaction.
// doNotCommit flag may be set when StmtCommit fail.
doNotCommit error
}

func (st *TxnState) init() {
Expand Down Expand Up @@ -143,19 +147,37 @@ type dirtyTableOperation struct {

// Commit overrides the Transaction interface.
func (st *TxnState) Commit(ctx context.Context) error {
defer st.reset()
if len(st.mutations) != 0 || len(st.dirtyTableOP) != 0 || st.buf.Len() != 0 {
log.Errorf("The code should never run here, TxnState=%#v, mutations=%#v, dirtyTableOP=%#v, buf=%#v something must be wrong: %s",
st,
st.mutations,
st.dirtyTableOP,
st.buf,
debug.Stack())
st.cleanup()
return errors.New("invalid transaction")
}
if st.doNotCommit != nil {
if err1 := st.Transaction.Rollback(); err1 != nil {
log.Error(err1)
}
return errors.Trace(st.doNotCommit)
}
return errors.Trace(st.Transaction.Commit(ctx))
}

// Rollback overrides the Transaction interface.
func (st *TxnState) Rollback() error {
defer st.reset()
return errors.Trace(st.Transaction.Rollback())
}

func (st *TxnState) reset() {
st.doNotCommit = nil
st.cleanup()
st.changeToInvalid()
}

// Get overrides the Transaction interface.
func (st *TxnState) Get(k kv.Key) ([]byte, error) {
val, err := st.buf.Get(k)
Expand Down Expand Up @@ -282,17 +304,24 @@ func (s *session) getTxnFuture(ctx context.Context) *txnFuture {
func (s *session) StmtCommit() error {
defer s.txn.cleanup()
st := &s.txn
var count int
err := kv.WalkMemBuffer(st.buf, func(k kv.Key, v []byte) error {

// gofail: var mockStmtCommitError bool
// if mockStmtCommitError {
// return errors.New("mock stmt commit error")
// count++
// }
if count > 3 {
return errors.New("mock stmt commit error")
}

if len(v) == 0 {
return errors.Trace(st.Transaction.Delete(k))
}
return errors.Trace(st.Transaction.Set(k, v))
})
if err != nil {
st.doNotCommit = err
return errors.Trace(err)
}

Expand Down

0 comments on commit 63070fc

Please sign in to comment.