From 5b26588eba41156c10a4914492d230bf1b54beae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=99=8E?= Date: Fri, 4 Sep 2020 12:07:32 +0800 Subject: [PATCH] expression: add linear search for the interval function (#19543) * add linear search for interval * update test * remove debug info * fix mistakenly erased code Co-authored-by: ti-srebot <66930949+ti-srebot@users.noreply.github.com> --- expression/builtin_compare.go | 80 +++++++++++++++++++++++++++--- expression/builtin_compare_test.go | 4 ++ expression/builtin_compare_vec.go | 12 ++++- expression/distsql_builtin.go | 4 +- expression/integration_test.go | 12 +++++ 5 files changed, 101 insertions(+), 11 deletions(-) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 4c05906bb8b51..df857488d0cf4 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -836,10 +836,19 @@ func (c *intervalFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } allInt := true + hasNullable := false + // if we have nullable columns in the argument list, we won't do a binary search, instead we will linearly scan the arguments. + // this behavior is in line with MySQL's, see MySQL's source code here: + // https://github.com/mysql/mysql-server/blob/f8cdce86448a211511e8a039c62580ae16cb96f5/sql/item_cmpfunc.cc#L2713-L2788 + // https://github.com/mysql/mysql-server/blob/f8cdce86448a211511e8a039c62580ae16cb96f5/sql/item_cmpfunc.cc#L2632-L2686 for i := range args { - if args[i].GetType().EvalType() != types.ETInt { + tp := args[i].GetType() + if tp.EvalType() != types.ETInt { allInt = false } + if !mysql.HasNotNullFlag(tp.Flag) { + hasNullable = true + } } argTps, argTp := make([]types.EvalType, 0, len(args)), types.ETReal @@ -855,10 +864,10 @@ func (c *intervalFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } var sig builtinFunc if allInt { - sig = &builtinIntervalIntSig{bf} + sig = &builtinIntervalIntSig{bf, hasNullable} sig.setPbCode(tipb.ScalarFuncSig_IntervalInt) } else { - sig = &builtinIntervalRealSig{bf} + sig = &builtinIntervalRealSig{bf, hasNullable} sig.setPbCode(tipb.ScalarFuncSig_IntervalReal) } return sig, nil @@ -866,6 +875,7 @@ func (c *intervalFunctionClass) getFunction(ctx sessionctx.Context, args []Expre type builtinIntervalIntSig struct { baseBuiltinFunc + hasNullable bool } func (b *builtinIntervalIntSig) Clone() builtinFunc { @@ -877,17 +887,52 @@ func (b *builtinIntervalIntSig) Clone() builtinFunc { // evalInt evals a builtinIntervalIntSig. // See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_interval func (b *builtinIntervalIntSig) evalInt(row chunk.Row) (int64, bool, error) { - args0, isNull, err := b.args[0].EvalInt(b.ctx, row) + arg0, isNull, err := b.args[0].EvalInt(b.ctx, row) if err != nil { return 0, true, err } if isNull { return -1, false, nil } - idx, err := b.binSearch(args0, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], row) + isUint1 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag) + var idx int + if b.hasNullable { + idx, err = b.linearSearch(arg0, isUint1, b.args[1:], row) + } else { + idx, err = b.binSearch(arg0, isUint1, b.args[1:], row) + } return int64(idx), err != nil, err } +// linearSearch linearly scans the argument least to find the position of the first value that is larger than the given target. +func (b *builtinIntervalIntSig) linearSearch(target int64, isUint1 bool, args []Expression, row chunk.Row) (i int, err error) { + i = 0 + for ; i < len(args); i++ { + isUint2 := mysql.HasUnsignedFlag(args[i].GetType().Flag) + arg, isNull, err := args[i].EvalInt(b.ctx, row) + if err != nil { + return 0, err + } + var less bool + if !isNull { + switch { + case !isUint1 && !isUint2: + less = target < arg + case isUint1 && isUint2: + less = uint64(target) < uint64(arg) + case !isUint1 && isUint2: + less = target < 0 || uint64(target) < uint64(arg) + case isUint1 && !isUint2: + less = arg > 0 && uint64(target) < uint64(arg) + } + } + if less { + break + } + } + return i, nil +} + // binSearch is a binary search method. // All arguments are treated as integers. // It is required that arg[0] < args[1] < args[2] < ... < args[n] for this function to work correctly. @@ -926,28 +971,49 @@ func (b *builtinIntervalIntSig) binSearch(target int64, isUint1 bool, args []Exp type builtinIntervalRealSig struct { baseBuiltinFunc + hasNullable bool } func (b *builtinIntervalRealSig) Clone() builtinFunc { newSig := &builtinIntervalRealSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.hasNullable = b.hasNullable return newSig } // evalInt evals a builtinIntervalRealSig. // See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_interval func (b *builtinIntervalRealSig) evalInt(row chunk.Row) (int64, bool, error) { - args0, isNull, err := b.args[0].EvalReal(b.ctx, row) + arg0, isNull, err := b.args[0].EvalReal(b.ctx, row) if err != nil { return 0, true, err } if isNull { return -1, false, nil } - idx, err := b.binSearch(args0, b.args[1:], row) + var idx int + if b.hasNullable { + idx, err = b.linearSearch(arg0, b.args[1:], row) + } else { + idx, err = b.binSearch(arg0, b.args[1:], row) + } return int64(idx), err != nil, err } +func (b *builtinIntervalRealSig) linearSearch(target float64, args []Expression, row chunk.Row) (i int, err error) { + i = 0 + for ; i < len(args); i++ { + arg, isNull, err := args[i].EvalReal(b.ctx, row) + if err != nil { + return 0, err + } + if !isNull && target < arg { + break + } + } + return i, nil +} + func (b *builtinIntervalRealSig) binSearch(target float64, args []Expression, row chunk.Row) (_ int, err error) { i, j := 0, len(args) for i < j { diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index cf9ec1a565a9a..5acfad1a950c8 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -233,6 +233,10 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) { {types.MakeDatums(uint64(9223372036854775806), -9223372036854775807), 1, false}, {types.MakeDatums("9007199254740991", "9007199254740992"), 0, false}, {types.MakeDatums(1, uint32(1), uint32(1)), 0, true}, + {types.MakeDatums(-1, 2333, nil), 0, false}, + {types.MakeDatums(1, nil, nil, nil), 3, false}, + {types.MakeDatums(1, nil, nil, nil, 2), 3, false}, + {types.MakeDatums(uint64(9223372036854775808), nil, nil, nil, 4), 4, false}, // tests for appropriate precision loss {types.MakeDatums(9007199254740992, "9007199254740993"), 1, false}, diff --git a/expression/builtin_compare_vec.go b/expression/builtin_compare_vec.go index dc440e27b6207..d13b7fb07a988 100644 --- a/expression/builtin_compare_vec.go +++ b/expression/builtin_compare_vec.go @@ -434,7 +434,11 @@ func (b *builtinIntervalIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Col i64s[i] = -1 continue } - idx, err = b.binSearch(v, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], input.GetRow(i)) + if b.hasNullable { + idx, err = b.linearSearch(v, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], input.GetRow(i)) + } else { + idx, err = b.binSearch(v, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], input.GetRow(i)) + } if err != nil { return err } @@ -467,7 +471,11 @@ func (b *builtinIntervalRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Co res[i] = -1 continue } - idx, err = b.binSearch(f64s[i], b.args[1:], input.GetRow(i)) + if b.hasNullable { + idx, err = b.linearSearch(f64s[i], b.args[1:], input.GetRow(i)) + } else { + idx, err = b.binSearch(f64s[i], b.args[1:], input.GetRow(i)) + } if err != nil { return err } diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 63e31914854be..3ded2d82c9c21 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -224,9 +224,9 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti case tipb.ScalarFuncSig_LeastTime: f = &builtinLeastTimeSig{base} case tipb.ScalarFuncSig_IntervalInt: - f = &builtinIntervalIntSig{base} + f = &builtinIntervalIntSig{base, false} // Since interval function won't be pushed down to TiKV, therefore it doesn't matter what value we give to hasNullable case tipb.ScalarFuncSig_IntervalReal: - f = &builtinIntervalRealSig{base} + f = &builtinIntervalRealSig{base, false} case tipb.ScalarFuncSig_GEInt: f = &builtinGEIntSig{base} case tipb.ScalarFuncSig_GEReal: diff --git a/expression/integration_test.go b/expression/integration_test.go index abd0d5edee9d1..b9e48c360de0f 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -7036,6 +7036,18 @@ func (s *testIntegrationSuite) TestIssue18515(c *C) { tk.MustExec("select /*+ TIDB_INLJ(t2) */ t1.a, t1.c, t2.a from t t1, t t2 where t1.c=t2.c;") } +func (s *testIntegrationSuite) TestIssue18525(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (col0 BLOB, col1 CHAR(74), col2 DATE UNIQUE)") + tk.MustExec("insert into t1 values ('l', '7a34bc7d-6786-461b-92d3-fd0a6cd88f39', '1000-01-03')") + tk.MustExec("insert into t1 values ('l', NULL, '1000-01-04')") + tk.MustExec("insert into t1 values ('b', NULL, '1000-01-02')") + tk.MustQuery("select INTERVAL( ( CONVERT( -11752 USING utf8 ) ), 6558853612195285496, `col1`) from t1").Check(testkit.Rows("0", "0", "0")) + +} + func (s *testIntegrationSerialSuite) TestIssue17989(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test")