Skip to content

Commit

Permalink
expression: add linear search for the interval function (pingcap#19543)…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Sep 8, 2020
1 parent 3dfed22 commit da76c5c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 7 deletions.
80 changes: 73 additions & 7 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -821,10 +821,19 @@ func (c *intervalFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}

allInt := true
hasNullable := false
// if we have nullable columns in the argument list, we won't do a binary search, instead we will linearly scan the arguments.
// this behavior is in line with MySQL's, see MySQL's source code here:
// https://github.com/mysql/mysql-server/blob/f8cdce86448a211511e8a039c62580ae16cb96f5/sql/item_cmpfunc.cc#L2713-L2788
// https://github.com/mysql/mysql-server/blob/f8cdce86448a211511e8a039c62580ae16cb96f5/sql/item_cmpfunc.cc#L2632-L2686
for i := range args {
if args[i].GetType().EvalType() != types.ETInt {
tp := args[i].GetType()
if tp.EvalType() != types.ETInt {
allInt = false
}
if !mysql.HasNotNullFlag(tp.Flag) {
hasNullable = true
}
}

argTps, argTp := make([]types.EvalType, 0, len(args)), types.ETReal
Expand All @@ -837,15 +846,16 @@ func (c *intervalFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTps...)
var sig builtinFunc
if allInt {
sig = &builtinIntervalIntSig{bf}
sig = &builtinIntervalIntSig{bf, hasNullable}
} else {
sig = &builtinIntervalRealSig{bf}
sig = &builtinIntervalRealSig{bf, hasNullable}
}
return sig, nil
}

type builtinIntervalIntSig struct {
baseBuiltinFunc
hasNullable bool
}

func (b *builtinIntervalIntSig) Clone() builtinFunc {
Expand All @@ -857,17 +867,52 @@ func (b *builtinIntervalIntSig) Clone() builtinFunc {
// evalInt evals a builtinIntervalIntSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_interval
func (b *builtinIntervalIntSig) evalInt(row chunk.Row) (int64, bool, error) {
args0, isNull, err := b.args[0].EvalInt(b.ctx, row)
arg0, isNull, err := b.args[0].EvalInt(b.ctx, row)
if err != nil {
return 0, true, err
}
if isNull {
return -1, false, nil
}
idx, err := b.binSearch(args0, mysql.HasUnsignedFlag(b.args[0].GetType().Flag), b.args[1:], row)
isUint1 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag)
var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, isUint1, b.args[1:], row)
} else {
idx, err = b.binSearch(arg0, isUint1, b.args[1:], row)
}
return int64(idx), err != nil, err
}

// linearSearch linearly scans the argument least to find the position of the first value that is larger than the given target.
func (b *builtinIntervalIntSig) linearSearch(target int64, isUint1 bool, args []Expression, row chunk.Row) (i int, err error) {
i = 0
for ; i < len(args); i++ {
isUint2 := mysql.HasUnsignedFlag(args[i].GetType().Flag)
arg, isNull, err := args[i].EvalInt(b.ctx, row)
if err != nil {
return 0, err
}
var less bool
if !isNull {
switch {
case !isUint1 && !isUint2:
less = target < arg
case isUint1 && isUint2:
less = uint64(target) < uint64(arg)
case !isUint1 && isUint2:
less = target < 0 || uint64(target) < uint64(arg)
case isUint1 && !isUint2:
less = arg > 0 && uint64(target) < uint64(arg)
}
}
if less {
break
}
}
return i, nil
}

// binSearch is a binary search method.
// All arguments are treated as integers.
// It is required that arg[0] < args[1] < args[2] < ... < args[n] for this function to work correctly.
Expand Down Expand Up @@ -906,28 +951,49 @@ func (b *builtinIntervalIntSig) binSearch(target int64, isUint1 bool, args []Exp

type builtinIntervalRealSig struct {
baseBuiltinFunc
hasNullable bool
}

func (b *builtinIntervalRealSig) Clone() builtinFunc {
newSig := &builtinIntervalRealSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.hasNullable = b.hasNullable
return newSig
}

// evalInt evals a builtinIntervalRealSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_interval
func (b *builtinIntervalRealSig) evalInt(row chunk.Row) (int64, bool, error) {
args0, isNull, err := b.args[0].EvalReal(b.ctx, row)
arg0, isNull, err := b.args[0].EvalReal(b.ctx, row)
if err != nil {
return 0, true, err
}
if isNull {
return -1, false, nil
}
idx, err := b.binSearch(args0, b.args[1:], row)
var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, b.args[1:], row)
} else {
idx, err = b.binSearch(arg0, b.args[1:], row)
}
return int64(idx), err != nil, err
}

func (b *builtinIntervalRealSig) linearSearch(target float64, args []Expression, row chunk.Row) (i int, err error) {
i = 0
for ; i < len(args); i++ {
arg, isNull, err := args[i].EvalReal(b.ctx, row)
if err != nil {
return 0, err
}
if !isNull && target < arg {
break
}
}
return i, nil
}

func (b *builtinIntervalRealSig) binSearch(target float64, args []Expression, row chunk.Row) (_ int, err error) {
i, j := 0, len(args)
for i < j {
Expand Down
4 changes: 4 additions & 0 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) {
{types.MakeDatums(uint64(9223372036854775806), -9223372036854775807), 1, false},
{types.MakeDatums("9007199254740991", "9007199254740992"), 0, false},
{types.MakeDatums(1, uint32(1), uint32(1)), 0, true},
{types.MakeDatums(-1, 2333, nil), 0, false},
{types.MakeDatums(1, nil, nil, nil), 3, false},
{types.MakeDatums(1, nil, nil, nil, 2), 3, false},
{types.MakeDatums(uint64(9223372036854775808), nil, nil, nil, 4), 4, false},

// tests for appropriate precision loss
{types.MakeDatums(9007199254740992, "9007199254740993"), 1, false},
Expand Down

0 comments on commit da76c5c

Please sign in to comment.