diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 3f001cd3fe586..8c83f69a2c4ce 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -4994,6 +4994,14 @@ func (b *builtinPeriodDiffSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull, errors.Trace(err) } + if !validPeriod(p1) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") + } + + if !validPeriod(p2) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") + } + return int64(period2Month(uint64(p1)) - period2Month(uint64(p2))), false, nil } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 3f59ef6e8acac..ff118ecaefdff 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -2351,19 +2351,24 @@ func (s *testEvaluatorSuite) TestPeriodDiff(c *C) { }{ {201611, 201611, true, 0}, {200802, 200703, true, 11}, - {0, 999999999, true, -120000086}, - {9999999, 0, true, 1200086}, - {411, 200413, true, -2}, - {197000, 207700, true, -1284}, {201701, 201611, true, 2}, {201702, 201611, true, 3}, {201510, 201611, true, -13}, {201702, 1611, true, 3}, {197102, 7011, true, 3}, - {12509, 12323, true, 10}, - {12509, 12323, true, 10}, } + tests2 := []struct { + Period1 int64 + Period2 int64 + }{ + {0, 999999999}, + {9999999, 0}, + {411, 200413}, + {197000, 207700}, + {12509, 12323}, + {12509, 12323}, + } fc := funcs[ast.PeriodDiff] for _, test := range tests { period1 := types.NewIntDatum(test.Period1) @@ -2381,6 +2386,18 @@ func (s *testEvaluatorSuite) TestPeriodDiff(c *C) { value := result.GetInt64() c.Assert(value, Equals, test.Expect) } + + for _, test := range tests2 { + period1 := types.NewIntDatum(test.Period1) + period2 := types.NewIntDatum(test.Period2) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{period1, period2})) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + _, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_diff") + } + // nil args := []types.Datum{types.NewDatum(nil), types.NewIntDatum(0)} f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) diff --git a/expression/integration_test.go b/expression/integration_test.go index 9530c00e39557..b76d6538772a9 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1462,10 +1462,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { } // for period_diff - result = tk.MustQuery(`SELECT period_diff(191, 2), period_diff(191, -2), period_diff(0, 0), period_diff(191, 191);`) - result.Check(testkit.Rows("101 -2213609288845122103 0 0")) - result = tk.MustQuery(`SELECT period_diff(NULL, 2), period_diff(-191, NULL), period_diff(NULL, NULL), period_diff(12.09, 2), period_diff("21aa", "11aa"), period_diff("", "");`) - result.Check(testkit.Rows(" 10 10 0")) + result = tk.MustQuery(`SELECT period_diff(200807, 200705), period_diff(200807, 200908);`) + result.Check(testkit.Rows("14 -13")) + result = tk.MustQuery(`SELECT period_diff(NULL, 2), period_diff(-191, NULL), period_diff(NULL, NULL), period_diff(12.09, 2), period_diff("12aa", "11aa");`) + result.Check(testkit.Rows(" 10 1")) + for _, errPeriod := range []string{ + "period_diff(-00013,1)", "period_diff(00013,1)", "period_diff(0, 0)", "period_diff(200013, 1)", "period_diff(5612, 4513)", "period_diff('', '')", + } { + err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod)) + c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_diff") + } // TODO: fix `CAST(xx as duration)` and release the test below: // result = tk.MustQuery(`SELECT hour("aaa"), hour(123456), hour(1234567);`)