From a55ead07e78436cdbbd09feee3e39c9e2f58197d Mon Sep 17 00:00:00 2001 From: Jack Yu Date: Wed, 16 Oct 2019 11:37:03 +0800 Subject: [PATCH] session: annotate the previous statement to the error when transaction commit failed (#12087) (#12747) --- executor/adapter.go | 5 +++-- session/session_test.go | 4 ++-- session/tidb.go | 15 +++++++++++---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/executor/adapter.go b/executor/adapter.go index 223ac925d0854..d11b4611b5b8f 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -124,7 +124,8 @@ func (a *recordSet) NewChunk() *chunk.Chunk { func (a *recordSet) Close() error { err := a.executor.Close() a.stmt.LogSlowQuery(a.txnStartTS, a.lastErr == nil) - a.stmt.Ctx.GetSessionVars().PrevStmt = a.stmt.OriginText() + sessVars := a.stmt.Ctx.GetSessionVars() + sessVars.PrevStmt = FormatSQL(a.stmt.OriginText(), sessVars) a.stmt.logAudit() return errors.Trace(err) } @@ -443,7 +444,7 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool) { Succ: succ, } if _, ok := a.StmtNode.(*ast.CommitStmt); ok { - slowItems.PrevStmt = FormatSQL(sessVars.PrevStmt, sessVars) + slowItems.PrevStmt = sessVars.PrevStmt } if costTime < threshold { logutil.SlowQueryLogger.Debug(sessVars.SlowLogFormat(slowItems)) diff --git a/session/session_test.go b/session/session_test.go index 07927ac0c55ea..f4da72ad9726b 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1461,12 +1461,12 @@ func (s *testSessionSuite) TestUnique(c *C) { c.Assert(err, NotNil) // Check error type and error message c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue, Commentf("err %v", err)) - c.Assert(err.Error(), Equals, "[kv:1062]Duplicate entry '1' for key 'PRIMARY'") + c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(1, 1);: [kv:1062]Duplicate entry '1' for key 'PRIMARY'") _, err = tk1.Exec("commit") c.Assert(err, NotNil) c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue, Commentf("err %v", err)) - c.Assert(err.Error(), Equals, "[kv:1062]Duplicate entry '2' for key 'val'") + c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(2, 2);: [kv:1062]Duplicate entry '2' for key 'val'") // Test for https://github.com/pingcap/tidb/issues/463 tk.MustExec("drop table test;") diff --git a/session/tidb.go b/session/tidb.go index 677f4557e9a57..79448f8a99be0 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -156,7 +156,8 @@ func Compile(ctx context.Context, sctx sessionctx.Context, stmtNode ast.StmtNode return stmt, errors.Trace(err) } -func finishStmt(ctx context.Context, sctx sessionctx.Context, se *session, sessVars *variable.SessionVars, meetsErr error) error { +func finishStmt(ctx context.Context, sctx sessionctx.Context, se *session, sessVars *variable.SessionVars, + meetsErr error, sql sqlexec.Statement) error { if meetsErr != nil { if !sessVars.InTxn() { logutil.Logger(context.Background()).Info("rollbackTxn for ddl/autocommit error.") @@ -166,7 +167,13 @@ func finishStmt(ctx context.Context, sctx sessionctx.Context, se *session, sessV } if !sessVars.InTxn() { - return se.CommitTxn(ctx) + if err := se.CommitTxn(ctx); err != nil { + if _, ok := sql.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { + err = errors.Annotatef(err, "previous statement: %s", se.GetSessionVars().PrevStmt) + } + return err + } + return nil } return checkStmtLimit(ctx, sctx, se, sessVars) @@ -197,7 +204,7 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) // then it could include the transaction commit time. if rs == nil { s.(*executor.ExecStmt).LogSlowQuery(origTxnCtx.StartTS, err == nil) - sessVars.PrevStmt = s.OriginText() + sessVars.PrevStmt = executor.FormatSQL(s.OriginText(), sessVars) } }() @@ -225,7 +232,7 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) } } - err = finishStmt(ctx, sctx, se, sessVars, err) + err = finishStmt(ctx, sctx, se, sessVars, err, s) if se.txn.pending() { // After run statement finish, txn state is still pending means the // statement never need a Txn(), such as: