Skip to content

Commit

Permalink
cherry pick pingcap#19543 to release-4.0
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <[email protected]>
  • Loading branch information
ichn-hu authored and ti-srebot committed Sep 4, 2020
1 parent a5922c2 commit 4c7c04f
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 11 deletions.
80 changes: 73 additions & 7 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,19 @@ func (c *intervalFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}

allInt := true
hasNullable := false
// if we have nullable columns in the argument list, we won't do a binary search, instead we will linearly scan the arguments.
// this behavior is in line with MySQL's, see MySQL's source code here:
// https://github.com/mysql/mysql-server/blob/f8cdce86448a211511e8a039c62580ae16cb96f5/sql/item_cmpfunc.cc#L2713-L2788
// https://github.com/mysql/mysql-server/blob/f8cdce86448a211511e8a039c62580ae16cb96f5/sql/item_cmpfunc.cc#L2632-L2686
for i := range args {
if args[i].GetType().EvalType() != types.ETInt {
tp := args[i].GetType()
if tp.EvalType() != types.ETInt {
allInt = false
}
if !mysql.HasNotNullFlag(tp.Flag) {
hasNullable = true
}
}

argTps, argTp := make([]types.EvalType, 0, len(args)), types.ETReal
Expand All @@ -855,17 +864,18 @@ func (c *intervalFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
var sig builtinFunc
if allInt {
sig = &builtinIntervalIntSig{bf}
sig = &builtinIntervalIntSig{bf, hasNullable}
sig.setPbCode(tipb.ScalarFuncSig_IntervalInt)
} else {
sig = &builtinIntervalRealSig{bf}
sig = &builtinIntervalRealSig{bf, hasNullable}
sig.setPbCode(tipb.ScalarFuncSig_IntervalReal)
}
return sig, nil
}

type builtinIntervalIntSig struct {
baseBuiltinFunc
hasNullable bool
}

func (b *builtinIntervalIntSig) Clone() builtinFunc {
Expand All @@ -877,17 +887,52 @@ func (b *builtinIntervalIntSig) Clone() builtinFunc {
// evalInt evals a builtinIntervalIntSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_interval
func (b *builtinIntervalIntSig) evalInt(row chunk.Row) (int64, bool, error) {
args0, isNull, err := b.args[0].EvalInt(b.ctx, row)
arg0, isNull, err := b.args[0].EvalInt(b.ctx, row)
if err != nil {
return 0, true, err
}
if isNull {
return -1, false, nil
}
idx, err := b.binSearch(args0, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], row)
isUint1 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag)
var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, isUint1, b.args[1:], row)
} else {
idx, err = b.binSearch(arg0, isUint1, b.args[1:], row)
}
return int64(idx), err != nil, err
}

// linearSearch linearly scans the argument least to find the position of the first value that is larger than the given target.
func (b *builtinIntervalIntSig) linearSearch(target int64, isUint1 bool, args []Expression, row chunk.Row) (i int, err error) {
i = 0
for ; i < len(args); i++ {
isUint2 := mysql.HasUnsignedFlag(args[i].GetType().Flag)
arg, isNull, err := args[i].EvalInt(b.ctx, row)
if err != nil {
return 0, err
}
var less bool
if !isNull {
switch {
case !isUint1 && !isUint2:
less = target < arg
case isUint1 && isUint2:
less = uint64(target) < uint64(arg)
case !isUint1 && isUint2:
less = target < 0 || uint64(target) < uint64(arg)
case isUint1 && !isUint2:
less = arg > 0 && uint64(target) < uint64(arg)
}
}
if less {
break
}
}
return i, nil
}

// binSearch is a binary search method.
// All arguments are treated as integers.
// It is required that arg[0] < args[1] < args[2] < ... < args[n] for this function to work correctly.
Expand Down Expand Up @@ -926,28 +971,49 @@ func (b *builtinIntervalIntSig) binSearch(target int64, isUint1 bool, args []Exp

type builtinIntervalRealSig struct {
baseBuiltinFunc
hasNullable bool
}

func (b *builtinIntervalRealSig) Clone() builtinFunc {
newSig := &builtinIntervalRealSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.hasNullable = b.hasNullable
return newSig
}

// evalInt evals a builtinIntervalRealSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_interval
func (b *builtinIntervalRealSig) evalInt(row chunk.Row) (int64, bool, error) {
args0, isNull, err := b.args[0].EvalReal(b.ctx, row)
arg0, isNull, err := b.args[0].EvalReal(b.ctx, row)
if err != nil {
return 0, true, err
}
if isNull {
return -1, false, nil
}
idx, err := b.binSearch(args0, b.args[1:], row)
var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, b.args[1:], row)
} else {
idx, err = b.binSearch(arg0, b.args[1:], row)
}
return int64(idx), err != nil, err
}

func (b *builtinIntervalRealSig) linearSearch(target float64, args []Expression, row chunk.Row) (i int, err error) {
i = 0
for ; i < len(args); i++ {
arg, isNull, err := args[i].EvalReal(b.ctx, row)
if err != nil {
return 0, err
}
if !isNull && target < arg {
break
}
}
return i, nil
}

func (b *builtinIntervalRealSig) binSearch(target float64, args []Expression, row chunk.Row) (_ int, err error) {
i, j := 0, len(args)
for i < j {
Expand Down
4 changes: 4 additions & 0 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) {
{types.MakeDatums(uint64(9223372036854775806), -9223372036854775807), 1, false},
{types.MakeDatums("9007199254740991", "9007199254740992"), 0, false},
{types.MakeDatums(1, uint32(1), uint32(1)), 0, true},
{types.MakeDatums(-1, 2333, nil), 0, false},
{types.MakeDatums(1, nil, nil, nil), 3, false},
{types.MakeDatums(1, nil, nil, nil, 2), 3, false},
{types.MakeDatums(uint64(9223372036854775808), nil, nil, nil, 4), 4, false},

// tests for appropriate precision loss
{types.MakeDatums(9007199254740992, "9007199254740993"), 1, false},
Expand Down
12 changes: 10 additions & 2 deletions expression/builtin_compare_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,11 @@ func (b *builtinIntervalIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Col
i64s[i] = -1
continue
}
idx, err = b.binSearch(v, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], input.GetRow(i))
if b.hasNullable {
idx, err = b.linearSearch(v, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], input.GetRow(i))
} else {
idx, err = b.binSearch(v, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], input.GetRow(i))
}
if err != nil {
return err
}
Expand Down Expand Up @@ -467,7 +471,11 @@ func (b *builtinIntervalRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Co
res[i] = -1
continue
}
idx, err = b.binSearch(f64s[i], b.args[1:], input.GetRow(i))
if b.hasNullable {
idx, err = b.linearSearch(f64s[i], b.args[1:], input.GetRow(i))
} else {
idx, err = b.binSearch(f64s[i], b.args[1:], input.GetRow(i))
}
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
case tipb.ScalarFuncSig_LeastTime:
f = &builtinLeastTimeSig{base}
case tipb.ScalarFuncSig_IntervalInt:
f = &builtinIntervalIntSig{base}
f = &builtinIntervalIntSig{base, false} // Since interval function won't be pushed down to TiKV, therefore it doesn't matter what value we give to hasNullable
case tipb.ScalarFuncSig_IntervalReal:
f = &builtinIntervalRealSig{base}
f = &builtinIntervalRealSig{base, false}
case tipb.ScalarFuncSig_GEInt:
f = &builtinGEIntSig{base}
case tipb.ScalarFuncSig_GEReal:
Expand Down
12 changes: 12 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6626,6 +6626,18 @@ func (s *testIntegrationSuite) TestIssue18515(c *C) {
tk.MustExec("select /*+ TIDB_INLJ(t2) */ t1.a, t1.c, t2.a from t t1, t t2 where t1.c=t2.c;")
}

func (s *testIntegrationSuite) TestIssue18525(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1 (col0 BLOB, col1 CHAR(74), col2 DATE UNIQUE)")
tk.MustExec("insert into t1 values ('l', '7a34bc7d-6786-461b-92d3-fd0a6cd88f39', '1000-01-03')")
tk.MustExec("insert into t1 values ('l', NULL, '1000-01-04')")
tk.MustExec("insert into t1 values ('b', NULL, '1000-01-02')")
tk.MustQuery("select INTERVAL( ( CONVERT( -11752 USING utf8 ) ), 6558853612195285496, `col1`) from t1").Check(testkit.Rows("0", "0", "0"))

}

func (s *testIntegrationSerialSuite) TestIssue17989(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down

0 comments on commit 4c7c04f

Please sign in to comment.