diff --git a/executor/executor.go b/executor/executor.go index 11dd342b4a100..b03af3671ddea 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1173,6 +1173,8 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc := new(stmtctx.StatementContext) sc.TimeZone = vars.Location() sc.MemTracker = memory.NewTracker(s.Text(), vars.MemQuotaQuery) + sc.NowTs = time.Time{} + sc.SysTs = time.Time{} switch config.GetGlobalConfig().OOMAction { case config.OOMActionCancel: sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{}) diff --git a/expression/builtin_time.go b/expression/builtin_time.go index ea4fde8971b4e..f2988489d5000 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -1971,7 +1971,11 @@ func (b *builtinCurrentDateSig) Clone() builtinFunc { // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_curdate func (b *builtinCurrentDateSig) evalTime(row chunk.Row) (d types.Time, isNull bool, err error) { tz := b.ctx.GetSessionVars().Location() - year, month, day := time.Now().In(tz).Date() + var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + year, month, day := nowTs.In(tz).Date() result := types.Time{ Time: types.FromDate(year, int(month), day, 0, 0, 0, 0), Type: mysql.TypeDate, @@ -2026,7 +2030,11 @@ func (b *builtinCurrentTime0ArgSig) Clone() builtinFunc { func (b *builtinCurrentTime0ArgSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { tz := b.ctx.GetSessionVars().Location() - dur := time.Now().In(tz).Format(types.TimeFormat) + var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + dur := nowTs.In(tz).Format(types.TimeFormat) res, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, dur, types.MinFsp) if err != nil { return types.Duration{}, true, errors.Trace(err) @@ -2050,7 +2058,11 @@ func (b *builtinCurrentTime1ArgSig) evalDuration(row chunk.Row) (types.Duration, return types.Duration{}, true, errors.Trace(err) } tz := b.ctx.GetSessionVars().Location() - dur := time.Now().In(tz).Format(types.TimeFSPFormat) + var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + dur := nowTs.In(tz).Format(types.TimeFSPFormat) res, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, dur, int(fsp)) if err != nil { return types.Duration{}, true, errors.Trace(err) @@ -2188,7 +2200,11 @@ func (b *builtinUTCDateSig) Clone() builtinFunc { // evalTime evals UTC_DATE, UTC_DATE(). // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-date func (b *builtinUTCDateSig) evalTime(row chunk.Row) (types.Time, bool, error) { - year, month, day := time.Now().UTC().Date() + var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + year, month, day := nowTs.UTC().Date() result := types.Time{ Time: types.FromGoTime(time.Date(year, month, day, 0, 0, 0, 0, time.UTC)), Type: mysql.TypeDate, @@ -2244,8 +2260,12 @@ func (c *utcTimestampFunctionClass) getFunction(ctx sessionctx.Context, args []E return sig, nil } -func evalUTCTimestampWithFsp(fsp int) (types.Time, bool, error) { - result, err := convertTimeToMysqlTime(time.Now().UTC(), fsp) +func evalUTCTimestampWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool, error) { + var nowTs = &ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp) if err != nil { return types.Time{}, true, errors.Trace(err) } @@ -2277,7 +2297,7 @@ func (b *builtinUTCTimestampWithArgSig) evalTime(row chunk.Row) (types.Time, boo return types.Time{}, true, errors.Errorf("Invalid negative %d specified, must in [0, 6].", num) } - result, isNull, err := evalUTCTimestampWithFsp(int(num)) + result, isNull, err := evalUTCTimestampWithFsp(b.ctx, int(num)) return result, isNull, errors.Trace(err) } @@ -2294,7 +2314,7 @@ func (b *builtinUTCTimestampWithoutArgSig) Clone() builtinFunc { // evalTime evals UTC_TIMESTAMP(). // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-timestamp func (b *builtinUTCTimestampWithoutArgSig) evalTime(row chunk.Row) (types.Time, bool, error) { - result, isNull, err := evalUTCTimestampWithFsp(0) + result, isNull, err := evalUTCTimestampWithFsp(b.ctx, 0) return result, isNull, errors.Trace(err) } @@ -2328,12 +2348,16 @@ func (c *nowFunctionClass) getFunction(ctx sessionctx.Context, args []Expression } func evalNowWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool, error) { - sysTs, err := getSystemTimestamp(ctx) - if err != nil { - return types.Time{}, true, errors.Trace(err) + var sysTs = &ctx.GetSessionVars().StmtCtx.SysTs + if sysTs.Equal(time.Time{}) { + var err error + *sysTs, err = getSystemTimestamp(ctx) + if err != nil { + return types.Time{}, true, errors.Trace(err) + } } - result, err := convertTimeToMysqlTime(sysTs, fsp) + result, err := convertTimeToMysqlTime(*sysTs, fsp) if err != nil { return types.Time{}, true, errors.Trace(err) } @@ -3557,7 +3581,11 @@ func (b *builtinUnixTimestampCurrentSig) Clone() builtinFunc { // evalInt evals a UNIX_TIMESTAMP(). // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp func (b *builtinUnixTimestampCurrentSig) evalInt(row chunk.Row) (int64, bool, error) { - dec, err := goTimeToMysqlUnixTimestamp(time.Now(), 1) + var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + dec, err := goTimeToMysqlUnixTimestamp(*nowTs, 1) if err != nil { return 0, true, errors.Trace(err) } @@ -5497,7 +5525,11 @@ func (b *builtinUTCTimeWithoutArgSig) Clone() builtinFunc { // evalDuration evals a builtinUTCTimeWithoutArgSig. // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-time func (b *builtinUTCTimeWithoutArgSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { - v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, time.Now().UTC().Format(types.TimeFormat), 0) + var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, nowTs.UTC().Format(types.TimeFormat), 0) return v, false, err } @@ -5524,7 +5556,11 @@ func (b *builtinUTCTimeWithArgSig) evalDuration(row chunk.Row) (types.Duration, if fsp < int64(types.MinFsp) { return types.Duration{}, true, errors.Errorf("Invalid negative %d specified, must in [0, 6].", fsp) } - v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, time.Now().UTC().Format(types.TimeFSPFormat), int(fsp)) + var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs + if nowTs.Equal(time.Time{}) { + *nowTs = time.Now() + } + v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, nowTs.UTC().Format(types.TimeFSPFormat), int(fsp)) return v, false, err } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index fadadc2783599..844e4e759c410 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -762,6 +762,11 @@ func (s *testEvaluatorSuite) TestTime(c *C) { c.Assert(err, IsNil) } +func resetStmtContext(ctx sessionctx.Context) { + ctx.GetSessionVars().StmtCtx.NowTs = time.Time{} + ctx.GetSessionVars().StmtCtx.SysTs = time.Time{} +} + func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { defer testleak.AfterTest(c)() @@ -778,6 +783,7 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { {funcs[ast.Now], func() time.Time { return time.Now() }}, {funcs[ast.UTCTimestamp], func() time.Time { return time.Now().UTC() }}, } { + resetStmtContext(s.ctx) f, err := x.fc.getFunction(s.ctx, s.datumsToConstants(nil)) c.Assert(err, IsNil) v, err := evalBuiltinFunc(f, chunk.Row{}) @@ -789,6 +795,7 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { c.Assert(strings.Contains(t.String(), "."), IsFalse) c.Assert(ts.Sub(gotime(t, ts.Location())), LessEqual, time.Second) + resetStmtContext(s.ctx) f, err = x.fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(6))) c.Assert(err, IsNil) v, err = evalBuiltinFunc(f, chunk.Row{}) @@ -798,11 +805,13 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { c.Assert(strings.Contains(t.String(), "."), IsTrue) c.Assert(ts.Sub(gotime(t, ts.Location())), LessEqual, time.Millisecond) + resetStmtContext(s.ctx) f, err = x.fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(8))) c.Assert(err, IsNil) _, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, NotNil) + resetStmtContext(s.ctx) f, err = x.fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(-2))) c.Assert(err, IsNil) _, err = evalBuiltinFunc(f, chunk.Row{}) @@ -813,6 +822,7 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { variable.SetSessionSystemVar(s.ctx.GetSessionVars(), "time_zone", types.NewDatum("+00:00")) variable.SetSessionSystemVar(s.ctx.GetSessionVars(), "timestamp", types.NewDatum(1234)) fc := funcs[ast.Now] + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, s.datumsToConstants(nil)) c.Assert(err, IsNil) v, err := evalBuiltinFunc(f, chunk.Row{}) @@ -877,6 +887,7 @@ func (s *testEvaluatorSuite) TestAddTimeSig(c *C) { // This is a test for issue 7334 du := newDateArighmeticalUtil() + resetStmtContext(s.ctx) now, _, err := evalNowWithFsp(s.ctx, 0) c.Assert(err, IsNil) res, _, err := du.add(s.ctx, now, "1", "MICROSECOND") @@ -1203,6 +1214,7 @@ func (s *testEvaluatorSuite) TestUTCTime(c *C) { }{{0, 8}, {3, 12}, {6, 15}, {-1, 0}, {7, 0}} for _, test := range tests { + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(test.param))) c.Assert(err, IsNil) v, err := evalBuiltinFunc(f, chunk.Row{}) @@ -1229,6 +1241,7 @@ func (s *testEvaluatorSuite) TestUTCDate(c *C) { defer testleak.AfterTest(c)() last := time.Now().UTC() fc := funcs[ast.UTCDate] + resetStmtContext(mock.NewContext()) f, err := fc.getFunction(mock.NewContext(), s.datumsToConstants(nil)) c.Assert(err, IsNil) v, err := evalBuiltinFunc(f, chunk.Row{}) @@ -1500,6 +1513,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) { types.NewStringDatum(test.t1), types.NewStringDatum(test.t2), } + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) c.Assert(err, IsNil) d, err := evalBuiltinFunc(f, chunk.Row{}) @@ -1509,6 +1523,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) { sc := s.ctx.GetSessionVars().StmtCtx sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewStringDatum("DAY"), types.NewStringDatum("2017-01-00"), types.NewStringDatum("2017-01-01")})) @@ -1517,6 +1532,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) { c.Assert(err, IsNil) c.Assert(d.Kind(), Equals, types.KindNull) + resetStmtContext(s.ctx) f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewStringDatum("DAY"), {}, types.NewStringDatum("2017-01-01")})) c.Assert(err, IsNil) @@ -1528,6 +1544,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) { func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) { // Test UNIX_TIMESTAMP(). fc := funcs[ast.UnixTimestamp] + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, nil) c.Assert(err, IsNil) d, err := evalBuiltinFunc(f, chunk.Row{}) @@ -1537,12 +1554,14 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) { // https://github.com/pingcap/tidb/issues/2496 // Test UNIX_TIMESTAMP(NOW()). + resetStmtContext(s.ctx) now, isNull, err := evalNowWithFsp(s.ctx, 0) c.Assert(err, IsNil) c.Assert(isNull, IsFalse) n := types.Datum{} n.SetMysqlTime(now) args := []types.Datum{n} + resetStmtContext(s.ctx) f, err = fc.getFunction(s.ctx, s.datumsToConstants(args)) c.Assert(err, IsNil) d, err = evalBuiltinFunc(f, chunk.Row{}) @@ -1554,6 +1573,7 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) { // https://github.com/pingcap/tidb/issues/2852 // Test UNIX_TIMESTAMP(NULL). args = []types.Datum{types.NewDatum(nil)} + resetStmtContext(s.ctx) f, err = fc.getFunction(s.ctx, s.datumsToConstants(args)) c.Assert(err, IsNil) d, err = evalBuiltinFunc(f, chunk.Row{}) @@ -1598,6 +1618,7 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) { fmt.Printf("Begin Test %v\n", test) expr := s.datumsToConstants([]types.Datum{test.input}) expr[0].GetType().Decimal = test.inputDecimal + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, expr) c.Assert(err, IsNil, Commentf("%+v", test)) d, err := evalBuiltinFunc(f, chunk.Row{}) @@ -1681,6 +1702,7 @@ func (s *testEvaluatorSuite) TestTimestamp(c *C) { } fc := funcs[ast.Timestamp] for _, test := range tests { + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, s.datumsToConstants(test.t)) c.Assert(err, IsNil) d, err := evalBuiltinFunc(f, chunk.Row{}) @@ -1690,6 +1712,7 @@ func (s *testEvaluatorSuite) TestTimestamp(c *C) { } nilDatum := types.NewDatum(nil) + resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{nilDatum})) c.Assert(err, IsNil) d, err := evalBuiltinFunc(f, chunk.Row{}) @@ -2357,6 +2380,7 @@ func (s *testEvaluatorSuite) TestWithTimeZone(c *C) { for _, t := range tests { now := time.Now().In(sv.TimeZone) + resetStmtContext(s.ctx) f, err := funcs[t.method].getFunction(s.ctx, s.datumsToConstants(t.Input)) c.Assert(err, IsNil) d, err := evalBuiltinFunc(f, chunk.Row{}) diff --git a/expression/function_traits.go b/expression/function_traits.go index 8d8913fd354d4..24ffa391db6c0 100644 --- a/expression/function_traits.go +++ b/expression/function_traits.go @@ -19,23 +19,12 @@ import ( // UnCacheableFunctions stores functions which can not be cached to plan cache. var UnCacheableFunctions = map[string]struct{}{ - ast.Now: {}, - ast.CurrentTimestamp: {}, - ast.UTCTime: {}, - ast.Curtime: {}, - ast.CurrentTime: {}, - ast.UTCTimestamp: {}, - ast.UnixTimestamp: {}, - ast.Sysdate: {}, - ast.Curdate: {}, - ast.CurrentDate: {}, - ast.UTCDate: {}, - ast.Database: {}, - ast.CurrentUser: {}, - ast.User: {}, - ast.ConnectionID: {}, - ast.LastInsertId: {}, - ast.Version: {}, + ast.Database: {}, + ast.CurrentUser: {}, + ast.User: {}, + ast.ConnectionID: {}, + ast.LastInsertId: {}, + ast.Version: {}, } // unFoldableFunctions stores functions which can not be folded duration constant folding stage. @@ -52,6 +41,23 @@ var unFoldableFunctions = map[string]struct{}{ ast.GetParam: {}, } +// DeferredFunctions stores non-deterministic functions, which can be deferred only when the plan cache is enabled. +var DeferredFunctions = map[string]struct{}{ + ast.Now: {}, + ast.CurrentTimestamp: {}, + ast.UTCTime: {}, + ast.Curtime: {}, + ast.CurrentTime: {}, + ast.UTCTimestamp: {}, + ast.UnixTimestamp: {}, + ast.Sysdate: {}, + ast.Curdate: {}, + ast.CurrentDate: {}, + ast.UTCDate: {}, + ast.Rand: {}, + ast.UUID: {}, +} + // inequalFunctions stores functions which cannot be propagated from column equal condition. var inequalFunctions = map[string]struct{}{ ast.IsNull: {}, diff --git a/expression/scalar_function.go b/expression/scalar_function.go index e0c1a13b8ea00..2e70a1684770c 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -70,8 +70,8 @@ func (sf *ScalarFunction) MarshalJSON() ([]byte, error) { return []byte(fmt.Sprintf("\"%s\"", sf)), nil } -// NewFunction creates a new scalar function or constant. -func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { +// newFunctionImpl creates a new scalar function or constant. +func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { if retType == nil { return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.") } @@ -96,7 +96,20 @@ func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldTy RetType: retType, Function: f, } - return FoldConstant(sf), nil + if fold { + return FoldConstant(sf), nil + } + return sf, nil +} + +// NewFunction creates a new scalar function or constant via a constant folding. +func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { + return newFunctionImpl(ctx, true, funcName, retType, args...) +} + +// NewFunctionBase creates a new scalar function with no constant folding. +func NewFunctionBase(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { + return newFunctionImpl(ctx, false, funcName, retType, args...) } // NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally. diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 3693a5d92d55f..50fa8d75f1377 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -757,13 +757,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok value := &expression.Constant{Value: v.Datum, RetType: &v.Type} er.ctxStack = append(er.ctxStack, value) case *driver.ParamMarkerExpr: - tp := types.NewFieldType(mysql.TypeUnspecified) - types.DefaultParamTypeForValue(v.GetValue(), tp) - value := &expression.Constant{Value: v.Datum, RetType: tp} - if er.useCache() { - value.DeferredExpr = er.getParamExpression(v) - } - er.ctxStack = append(er.ctxStack, value) + er.paramToExpression(v) case *ast.VariableExpr: er.rewriteVariable(v) case *ast.FuncCallExpr: @@ -820,17 +814,18 @@ func datumToConstant(d types.Datum, tp byte) *expression.Constant { return &expression.Constant{Value: d, RetType: types.NewFieldType(tp)} } -func (er *expressionRewriter) getParamExpression(v *driver.ParamMarkerExpr) expression.Expression { - f, err := expression.NewFunction(er.ctx, - ast.GetParam, - &v.Type, - datumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong)) - if err != nil { - er.err = errors.Trace(err) - return nil - } - f.GetType().Tp = v.Type.Tp - return f +func (er *expressionRewriter) paramToExpression(v *driver.ParamMarkerExpr) { + tp := types.NewFieldType(mysql.TypeUnspecified) + types.DefaultParamTypeForValue(v.GetValue(), tp) + value := &expression.Constant{Value: v.Datum, RetType: tp} + if er.useCache() { + var f expression.Expression + f, er.err = expression.NewFunctionBase(er.ctx, ast.GetParam, &v.Type, + datumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong)) + f.GetType().Tp = v.Type.Tp + value.DeferredExpr = f + } + er.ctxStack = append(er.ctxStack, value) } func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { @@ -1220,9 +1215,16 @@ func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) { return } var function expression.Expression - function, er.err = expression.NewFunction(er.ctx, v.FnName.L, &v.Type, args...) er.ctxStack = er.ctxStack[:stackLen-len(v.Args)] - er.ctxStack = append(er.ctxStack, function) + if _, ok := expression.DeferredFunctions[v.FnName.L]; er.useCache() && ok { + function, er.err = expression.NewFunctionBase(er.ctx, v.FnName.L, &v.Type, args...) + c := &expression.Constant{Value: types.NewDatum(nil), RetType: &v.Type, DeferredExpr: function} + c.GetType().Tp = function.GetType().Tp + er.ctxStack = append(er.ctxStack, c) + } else { + function, er.err = expression.NewFunction(er.ctx, v.FnName.L, &v.Type, args...) + er.ctxStack = append(er.ctxStack, function) + } } func (er *expressionRewriter) toColumn(v *ast.ColumnName) { diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index 1a456fbc679b3..67e43c8927337 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -14,10 +14,16 @@ package core_test import ( + "time" + . "github.com/pingcap/check" + "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" + dto "github.com/prometheus/client_model/go" ) var _ = Suite(&testPrepareSuite{}) @@ -89,3 +95,93 @@ func (s *testPrepareSuite) TestPrepareCacheIndexScan(c *C) { tk.MustQuery("execute stmt1 using @a, @b").Check(testkit.Rows("1 3", "1 3")) tk.MustQuery("execute stmt1 using @a, @b").Check(testkit.Rows("1 3", "1 3")) } + +func (s *testPlanSuite) TestPrepareCacheDeferredFunction(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + orgEnable := core.PreparedPlanCacheEnabled() + orgCapacity := core.PreparedPlanCacheCapacity + defer func() { + dom.Close() + store.Close() + core.SetPreparedPlanCache(orgEnable) + core.PreparedPlanCacheCapacity = orgCapacity + }() + core.SetPreparedPlanCache(true) + core.PreparedPlanCacheCapacity = 100 + + defer testleak.AfterTest(c)() + + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (id int PRIMARY KEY, c1 TIMESTAMP(3) NOT NULL DEFAULT '0000-00-00 00:00:00', KEY idx1 (c1))") + tk.MustExec("prepare sel1 from 'select id, c1 from t1 where c1 < now(3)'") + + sql1 := "execute sel1" + expectedPattern := `IndexReader\(Index\(t1.idx1\)\[\[-inf,[0-9]{4}-(0[1-9]|1[0-2])-(0[1-9]|[1-2][0-9]|3[0-1]) (2[0-3]|[01][0-9]):[0-5][0-9]:[0-5][0-9].000\)\]\)` + + var cnt [2]float64 + var planStr [2]string + metrics.PlanCacheCounter.Reset() + counter := metrics.PlanCacheCounter.WithLabelValues("prepare") + for i := 0; i < 2; i++ { + stmt, err := s.ParseOneStmt(sql1, "", "") + c.Check(err, IsNil) + is := tk.Se.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema) + builder := core.NewPlanBuilder(tk.Se, is) + p, err := builder.Build(stmt) + c.Check(err, IsNil) + execPlan, ok := p.(*core.Execute) + c.Check(ok, IsTrue) + executor.ResetContextOfStmt(tk.Se, stmt) + err = execPlan.OptimizePreparedPlan(tk.Se, is) + c.Check(err, IsNil) + planStr[i] = core.ToString(execPlan.Plan) + c.Check(planStr[i], Matches, expectedPattern, Commentf("for %s", sql1)) + pb := &dto.Metric{} + counter.Write(pb) + cnt[i] = pb.GetCounter().GetValue() + c.Check(cnt[i], Equals, float64(i)) + time.Sleep(time.Second * 1) + } + c.Assert(planStr[0] < planStr[1], IsTrue) +} + +func (s *testPrepareSuite) TestPrepareCacheNow(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + orgEnable := core.PreparedPlanCacheEnabled() + orgCapacity := core.PreparedPlanCacheCapacity + defer func() { + dom.Close() + store.Close() + core.SetPreparedPlanCache(orgEnable) + core.PreparedPlanCacheCapacity = orgCapacity + }() + core.SetPreparedPlanCache(true) + core.PreparedPlanCacheCapacity = 100 + tk.MustExec("use test") + tk.MustExec(`prepare stmt1 from "select now(), sleep(1), now()"`) + // When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache. + rs := tk.MustQuery("execute stmt1").Rows() + c.Assert(rs[0][0].(string), Equals, rs[0][2].(string)) + + tk.MustExec(`prepare stmt2 from "select current_timestamp(), sleep(1), current_timestamp()"`) + // When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache. + rs = tk.MustQuery("execute stmt2").Rows() + c.Assert(rs[0][0].(string), Equals, rs[0][2].(string)) + + tk.MustExec(`prepare stmt3 from "select utc_timestamp(), sleep(1), utc_timestamp()"`) + // When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache. + rs = tk.MustQuery("execute stmt3").Rows() + c.Assert(rs[0][0].(string), Equals, rs[0][2].(string)) + + tk.MustExec(`prepare stmt4 from "select unix_timestamp(), sleep(1), unix_timestamp()"`) + // When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache. + rs = tk.MustQuery("execute stmt4").Rows() + c.Assert(rs[0][0].(string), Equals, rs[0][2].(string)) +} diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 66db63081c5b8..ef416642c2c62 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -80,6 +80,8 @@ type StatementContext struct { RuntimeStatsColl *execdetails.RuntimeStatsColl TableIDs []int64 IndexIDs []int64 + NowTs time.Time + SysTs time.Time } // AddAffectedRows adds affected rows.