From 766be2587059c54228881138bee03fa8d06bc38a Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 12 Feb 2025 16:52:45 +0100 Subject: [PATCH] Implement temporal comparisons This is currently missing and leads to incorrect types to be returned for `LEAST` & `GREATEST` as comparison functions. There's a little mismatch here in behavior compared to MySQL which I argue is actually a bug in MySQL. In MySQL, a temporal type always has the binary collation: ``` mysql> select NOW(6), collation(NOW(6)); +----------------------------+-------------------+ | NOW(6) | collation(NOW(6)) | +----------------------------+-------------------+ | 2025-02-19 15:33:21.732301 | binary | +----------------------------+-------------------+ 1 row in set (0.00 sec) ``` On MySQL 8.4, this results in: ``` mysql> select GREATEST(NOW(6), NOW(6)), collation(GREATEST(NOW(6), NOW(6))); +----------------------------+-------------------------------------+ | GREATEST(NOW(6), NOW(6)) | collation(GREATEST(NOW(6), NOW(6))) | +----------------------------+-------------------------------------+ | 2025-02-19 15:35:00.921308 | latin1_swedish_ci | +----------------------------+-------------------------------------+ 1 row in set (0.00 sec) ``` But on MySQL 8.0, it returns: ``` mysql> select GREATEST(NOW(6), NOW(6)), collation(GREATEST(NOW(6), NOW(6))); +----------------------------+-------------------------------------+ | GREATEST(NOW(6), NOW(6)) | collation(GREATEST(NOW(6), NOW(6))) | +----------------------------+-------------------------------------+ | 2025-02-19 15:35:00.921308 | utf8mb4_0900_ai_ci | +----------------------------+-------------------------------------+ 1 row in set (0.00 sec) ``` Neither of these collations make sense, because it really should not change the collation and return `binary` still. That is what Vitess still does with the changes here (hence the addition to the test framework to allow skipping the collation check). I'll also report the issue upstream to make it behave correctly there as well. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 35 +- go/vt/vtgate/evalengine/compiler_asm.go | 53 +- go/vt/vtgate/evalengine/compiler_asm_push.go | 29 + go/vt/vtgate/evalengine/compiler_test.go | 2 +- go/vt/vtgate/evalengine/eval_temporal.go | 99 ++- go/vt/vtgate/evalengine/expr_bvar.go | 4 +- go/vt/vtgate/evalengine/expr_collate.go | 2 +- go/vt/vtgate/evalengine/expr_column.go | 4 +- go/vt/vtgate/evalengine/fn_compare.go | 284 +++++++- go/vt/vtgate/evalengine/fn_compare_test.go | 63 ++ go/vt/vtgate/evalengine/fn_time.go | 48 +- .../evalengine/integration/comparison_test.go | 12 +- go/vt/vtgate/evalengine/testcases/cases.go | 656 +++++++++--------- go/vt/vtgate/evalengine/testcases/helpers.go | 2 +- 14 files changed, 909 insertions(+), 384 deletions(-) create mode 100644 go/vt/vtgate/evalengine/fn_compare_test.go diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index b0a7edd285d..c69df3a300f 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -335,7 +335,7 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { c.asm.Convert_id(offset) case sqltypes.Uint64: c.asm.Convert_ud(offset) - case sqltypes.Datetime, sqltypes.Time: + case sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: scale = ct.Size size = ct.Size + decimalSizeBase fallthrough @@ -345,6 +345,28 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Scale: scale, Size: size} } +func (c *compiler) compileToTemporal(doct ctype, typ sqltypes.Type, offset, prec int) ctype { + switch doct.Type { + case typ: + if int(doct.Size) == prec { + return doct + } + fallthrough + default: + switch typ { + case sqltypes.Date: + c.asm.Convert_xD(offset, c.sqlmode.AllowZeroDate()) + case sqltypes.Datetime: + c.asm.Convert_xDT(offset, prec, c.sqlmode.AllowZeroDate()) + case sqltypes.Timestamp: + c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate()) + case sqltypes.Time: + c.asm.Convert_xT(offset, prec) + } + } + return ctype{Type: typ, Col: collationBinary, Flag: flagNullable} +} + func (c *compiler) compileToDate(doct ctype, offset int) ctype { switch doct.Type { case sqltypes.Date: @@ -366,6 +388,17 @@ func (c *compiler) compileToDateTime(doct ctype, offset, prec int) ctype { return ctype{Type: sqltypes.Datetime, Size: int32(prec), Col: collationBinary, Flag: flagNullable} } +func (c *compiler) compileToTimestamp(doct ctype, offset, prec int) ctype { + switch doct.Type { + case sqltypes.Timestamp: + c.asm.Convert_tp(offset, prec) + return doct + default: + c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate()) + } + return ctype{Type: sqltypes.Timestamp, Size: int32(prec), Col: collationBinary, Flag: flagNullable} +} + func (c *compiler) compileToTime(doct ctype, offset, prec int) ctype { switch doct.Type { case sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 7dda215353f..d13d22e76cc 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -767,11 +767,11 @@ func (asm *assembler) CmpDates() { }, "CMP DATE(SP-2), DATE(SP-1)") } -func (asm *assembler) Collate(col collations.ID) { +func (asm *assembler) Collate(col collations.TypedCollation) { asm.emit(func(env *ExpressionEnv) int { a := env.vm.stack[env.vm.sp-1].(*evalBytes) a.tt = int16(sqltypes.VarChar) - a.col.Collation = col + a.col = col return 1 }, "COLLATE VARCHAR(SP-1), %d", col) } @@ -1170,6 +1170,21 @@ func (asm *assembler) Convert_xDT(offset, prec int, allowZero bool) { }, "CONV (SP-%d), DATETIME", offset) } +func (asm *assembler) Convert_xDTs(offset, prec int, allowZero bool) { + asm.emit(func(env *ExpressionEnv) int { + // Need to explicitly check here or we otherwise + // store a nil wrapper in an interface vs. a direct + // nil. + dt := evalToTimestamp(env.vm.stack[env.vm.sp-offset], prec, env.now, allowZero) + if dt == nil { + env.vm.stack[env.vm.sp-offset] = nil + } else { + env.vm.stack[env.vm.sp-offset] = dt + } + return 1 + }, "CONV (SP-%d), TIMESTAMP", offset) +} + func (asm *assembler) Convert_xT(offset, prec int) { asm.emit(func(env *ExpressionEnv) int { t := evalToTime(env.vm.stack[env.vm.sp-offset], prec) @@ -2670,6 +2685,40 @@ func (asm *assembler) Fn_MULTICMP_u(args int, lessThan bool) { }, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args) } +func (asm *assembler) Fn_MULTICMP_temporal(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(env *ExpressionEnv) int { + var x *evalTemporal + x, _ = env.vm.stack[env.vm.sp-args].(*evalTemporal) + for sp := env.vm.sp - args + 1; sp < env.vm.sp; sp++ { + if env.vm.stack[sp] == nil { + if lessThan { + x = nil + } + continue + } + y := env.vm.stack[sp].(*evalTemporal) + if lessThan == (y.compare(x) < 0) { + x = y + } + } + env.vm.stack[env.vm.sp-args] = x + env.vm.sp -= args - 1 + return 1 + }, "FN MULTICMP TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_temporal_fallback(f multiComparisonFunc, args int, cmp, prec int) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(env *ExpressionEnv) int { + env.vm.stack[env.vm.sp-args], env.vm.err = f(env, env.vm.stack[env.vm.sp-args:env.vm.sp], cmp, prec) + env.vm.sp -= args - 1 + return 1 + }, "FN MULTICMP_FALLBACK TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args) +} + func (asm *assembler) Fn_REPEAT(base sqltypes.Type, fallback sqltypes.Type) { asm.adjustStack(-1) diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index 8f2b5d9f28b..404c8870f87 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -362,6 +362,23 @@ func (asm *assembler) PushColumn_datetime(offset int) { }, "PUSH DATETIME(:%d)", offset) } +func push_timestamp(env *ExpressionEnv, raw []byte) int { + env.vm.stack[env.vm.sp], env.vm.err = parseTimestamp(raw) + env.vm.sp++ + return 1 +} + +func (asm *assembler) PushColumn_timestamp(offset int) { + asm.adjustStack(1) + asm.emit(func(env *ExpressionEnv) int { + col := env.Row[offset] + if col.IsNull() { + return push_null(env) + } + return push_timestamp(env, col.Raw()) + }, "PUSH TIMESTAMP(:%d)", offset) +} + func (asm *assembler) PushBVar_datetime(key string) { asm.adjustStack(1) asm.emit(func(env *ExpressionEnv) int { @@ -374,6 +391,18 @@ func (asm *assembler) PushBVar_datetime(key string) { }, "PUSH DATETIME(:%q)", key) } +func (asm *assembler) PushBVar_timestamp(key string) { + asm.adjustStack(1) + asm.emit(func(env *ExpressionEnv) int { + var bvar *querypb.BindVariable + bvar, env.vm.err = env.lookupBindVar(key) + if env.vm.err != nil { + return 0 + } + return push_timestamp(env, bvar.Value) + }, "PUSH TIMESTAMP(:%q)", key) +} + func push_date(env *ExpressionEnv, raw []byte) int { env.vm.stack[env.vm.sp], env.vm.err = parseDate(raw) env.vm.sp++ diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 343bb0cd043..95f17488c9a 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -119,7 +119,7 @@ func TestCompilerReference(t *testing.T) { var supported, total int env := evalengine.EmptyExpressionEnv(venv) - tc.Run(func(query string, row []sqltypes.Value) { + tc.Run(func(query string, row []sqltypes.Value, skipCollationCheck bool) { env.Row = row total++ testCompilerCase(t, query, venv, tc.Schema, env) diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index d73485441c3..2adab98afc4 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -29,7 +29,7 @@ func (e *evalTemporal) ToRawBytes() []byte { switch e.t { case sqltypes.Date: return e.dt.Date.Format() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.Format(e.prec) case sqltypes.Time: return e.dt.Time.Format(e.prec) @@ -54,7 +54,7 @@ func (e *evalTemporal) toInt64() int64 { switch e.SQLType() { case sqltypes.Date: return e.dt.Date.FormatInt64() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatInt64() case sqltypes.Time: return e.dt.Time.FormatInt64() @@ -67,7 +67,7 @@ func (e *evalTemporal) toFloat() float64 { switch e.SQLType() { case sqltypes.Date: return float64(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatFloat64() case sqltypes.Time: return e.dt.Time.FormatFloat64() @@ -80,7 +80,7 @@ func (e *evalTemporal) toDecimal() decimal.Decimal { switch e.SQLType() { case sqltypes.Date: return decimal.NewFromInt(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatDecimal() case sqltypes.Time: return e.dt.Time.FormatDecimal() @@ -93,7 +93,7 @@ func (e *evalTemporal) toJSON() *evalJSON { switch e.SQLType() { case sqltypes.Date: return json.NewDate(hack.String(e.dt.Date.Format())) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return json.NewDateTime(hack.String(e.dt.Format(datetime.DefaultPrecision))) case sqltypes.Time: return json.NewTime(hack.String(e.dt.Time.Format(datetime.DefaultPrecision))) @@ -104,7 +104,7 @@ func (e *evalTemporal) toJSON() *evalJSON { func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { switch e.SQLType() { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)} case sqltypes.Time: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} @@ -113,9 +113,23 @@ func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { } } +func (e *evalTemporal) toTimestamp(l int, now time.Time) *evalTemporal { + switch e.SQLType() { + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Round(l), prec: uint8(l)} + case sqltypes.Time: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} + default: + panic("unreachable") + } +} + func (e *evalTemporal) toTime(l int) *evalTemporal { + if l == -1 { + l = int(e.prec) + } switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Time: e.dt.Time.Round(l)} return &evalTemporal{t: sqltypes.Time, dt: dt, prec: uint8(l)} case sqltypes.Date: @@ -130,7 +144,7 @@ func (e *evalTemporal) toTime(l int) *evalTemporal { func (e *evalTemporal) toDate(now time.Time) *evalTemporal { switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Date: e.dt.Date} return &evalTemporal{t: sqltypes.Date, dt: dt} case sqltypes.Date: @@ -148,6 +162,13 @@ func (e *evalTemporal) isZero() bool { return e.dt.IsZero() } +func (e *evalTemporal) compare(other *evalTemporal) int { + if other == nil { + return 1 + } + return e.dt.Compare(other.dt) +} + func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval { var tmp *evalTemporal var ok bool @@ -179,6 +200,13 @@ func newEvalDateTime(dt datetime.DateTime, l int, allowZero bool) *evalTemporal return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)} } +func newEvalTimestamp(dt datetime.DateTime, l int, allowZero bool) *evalTemporal { + if !allowZero && dt.IsZero() { + return nil + } + return &evalTemporal{t: sqltypes.Timestamp, dt: dt.Round(l), prec: uint8(l)} +} + func newEvalDate(d datetime.Date, allowZero bool) *evalTemporal { if !allowZero && d.IsZero() { return nil @@ -210,6 +238,14 @@ func parseDateTime(s []byte) (*evalTemporal, error) { return newEvalDateTime(t, l, true), nil } +func parseTimestamp(s []byte) (*evalTemporal, error) { + t, l, ok := datetime.ParseDateTime(hack.String(s), -1) + if !ok { + return nil, errIncorrectTemporal("TIMESTAMP", s) + } + return newEvalTimestamp(t, l, true), nil +} + func parseTime(s []byte) (*evalTemporal, error) { t, l, state := datetime.ParseTime(hack.String(s), -1) if state != datetime.TimeOK { @@ -387,6 +423,53 @@ func evalToDateTime(e eval, l int, now time.Time, allowZero bool) *evalTemporal return nil } +func evalToTimestamp(e eval, l int, now time.Time, allowZero bool) *evalTemporal { + switch e := e.(type) { + case *evalTemporal: + return e.toTimestamp(precision(l, int(e.prec)), now) + case *evalBytes: + if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() { + return newEvalTimestamp(t, l, allowZero) + } + if d, _ := datetime.ParseDate(e.string()); !d.IsZero() { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalInt64: + if t, ok := datetime.ParseDateTimeInt64(e.i); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(e.i); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalUint64: + if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalFloat: + if t, l, ok := datetime.ParseDateTimeFloat(e.f, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateFloat(e.f); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalDecimal: + if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateDecimal(e.dec); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalJSON: + if dt, ok := e.DateTime(); ok { + return newEvalTimestamp(dt, precision(l, datetime.DefaultPrecision), allowZero) + } + } + return nil +} + func evalToDate(e eval, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 50f231dbe9c..23b40949e83 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -154,8 +154,10 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { c.asm.PushNull() case tt == sqltypes.TypeJSON: c.asm.PushBVar_json(bvar.Key) - case tt == sqltypes.Datetime || tt == sqltypes.Timestamp: + case tt == sqltypes.Datetime: c.asm.PushBVar_datetime(bvar.Key) + case tt == sqltypes.Timestamp: + c.asm.PushBVar_timestamp(bvar.Key) case tt == sqltypes.Date: c.asm.PushBVar_date(bvar.Key) case tt == sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/expr_collate.go b/go/vt/vtgate/evalengine/expr_collate.go index be0eb78882b..b381acf6356 100644 --- a/go/vt/vtgate/evalengine/expr_collate.go +++ b/go/vt/vtgate/evalengine/expr_collate.go @@ -118,7 +118,7 @@ func (expr *CollateExpr) compile(c *compiler) (ctype, error) { } fallthrough case sqltypes.VarBinary: - c.asm.Collate(expr.TypedCollation.Collation) + c.asm.Collate(expr.TypedCollation) default: c.asm.Convert_xc(1, sqltypes.VarChar, expr.TypedCollation.Collation, nil) } diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index e52c522d973..7df113ee5d2 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -145,8 +145,10 @@ func (column *Column) compile(c *compiler) (ctype, error) { c.asm.PushNull() case tt == sqltypes.TypeJSON: c.asm.PushColumn_json(column.Offset) - case tt == sqltypes.Datetime || tt == sqltypes.Timestamp: + case tt == sqltypes.Datetime: c.asm.PushColumn_datetime(column.Offset) + case tt == sqltypes.Timestamp: + c.asm.PushColumn_timestamp(column.Offset) case tt == sqltypes.Date: c.asm.PushColumn_date(column.Offset) case tt == sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index 1deec6752ef..1084a240bd8 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -22,6 +22,7 @@ import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/mysql/collations/colldata" + datetime2 "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -32,11 +33,12 @@ type ( CallExpr } - multiComparisonFunc func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) + multiComparisonFunc func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) builtinMultiComparison struct { CallExpr - cmp int + cmp int + prec int } ) @@ -93,7 +95,7 @@ func (b *builtinCoalesce) compile(c *compiler) (ctype, error) { return ctype{Type: ta.result(), Flag: f, Col: ca.result()}, nil } -func getMultiComparisonFunc(args []eval) multiComparisonFunc { +func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiComparisonFunc { var ( integersI int integersU int @@ -101,6 +103,11 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { decimals int text int binary int + temporal int + datetime int + timestamp int + date int + time int ) /* @@ -114,7 +121,7 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { for _, arg := range args { if arg == nil { - return func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { + return func(_ *ExpressionEnv, _ []eval, _, _ int) (eval, error) { return nil, nil } } @@ -126,18 +133,86 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { integersU++ case *evalFloat: floats++ + call.prec = datetime2.DefaultPrecision case *evalDecimal: decimals++ + call.prec = max(call.prec, int(arg.length)) case *evalBytes: switch arg.SQLType() { case sqltypes.Text, sqltypes.VarChar: text++ + call.prec = max(call.prec, datetime2.DefaultPrecision) case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + if !arg.isHexOrBitLiteral() { + call.prec = max(call.prec, datetime2.DefaultPrecision) + } + } + case *evalTemporal: + temporal++ + call.prec = max(call.prec, int(arg.prec)) + switch arg.SQLType() { + case sqltypes.Datetime: + datetime++ + case sqltypes.Timestamp: + timestamp++ + case sqltypes.Date: + date++ + case sqltypes.Time: + time++ } } } + if temporal == len(args) { + switch { + case datetime > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case timestamp > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0 && time > 0: + // When all types are temporal, we convert the case + // of having a date and time all to datetime. + // This is contrary to the case where we have a non-temporal + // type in the list, since MySQL doesn't do that. + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, _ int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }) + case time > 0: + return compareAllTemporal(func(_ *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTime(arg, prec) + }) + } + } + + switch { + case datetime > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case timestamp > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, _ int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }) + case time > 0: + // So for time, there's actually no conversion and + // internal comparisons as time. So we don't pass it + // a conversion function. + return compareAllTemporalAsString(nil) + } + if integersI+integersU == len(args) { if integersI == len(args) { return compareAllInteger_i @@ -165,7 +240,93 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { panic("unexpected argument type") } -func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllTemporal(f func(env *ExpressionEnv, arg eval, prec int) *evalTemporal) multiComparisonFunc { + return func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) { + var x *evalTemporal + for _, arg := range args { + conv := f(env, arg, prec) + if x == nil { + x = conv + continue + } + if (cmp < 0) == (conv.compare(x) < 0) { + x = conv + } + } + return x, nil + } +} + +func compareAllTemporalAsString(f func(env *ExpressionEnv, arg eval, prec int) *evalTemporal) multiComparisonFunc { + return func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) { + validArgs := make([]*evalTemporal, 0, len(args)) + var ca collationAggregation + for _, arg := range args { + if err := ca.add(evalCollation(arg), env.collationEnv); err != nil { + return nil, err + } + if f != nil { + conv := f(env, arg, prec) + validArgs = append(validArgs, conv) + } + } + tc := ca.result() + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(sqltypes.VarChar, env.collationEnv.DefaultConnectionCharset()) + } + if f != nil { + idx := compareTemporalInternal(validArgs, cmp) + if idx >= 0 { + arg := args[idx] + if _, ok := arg.(*evalTemporal); ok { + arg = validArgs[idx] + } + return evalToVarchar(arg, tc.Collation, false) + } + } + txt, err := compareAllText(env, args, cmp, prec) + if err != nil { + return nil, err + } + return evalToVarchar(txt, tc.Collation, false) + } +} + +func compareTemporalInternal(args []*evalTemporal, cmp int) int { + if cmp < 0 { + // If we have any failed conversions and want to have the smallest value, + // we can't find that so we return -1 to indicate that. + // This will result in a fallback to do a string comparison. + for _, arg := range args { + if arg == nil { + return -1 + } + } + } + + x := 0 + for i, arg := range args[1:] { + if arg == nil { + continue + } + if (cmp < 0) == (compareTemporal(args, i+1, x) < 0) { + x = i + 1 + } + } + return x +} + +func compareTemporal(args []*evalTemporal, idx1, idx2 int) int { + if idx1 < 0 { + return 1 + } + if idx2 < 0 { + return -1 + } + return args[idx1].compare(args[idx2]) +} + +func compareAllInteger_u(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { x := args[0].(*evalUint64) for _, arg := range args[1:] { y := arg.(*evalUint64) @@ -176,7 +337,7 @@ func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllInteger_i(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { x := args[0].(*evalInt64) for _, arg := range args[1:] { y := arg.(*evalInt64) @@ -187,7 +348,7 @@ func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllFloat(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllFloat(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { candidateF, ok := evalToFloat(args[0]) if !ok { return nil, errDecimalOutOfRange @@ -212,7 +373,7 @@ func evalDecimalPrecision(e eval) int32 { return 0 } -func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllDecimal(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { decExtreme := evalToDecimal(args[0], 0, 0).dec precExtreme := evalDecimalPrecision(args[0]) @@ -229,12 +390,12 @@ func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, e return newEvalDecimalWithPrec(decExtreme, precExtreme), nil } -func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllText(env *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { var charsets = make([]charset.Charset, 0, len(args)) var ca collationAggregation for _, arg := range args { col := evalCollation(arg) - if err := ca.add(col, collationEnv); err != nil { + if err := ca.add(col, env.collationEnv); err != nil { return nil, err } charsets = append(charsets, colldata.Lookup(col.Collation).Charset()) @@ -262,7 +423,7 @@ func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) return newEvalText(b1, tc), nil } -func compareAllBinary(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllBinary(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { candidateB := args[0].ToRawBytes() for _, arg := range args[1:] { @@ -280,7 +441,7 @@ func (call *builtinMultiComparison) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - return getMultiComparisonFunc(args)(env.collationEnv, args, call.cmp) + return call.getMultiComparisonFunc(args)(env, args, call.cmp, call.prec) } func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, error) { @@ -314,14 +475,20 @@ func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { var ( - signed int - unsigned int - floats int - decimals int - text int - binary int - args []ctype - nullable bool + signed int + unsigned int + floats int + decimals int + temporal int + date int + datetime int + timestamp int + time int + text int + binary int + args []ctype + nullable bool + prec int ) /* @@ -349,12 +516,34 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { unsigned++ case sqltypes.Float64: floats++ + prec = max(prec, datetime2.DefaultPrecision) case sqltypes.Decimal: decimals++ + prec = max(prec, int(tt.Scale)) case sqltypes.Text, sqltypes.VarChar: text++ + prec = max(prec, datetime2.DefaultPrecision) case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + if !tt.isHexOrBitLiteral() { + prec = max(prec, datetime2.DefaultPrecision) + } + case sqltypes.Date: + temporal++ + date++ + prec = max(prec, int(tt.Size)) + case sqltypes.Datetime: + temporal++ + datetime++ + prec = max(prec, int(tt.Size)) + case sqltypes.Timestamp: + temporal++ + timestamp++ + prec = max(prec, int(tt.Size)) + case sqltypes.Time: + temporal++ + time++ + prec = max(prec, int(tt.Size)) case sqltypes.Null: nullable = true default: @@ -366,6 +555,61 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { if nullable { f |= flagNullable } + if temporal == len(args) { + var typ sqltypes.Type + switch { + case datetime > 0: + typ = sqltypes.Datetime + case timestamp > 0: + typ = sqltypes.Timestamp + case date > 0 && time > 0: + // When all types are temporal, we convert the case + // of having a date and time all to datetime. + // This is contrary to the case where we have a non-temporal + // type in the list, since MySQL doesn't do that. + typ = sqltypes.Datetime + case date > 0: + typ = sqltypes.Date + case time > 0: + typ = sqltypes.Time + } + for i, tt := range args { + if tt.Type != typ || int(tt.Size) != prec { + c.compileToTemporal(tt, typ, len(args)-i, prec) + } + } + c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + return ctype{Type: typ, Flag: f, Col: collationBinary}, nil + } else if temporal > 0 { + var ca collationAggregation + for _, arg := range args { + if err := ca.add(arg.Col, c.env.CollationEnv()); err != nil { + return ctype{}, err + } + } + + tc := ca.result() + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(sqltypes.VarChar, c.collation) + } + switch { + case datetime > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case timestamp > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case date > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case time > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(nil), len(args), call.cmp, prec) + } + return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil + } if signed+unsigned == len(args) { if signed == len(args) { c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) diff --git a/go/vt/vtgate/evalengine/fn_compare_test.go b/go/vt/vtgate/evalengine/fn_compare_test.go new file mode 100644 index 00000000000..305a0b3bc91 --- /dev/null +++ b/go/vt/vtgate/evalengine/fn_compare_test.go @@ -0,0 +1,63 @@ +package evalengine + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/mysql/datetime" +) + +func TestCompareTemporal(t *testing.T) { + tests := []struct { + name string + val1 *evalTemporal + val2 *evalTemporal + result int + }{ + { + name: "equal values", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 0, + }, + { + name: "larger value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: -1, + }, + { + name: "smaller value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 1, + }, + { + name: "first nil value", + val1: nil, + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 1, + }, + + { + name: "second nil value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: nil, + result: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idx1 := 0 + idx2 := 1 + if tt.val1 == nil { + idx1 = -1 + } + if tt.val2 == nil { + idx2 = -1 + } + assert.Equal(t, tt.result, compareTemporal([]*evalTemporal{tt.val1, tt.val2}, idx1, idx2)) + }) + } +} diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 322b89faafb..c50a7a265f5 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -335,7 +335,7 @@ func (call *builtinDateFormat) compile(c *compiler) (ctype, error) { skip1 := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: default: c.asm.Convert_xDT(1, datetime.DefaultPrecision, false) } @@ -451,7 +451,7 @@ func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { var prec int32 switch n.Type { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: prec = n.Size case sqltypes.Decimal: prec = n.Scale @@ -533,7 +533,7 @@ func (call *builtinDayOfMonth) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -566,7 +566,7 @@ func (call *builtinDayOfWeek) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -599,7 +599,7 @@ func (call *builtinDayOfYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -742,7 +742,7 @@ func (call *builtinFromUnixtime) compile(c *compiler) (ctype, error) { case sqltypes.Decimal: prec = arg.Size c.asm.Fn_FROM_UNIXTIME_d() - case sqltypes.Datetime, sqltypes.Date, sqltypes.Time: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Time, sqltypes.Timestamp: prec = arg.Size if prec == 0 { c.asm.Convert_Ti(1) @@ -814,7 +814,7 @@ func (call *builtinHour) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1160,7 +1160,7 @@ func (call *builtinMicrosecond) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1193,7 +1193,7 @@ func (call *builtinMinute) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1226,7 +1226,7 @@ func (call *builtinMonth) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1264,7 +1264,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1309,7 +1309,7 @@ func (call *builtinLastDay) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, c.sqlmode.AllowZeroDate()) } @@ -1344,7 +1344,7 @@ func (call *builtinToDays) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1481,7 +1481,7 @@ func (call *builtinTimeToSec) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1516,7 +1516,7 @@ func (call *builtinToSeconds) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xDT(1, -1, false) } @@ -1549,7 +1549,7 @@ func (call *builtinQuarter) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1582,7 +1582,7 @@ func (call *builtinSecond) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1617,7 +1617,7 @@ func (call *builtinTime) compile(c *compiler) (ctype, error) { var prec int32 switch arg.Type { case sqltypes.Time: - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: prec = arg.Size c.asm.Convert_xT(1, -1) case sqltypes.Decimal: @@ -1717,7 +1717,7 @@ func (call *builtinUnixTimestamp) compile(c *compiler) (ctype, error) { c.asm.Fn_UNIX_TIMESTAMP1() c.asm.jumpDestination(skip) switch arg.Type { - case sqltypes.Datetime, sqltypes.Time, sqltypes.Decimal: + case sqltypes.Datetime, sqltypes.Time, sqltypes.Decimal, sqltypes.Timestamp: if arg.Size == 0 { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } @@ -1782,7 +1782,7 @@ func (call *builtinWeek) compile(c *compiler) (ctype, error) { var skip2 *jump switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1827,7 +1827,7 @@ func (call *builtinWeekDay) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1863,7 +1863,7 @@ func (call *builtinWeekOfYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1906,7 +1906,7 @@ func (call *builtinYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1956,7 +1956,7 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) { var skip2 *jump switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index d559cb8ab1d..0e15869a125 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -82,12 +82,12 @@ func normalizeValue(v sqltypes.Value, coll collations.ID) sqltypes.Value { return v } -func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, env *evalengine.ExpressionEnv, conn *mysql.Conn, expr string, fields []*querypb.Field, cmp *testcases.Comparison) { +func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, env *evalengine.ExpressionEnv, conn *mysql.Conn, expr string, fields []*querypb.Field, cmp *testcases.Comparison, skipCollationCheck bool) { t.Helper() localQuery := "SELECT " + expr remoteQuery := "SELECT " + expr - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { remoteQuery = fmt.Sprintf("SELECT %s, COLLATION(%s)", expr, expr) } if len(fields) > 0 { @@ -146,7 +146,7 @@ func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, en var localCollation, remoteCollation collations.ID if localErr == nil { v := local.Value(collations.MySQL8().DefaultConnectionCharset()) - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { if v.IsNull() { localCollation = collations.CollationBinaryID } else { @@ -166,7 +166,7 @@ func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, en } else { remoteVal = remote.Rows[0][0] } - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { if remote.Rows[0][0].IsNull() { // TODO: passthrough proper collations for nullable fields remoteCollation = collations.CollationBinaryID @@ -275,9 +275,9 @@ func TestMySQL(t *testing.T) { Username: "vt_dba", }) env := evalengine.NewExpressionEnv(ctx, nil, &vcursor{env: venv}) - tc.Run(func(query string, row []sqltypes.Value) { + tc.Run(func(query string, row []sqltypes.Value, skipCollationCheck bool) { env.Row = row - compareRemoteExprEnv(t, collationEnv, env, conn, query, tc.Schema, tc.Compare) + compareRemoteExprEnv(t, collationEnv, env, conn, query, tc.Schema, tc.Compare, skipCollationCheck) }) }) } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index ff6c0c0f311..5469873b10e 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -178,18 +178,18 @@ var Cases = []TestCase{ func JSONPathOperations(yield Query) { for _, obj := range inputJSONObjects { - yield(fmt.Sprintf("JSON_KEYS('%s')", obj), nil) + yield(fmt.Sprintf("JSON_KEYS('%s')", obj), nil, false) for _, path1 := range inputJSONPaths { - yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_KEYS('%s', '%s')", obj, path1), nil) + yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_KEYS('%s', '%s')", obj, path1), nil, false) for _, path2 := range inputJSONPaths { - yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s', '%s')", obj, path1, path2), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s', '%s')", obj, path1, path2), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s', '%s')", obj, path1, path2), nil) + yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s', '%s')", obj, path1, path2), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s', '%s')", obj, path1, path2), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s', '%s')", obj, path1, path2), nil, false) } } } @@ -197,21 +197,21 @@ func JSONPathOperations(yield Query) { func JSONArray(yield Query) { for _, a := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_ARRAY(%s)", a), nil) + yield(fmt.Sprintf("JSON_ARRAY(%s)", a), nil, false) for _, b := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_ARRAY(%s, %s)", a, b), nil) + yield(fmt.Sprintf("JSON_ARRAY(%s, %s)", a, b), nil, false) } } - yield("JSON_ARRAY()", nil) + yield("JSON_ARRAY()", nil, false) } func JSONObject(yield Query) { for _, a := range inputJSONPrimitives { for _, b := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_OBJECT(%s, %s)", a, b), nil) + yield(fmt.Sprintf("JSON_OBJECT(%s, %s)", a, b), nil, false) } } - yield("JSON_OBJECT()", nil) + yield("JSON_OBJECT()", nil, false) } func CharsetConversionOperators(yield Query) { @@ -228,7 +228,7 @@ func CharsetConversionOperators(yield Query) { for _, pfx := range introducers { for _, lhs := range contents { for _, rhs := range charsets { - yield(fmt.Sprintf("HEX(CONVERT(%s %s USING %s))", pfx, lhs, rhs), nil) + yield(fmt.Sprintf("HEX(CONVERT(%s %s USING %s))", pfx, lhs, rhs), nil, false) } } } @@ -250,7 +250,7 @@ func CaseExprWithPredicate(yield Query) { for _, pred1 := range predicates { for _, val1 := range elements { for _, elseVal := range elements { - yield(fmt.Sprintf("case when %s then %s else %s end", pred1, val1, elseVal), nil) + yield(fmt.Sprintf("case when %s then %s else %s end", pred1, val1, elseVal), nil, false) } } } @@ -259,7 +259,7 @@ func CaseExprWithPredicate(yield Query) { genSubsets(elements, 3, func(values []string) { yield(fmt.Sprintf("case when %s then %s when %s then %s when %s then %s end", predicates[0], values[0], predicates[1], values[1], predicates[2], values[2], - ), nil) + ), nil, false) }) }) } @@ -279,13 +279,13 @@ func FnCeil(yield Query) { } for _, num := range ceilInputs { - yield(fmt.Sprintf("CEIL(%s)", num), nil) - yield(fmt.Sprintf("CEILING(%s)", num), nil) + yield(fmt.Sprintf("CEIL(%s)", num), nil, false) + yield(fmt.Sprintf("CEILING(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("CEIL(%s)", num), nil) - yield(fmt.Sprintf("CEILING(%s)", num), nil) + yield(fmt.Sprintf("CEIL(%s)", num), nil, false) + yield(fmt.Sprintf("CEILING(%s)", num), nil, false) } } @@ -304,11 +304,11 @@ func FnFloor(yield Query) { } for _, num := range floorInputs { - yield(fmt.Sprintf("FLOOR(%s)", num), nil) + yield(fmt.Sprintf("FLOOR(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("FLOOR(%s)", num), nil) + yield(fmt.Sprintf("FLOOR(%s)", num), nil, false) } } @@ -327,280 +327,280 @@ func FnAbs(yield Query) { } for _, num := range absInputs { - yield(fmt.Sprintf("ABS(%s)", num), nil) + yield(fmt.Sprintf("ABS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ABS(%s)", num), nil) + yield(fmt.Sprintf("ABS(%s)", num), nil, false) } } func FnPi(yield Query) { - yield("PI()+0.000000000000000000", nil) + yield("PI()+0.000000000000000000", nil, false) } func FnAcos(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ACOS(%s)", num), nil) + yield(fmt.Sprintf("ACOS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ACOS(%s)", num), nil) + yield(fmt.Sprintf("ACOS(%s)", num), nil, false) } } func FnAsin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ASIN(%s)", num), nil) + yield(fmt.Sprintf("ASIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ASIN(%s)", num), nil) + yield(fmt.Sprintf("ASIN(%s)", num), nil, false) } } func FnAtan(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ATAN(%s)", num), nil) + yield(fmt.Sprintf("ATAN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ATAN(%s)", num), nil) + yield(fmt.Sprintf("ATAN(%s)", num), nil, false) } } func FnAtan2(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil, false) } } } func FnCos(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("COS(%s)", num), nil) + yield(fmt.Sprintf("COS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("COS(%s)", num), nil) + yield(fmt.Sprintf("COS(%s)", num), nil, false) } } func FnCot(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("COT(%s)", num), nil) + yield(fmt.Sprintf("COT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("COT(%s)", num), nil) + yield(fmt.Sprintf("COT(%s)", num), nil, false) } } func FnSin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SIN(%s)", num), nil) + yield(fmt.Sprintf("SIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SIN(%s)", num), nil) + yield(fmt.Sprintf("SIN(%s)", num), nil, false) } } func FnTan(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("TAN(%s)", num), nil) + yield(fmt.Sprintf("TAN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("TAN(%s)", num), nil) + yield(fmt.Sprintf("TAN(%s)", num), nil, false) } } func FnDegrees(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("DEGREES(%s)", num), nil) + yield(fmt.Sprintf("DEGREES(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("DEGREES(%s)", num), nil) + yield(fmt.Sprintf("DEGREES(%s)", num), nil, false) } } func FnRadians(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("RADIANS(%s)", num), nil) + yield(fmt.Sprintf("RADIANS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("RADIANS(%s)", num), nil) + yield(fmt.Sprintf("RADIANS(%s)", num), nil, false) } } func FnExp(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("EXP(%s)", num), nil) + yield(fmt.Sprintf("EXP(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("EXP(%s)", num), nil) + yield(fmt.Sprintf("EXP(%s)", num), nil, false) } } func FnLn(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LN(%s)", num), nil) + yield(fmt.Sprintf("LN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LN(%s)", num), nil) + yield(fmt.Sprintf("LN(%s)", num), nil, false) } } func FnLog(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG(%s)", num), nil) + yield(fmt.Sprintf("LOG(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG(%s)", num), nil) + yield(fmt.Sprintf("LOG(%s)", num), nil, false) } for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } } } func FnLog10(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG10(%s)", num), nil) + yield(fmt.Sprintf("LOG10(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG10(%s)", num), nil) + yield(fmt.Sprintf("LOG10(%s)", num), nil, false) } } func FnMod(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } } } func FnLog2(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG2(%s)", num), nil) + yield(fmt.Sprintf("LOG2(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG2(%s)", num), nil) + yield(fmt.Sprintf("LOG2(%s)", num), nil, false) } } func FnPow(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } } } func FnSign(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SIGN(%s)", num), nil) + yield(fmt.Sprintf("SIGN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SIGN(%s)", num), nil) + yield(fmt.Sprintf("SIGN(%s)", num), nil, false) } } func FnSqrt(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SQRT(%s)", num), nil) + yield(fmt.Sprintf("SQRT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SQRT(%s)", num), nil) + yield(fmt.Sprintf("SQRT(%s)", num), nil, false) } } func FnRound(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ROUND(%s)", num), nil) + yield(fmt.Sprintf("ROUND(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s)", num), nil) + yield(fmt.Sprintf("ROUND(%s)", num), nil, false) } for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } } } @@ -608,34 +608,34 @@ func FnRound(yield Query) { func FnTruncate(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } } } func FnCrc32(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } } @@ -643,10 +643,10 @@ func FnConv(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -654,10 +654,10 @@ func FnConv(yield Query) { for _, num1 := range radianInputs { for _, num2 := range inputBitwise { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -665,10 +665,10 @@ func FnConv(yield Query) { for _, num1 := range inputBitwise { for _, num2 := range inputBitwise { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -676,50 +676,50 @@ func FnConv(yield Query) { func FnBin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("BIN(%s)", num), nil) + yield(fmt.Sprintf("BIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("BIN(%s)", num), nil) + yield(fmt.Sprintf("BIN(%s)", num), nil, false) } } func FnOct(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("OCT(%s)", num), nil) + yield(fmt.Sprintf("OCT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("OCT(%s)", num), nil) + yield(fmt.Sprintf("OCT(%s)", num), nil, false) } } func FnMD5(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } } func FnSHA1(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } } @@ -727,28 +727,28 @@ func FnSHA2(yield Query) { bitLengths := []string{"0", "224", "256", "384", "512", "1", "0.1", "256.1e0", "1-1", "128+128"} for _, bits := range bitLengths { for _, num := range radianInputs { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } } } func FnRandomBytes(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil) - yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil) + yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil, false) + yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil) - yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil) + yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil, false) + yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil, false) } } @@ -762,7 +762,7 @@ func CaseExprWithValue(yield Query) { if !(bugs{}).CanCompare(cmpbase, val1) { continue } - yield(fmt.Sprintf("case %s when %s then 1 else 0 end", cmpbase, val1), nil) + yield(fmt.Sprintf("case %s when %s then 1 else 0 end", cmpbase, val1), nil, false) } } } @@ -775,7 +775,7 @@ func If(yield Query) { for _, cmpbase := range elements { for _, val1 := range elements { for _, val2 := range elements { - yield(fmt.Sprintf("if(%s, %s, %s)", cmpbase, val1, val2), nil) + yield(fmt.Sprintf("if(%s, %s, %s)", cmpbase, val1, val2), nil, false) } } } @@ -796,17 +796,17 @@ func Base64(yield Query) { } for _, lhs := range inputs { - yield(fmt.Sprintf("FROM_BASE64(%s)", lhs), nil) - yield(fmt.Sprintf("TO_BASE64(%s)", lhs), nil) + yield(fmt.Sprintf("FROM_BASE64(%s)", lhs), nil, false) + yield(fmt.Sprintf("TO_BASE64(%s)", lhs), nil, false) } } func Conversion(yield Query) { for _, lhs := range inputConversions { for _, rhs := range inputConversionTypes { - yield(fmt.Sprintf("CAST(%s AS %s)", lhs, rhs), nil) - yield(fmt.Sprintf("CONVERT(%s, %s)", lhs, rhs), nil) - yield(fmt.Sprintf("CAST(CAST(%s AS JSON) AS %s)", lhs, rhs), nil) + yield(fmt.Sprintf("CAST(%s AS %s)", lhs, rhs), nil, false) + yield(fmt.Sprintf("CONVERT(%s, %s)", lhs, rhs), nil, false) + yield(fmt.Sprintf("CAST(CAST(%s AS JSON) AS %s)", lhs, rhs), nil, false) } } } @@ -815,8 +815,8 @@ func LargeDecimals(yield Query) { var largepi = inputPi + inputPi for pos := 0; pos < len(largepi); pos++ { - yield(fmt.Sprintf("%s.%s", largepi[:pos], largepi[pos:]), nil) - yield(fmt.Sprintf("-%s.%s", largepi[:pos], largepi[pos:]), nil) + yield(fmt.Sprintf("%s.%s", largepi[:pos], largepi[pos:]), nil, false) + yield(fmt.Sprintf("-%s.%s", largepi[:pos], largepi[pos:]), nil, false) } } @@ -824,8 +824,8 @@ func LargeIntegers(yield Query) { var largepi = inputPi + inputPi for pos := 1; pos < len(largepi); pos++ { - yield(largepi[:pos], nil) - yield(fmt.Sprintf("-%s", largepi[:pos]), nil) + yield(largepi[:pos], nil, false) + yield(fmt.Sprintf("-%s", largepi[:pos]), nil, false) } } @@ -833,7 +833,7 @@ func DecimalClamping(yield Query) { for pos := 0; pos < len(inputPi); pos++ { for m := 0; m < min(len(inputPi), 67); m += 2 { for d := 0; d <= min(m, 33); d += 2 { - yield(fmt.Sprintf("CAST(%s.%s AS DECIMAL(%d, %d))", inputPi[:pos], inputPi[pos:], m, d), nil) + yield(fmt.Sprintf("CAST(%s.%s AS DECIMAL(%d, %d))", inputPi[:pos], inputPi[pos:], m, d), nil, false) } } } @@ -842,7 +842,7 @@ func DecimalClamping(yield Query) { func BitwiseOperatorsUnary(yield Query) { for _, op := range []string{"~", "BIT_COUNT"} { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s(%s)", op, rhs), nil) + yield(fmt.Sprintf("%s(%s)", op, rhs), nil, false) } } } @@ -851,13 +851,13 @@ func BitwiseOperators(yield Query) { for _, op := range []string{"&", "|", "^", "<<", ">>"} { for _, lhs := range inputBitwise { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } for _, lhs := range inputConversions { for _, rhs := range inputConversions { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -910,7 +910,7 @@ func WeightString(yield Query) { } for _, i := range inputs { - yield(fmt.Sprintf("WEIGHT_STRING(%s)", i), nil) + yield(fmt.Sprintf("WEIGHT_STRING(%s)", i), nil, false) } } @@ -927,18 +927,18 @@ func FloatFormatting(yield Query) { } for _, f := range floats { - yield(fmt.Sprintf("%s + 0.0e0", f), nil) - yield(fmt.Sprintf("-%s", f), nil) + yield(fmt.Sprintf("%s + 0.0e0", f), nil, false) + yield(fmt.Sprintf("-%s", f), nil, false) } for i := 0; i < 64; i++ { v := uint64(1) << i - yield(fmt.Sprintf("%d + 0.0e0", v), nil) - yield(fmt.Sprintf("%d + 0.0e0", v+1), nil) - yield(fmt.Sprintf("%d + 0.0e0", ^v), nil) - yield(fmt.Sprintf("-%de0", v), nil) - yield(fmt.Sprintf("-%de0", v+1), nil) - yield(fmt.Sprintf("-%de0", ^v), nil) + yield(fmt.Sprintf("%d + 0.0e0", v), nil, false) + yield(fmt.Sprintf("%d + 0.0e0", v+1), nil, false) + yield(fmt.Sprintf("%d + 0.0e0", ^v), nil, false) + yield(fmt.Sprintf("-%de0", v), nil, false) + yield(fmt.Sprintf("-%de0", v+1), nil, false) + yield(fmt.Sprintf("-%de0", ^v), nil, false) } } @@ -962,7 +962,7 @@ func UnderscoreAndPercentage(yield Query) { `'poke\_mon' = 'poke\_mon'`, } for _, query := range queries { - yield(query, nil) + yield(query, nil, false) } } @@ -993,7 +993,7 @@ func Types(yield Query) { } for _, query := range queries { - yield(query, nil) + yield(query, nil, false) } } @@ -1003,13 +1003,13 @@ func Arithmetic(yield Query) { for _, op := range operators { for _, lhs := range inputConversions { for _, rhs := range inputConversions { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } for _, lhs := range inputBitwise { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -1025,9 +1025,9 @@ func HexArithmetic(yield Query) { for _, lhs := range cases { for _, rhs := range cases { - yield(fmt.Sprintf("%s + %s", lhs, rhs), nil) + yield(fmt.Sprintf("%s + %s", lhs, rhs), nil, false) // compare with negative values too - yield(fmt.Sprintf("-%s + -%s", lhs, rhs), nil) + yield(fmt.Sprintf("-%s + -%s", lhs, rhs), nil, false) } } } @@ -1055,7 +1055,7 @@ func NumericTypes(yield Query) { } for _, rhs := range numbers { - yield(rhs, nil) + yield(rhs, nil, false) } } @@ -1072,13 +1072,13 @@ func NegateArithmetic(yield Query) { } for _, rhs := range cases { - yield(fmt.Sprintf("- %s", rhs), nil) - yield(fmt.Sprintf("-%s", rhs), nil) + yield(fmt.Sprintf("- %s", rhs), nil, false) + yield(fmt.Sprintf("-%s", rhs), nil, false) } for _, rhs := range inputConversions { - yield(fmt.Sprintf("- %s", rhs), nil) - yield(fmt.Sprintf("-%s", rhs), nil) + yield(fmt.Sprintf("- %s", rhs), nil, false) + yield(fmt.Sprintf("-%s", rhs), nil, false) } } @@ -1092,7 +1092,7 @@ func CollationOperations(yield Query) { } for _, expr := range cases { - yield(expr, nil) + yield(expr, nil, false) } } @@ -1115,7 +1115,7 @@ func LikeComparison(yield Query) { for _, lhs := range left { for _, rhs := range right { for _, op := range []string{"LIKE", "NOT LIKE"} { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -1149,7 +1149,7 @@ func StrcmpComparison(yield Query) { for _, lhs := range inputs { for _, rhs := range inputs { - yield(fmt.Sprintf("STRCMP(%s, %s)", lhs, rhs), nil) + yield(fmt.Sprintf("STRCMP(%s, %s)", lhs, rhs), nil, false) } } } @@ -1168,7 +1168,7 @@ func MultiComparisons(yield Query) { `"0"`, `"-1"`, `"1"`, `_utf8mb4 'foobar'`, `_utf8mb4 'FOOBAR'`, `_binary '0'`, `_binary '-1'`, `_binary '1'`, - `0x0`, `0x1`, `-0x0`, `-0x1`, + `0x0`, `0x1`, "_utf8mb4 'Abc' COLLATE utf8mb4_0900_as_ci", "_utf8mb4 'aBC' COLLATE utf8mb4_0900_as_ci", "_utf8mb4 'ǍḄÇ' COLLATE utf8mb4_0900_as_ci", @@ -1183,17 +1183,37 @@ func MultiComparisons(yield Query) { "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs", "_utf8mb4 'の東京ノ' COLLATE utf8mb4_ja_0900_as_cs_ks", "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs_ks", + `date'2024-02-18'`, + `date'2023-02-01'`, + `date'2100-02-01'`, + `timestamp'2020-12-31 23:59:59'`, + `timestamp'2025-01-01 00:00:00.123456'`, + `time'23:59:59.5432'`, + `time'120:59:59'`, } for _, method := range []string{"LEAST", "GREATEST"} { + skip := func(arg []string) bool { + skipCollations := false + for _, a := range arg { + if strings.Contains(a, "date'") || strings.Contains(a, "time'") || strings.Contains(a, "timestamp'") { + skipCollations = true + break + } + } + return skipCollations + } + genSubsets(numbers, 2, func(arg []string) { - yield(fmt.Sprintf("%s(%s, %s)", method, arg[0], arg[1]), nil) - yield(fmt.Sprintf("%s(%s, %s)", method, arg[1], arg[0]), nil) + skipCollations := skip(arg) + yield(fmt.Sprintf("%s(%s, %s)", method, arg[0], arg[1]), nil, skipCollations) + yield(fmt.Sprintf("%s(%s, %s)", method, arg[1], arg[0]), nil, skipCollations) }) genSubsets(numbers, 3, func(arg []string) { - yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[0], arg[1], arg[2]), nil) - yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[2], arg[1], arg[0]), nil) + skipCollations := skip(arg) + yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[0], arg[1], arg[2]), nil, skipCollations) + yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[2], arg[1], arg[0]), nil, skipCollations) }) } } @@ -1213,7 +1233,7 @@ func IntervalStatement(yield Query) { for _, arg1 := range inputs { for _, arg2 := range inputs { for _, arg3 := range inputs { - yield(fmt.Sprintf("INTERVAL(%s, %s, %s, %s)", base, arg1, arg2, arg3), nil) + yield(fmt.Sprintf("INTERVAL(%s, %s, %s, %s)", base, arg1, arg2, arg3), nil, false) } } } @@ -1238,7 +1258,7 @@ func IsStatement(yield Query) { for _, l := range left { for _, r := range right { - yield(fmt.Sprintf("%s IS %s", l, r), nil) + yield(fmt.Sprintf("%s IS %s", l, r), nil, false) } } } @@ -1247,7 +1267,7 @@ func NotStatement(yield Query) { var ops = []string{"NOT", "!"} for _, op := range ops { for _, i := range inputConversions { - yield(fmt.Sprintf("%s %s", op, i), nil) + yield(fmt.Sprintf("%s %s", op, i), nil, false) } } } @@ -1257,7 +1277,7 @@ func LogicalStatement(yield Query) { for _, op := range ops { for _, l := range inputConversions { for _, r := range inputConversions { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } } @@ -1275,7 +1295,7 @@ func TupleComparisons(yield Query) { for _, op := range operators { for i := 0; i < len(tuples); i++ { for j := 0; j < len(tuples); j++ { - yield(fmt.Sprintf("%s %s %s", tuples[i], op, tuples[j]), nil) + yield(fmt.Sprintf("%s %s %s", tuples[i], op, tuples[j]), nil, false) } } } @@ -1286,13 +1306,13 @@ func Comparisons(yield Query) { for _, op := range operators { for _, l := range inputComparisonElement { for _, r := range inputComparisonElement { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } for _, l := range inputConversions { for _, r := range inputConversions { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } } @@ -1331,9 +1351,9 @@ func JSONExtract(yield Query) { expr2 := fmt.Sprintf("cast(%s as char) <=> %s", expr0, expr1) for _, row := range rows { - yield(expr0, []sqltypes.Value{row}) - yield(expr1, []sqltypes.Value{row}) - yield(expr2, []sqltypes.Value{row}) + yield(expr0, []sqltypes.Value{row}, false) + yield(expr1, []sqltypes.Value{row}, false) + yield(expr2, []sqltypes.Value{row}, false) } } } @@ -1350,7 +1370,7 @@ func FnField(yield Query) { for _, s1 := range inputStrings { for _, s2 := range inputStrings { for _, s3 := range inputStrings { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1358,7 +1378,7 @@ func FnField(yield Query) { for _, s1 := range radianInputs { for _, s2 := range radianInputs { for _, s3 := range radianInputs { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1367,7 +1387,7 @@ func FnField(yield Query) { for _, s1 := range inputStrings { for _, s2 := range radianInputs { for _, s3 := range inputStrings { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1376,7 +1396,7 @@ func FnField(yield Query) { for _, s1 := range inputBitwise { for _, s2 := range inputBitwise { for _, s3 := range inputBitwise { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1386,21 +1406,21 @@ func FnField(yield Query) { "FIELD('Gg', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnElt(yield Query) { for _, s1 := range inputStrings { for _, n := range inputBitwise { - yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil) + yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil, false) } } for _, s1 := range inputStrings { for _, s2 := range inputStrings { for _, n := range inputBitwise { - yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil) + yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil, false) } } } @@ -1414,7 +1434,7 @@ func FnElt(yield Query) { for _, s2 := range inputStrings { for _, s3 := range inputStrings { for _, n := range validIndex { - yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil) + yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil, false) } } } @@ -1426,7 +1446,7 @@ func FnElt(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -1435,7 +1455,7 @@ func FnInsert(yield Query) { for _, ns := range insertStrings { for _, l := range inputBitwise { for _, p := range inputBitwise { - yield(fmt.Sprintf("INSERT(%s, %s, %s, %s)", s, p, l, ns), nil) + yield(fmt.Sprintf("INSERT(%s, %s, %s, %s)", s, p, l, ns), nil, false) } } } @@ -1448,53 +1468,53 @@ func FnInsert(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnLower(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("LOWER(%s)", str), nil) - yield(fmt.Sprintf("LCASE(%s)", str), nil) + yield(fmt.Sprintf("LOWER(%s)", str), nil, false) + yield(fmt.Sprintf("LCASE(%s)", str), nil, false) } } func FnUpper(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("UPPER(%s)", str), nil) - yield(fmt.Sprintf("UCASE(%s)", str), nil) + yield(fmt.Sprintf("UPPER(%s)", str), nil, false) + yield(fmt.Sprintf("UCASE(%s)", str), nil, false) } } func FnCharLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CHAR_LENGTH(%s)", str), nil) - yield(fmt.Sprintf("CHARACTER_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("CHAR_LENGTH(%s)", str), nil, false) + yield(fmt.Sprintf("CHARACTER_LENGTH(%s)", str), nil, false) } } func FnLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("LENGTH(%s)", str), nil) - yield(fmt.Sprintf("OCTET_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("LENGTH(%s)", str), nil, false) + yield(fmt.Sprintf("OCTET_LENGTH(%s)", str), nil, false) } } func FnBitLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil, false) } } func FnAscii(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("ASCII(%s)", str), nil) + yield(fmt.Sprintf("ASCII(%s)", str), nil, false) } } func FnReverse(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("REVERSE(%s)", str), nil) + yield(fmt.Sprintf("REVERSE(%s)", str), nil, false) } } @@ -1516,13 +1536,13 @@ func FnSpace(yield Query) { } for _, c := range counts { - yield(fmt.Sprintf("SPACE(%s)", c), nil) + yield(fmt.Sprintf("SPACE(%s)", c), nil, false) } } func FnOrd(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("ORD(%s)", str), nil) + yield(fmt.Sprintf("ORD(%s)", str), nil, false) } } @@ -1530,7 +1550,7 @@ func FnRepeat(yield Query) { counts := []string{"-1", "1.9", "3", "1073741825", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("REPEAT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("REPEAT(%s, %s)", str, cnt), nil, false) } } } @@ -1539,7 +1559,7 @@ func FnLeft(yield Query) { counts := []string{"-1", "1.9", "3", "10", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("LEFT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("LEFT(%s, %s)", str, cnt), nil, false) } } } @@ -1549,7 +1569,7 @@ func FnLpad(yield Query) { for _, str := range inputStrings { for _, cnt := range counts { for _, pad := range inputStrings { - yield(fmt.Sprintf("LPAD(%s, %s, %s)", str, cnt, pad), nil) + yield(fmt.Sprintf("LPAD(%s, %s, %s)", str, cnt, pad), nil, false) } } } @@ -1559,7 +1579,7 @@ func FnRight(yield Query) { counts := []string{"-1", "1.9", "3", "10", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("RIGHT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("RIGHT(%s, %s)", str, cnt), nil, false) } } } @@ -1569,7 +1589,7 @@ func FnRpad(yield Query) { for _, str := range inputStrings { for _, cnt := range counts { for _, pad := range inputStrings { - yield(fmt.Sprintf("RPAD(%s, %s, %s)", str, cnt, pad), nil) + yield(fmt.Sprintf("RPAD(%s, %s, %s)", str, cnt, pad), nil, false) } } } @@ -1577,33 +1597,33 @@ func FnRpad(yield Query) { func FnLTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("LTRIM(%s)", str), nil) + yield(fmt.Sprintf("LTRIM(%s)", str), nil, false) } } func FnRTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("RTRIM(%s)", str), nil) + yield(fmt.Sprintf("RTRIM(%s)", str), nil, false) } } func FnTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("TRIM(%s)", str), nil) + yield(fmt.Sprintf("TRIM(%s)", str), nil, false) } modes := []string{"LEADING", "TRAILING", "BOTH"} for _, str := range inputTrimStrings { for _, mode := range modes { - yield(fmt.Sprintf("TRIM(%s FROM %s)", mode, str), nil) + yield(fmt.Sprintf("TRIM(%s FROM %s)", mode, str), nil, false) } } for _, str := range inputTrimStrings { for _, pat := range inputTrimStrings { - yield(fmt.Sprintf("TRIM(%s FROM %s)", pat, str), nil) + yield(fmt.Sprintf("TRIM(%s FROM %s)", pat, str), nil, false) for _, mode := range modes { - yield(fmt.Sprintf("TRIM(%s %s FROM %s)", mode, pat, str), nil) + yield(fmt.Sprintf("TRIM(%s %s FROM %s)", mode, pat, str), nil, false) } } } @@ -1628,15 +1648,15 @@ func FnSubstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, str := range inputStrings { for _, i := range radianInputs { - yield(fmt.Sprintf("SUBSTRING(%s, %s)", str, i), nil) + yield(fmt.Sprintf("SUBSTRING(%s, %s)", str, i), nil, false) for _, j := range radianInputs { - yield(fmt.Sprintf("SUBSTRING(%s, %s, %s)", str, i, j), nil) + yield(fmt.Sprintf("SUBSTRING(%s, %s, %s)", str, i, j), nil, false) } } } @@ -1654,17 +1674,17 @@ func FnLocate(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, substr := range locateStrings { for _, str := range locateStrings { - yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil) - yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil) - yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil) + yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil, false) + yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil, false) + yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil, false) for _, i := range radianInputs { - yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil) + yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil, false) } } } @@ -1685,13 +1705,13 @@ func FnReplace(yield Query) { } for _, q := range cases { - yield(q, nil) + yield(q, nil, false) } for _, substr := range inputStrings { for _, str := range inputStrings { for _, i := range inputStrings { - yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil) + yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil, false) } } } @@ -1699,19 +1719,19 @@ func FnReplace(yield Query) { func FnConcat(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CONCAT(%s)", str), nil) + yield(fmt.Sprintf("CONCAT(%s)", str), nil, false) } for _, str1 := range inputConversions { for _, str2 := range inputConversions { - yield(fmt.Sprintf("CONCAT(%s, %s)", str1, str2), nil) + yield(fmt.Sprintf("CONCAT(%s, %s)", str1, str2), nil, false) } } for _, str1 := range inputStrings { for _, str2 := range inputStrings { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1719,13 +1739,13 @@ func FnConcat(yield Query) { func FnConcatWs(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, NULL)", str), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, NULL)", str), nil, false) } for _, str1 := range inputConversions { for _, str2 := range inputStrings { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1733,7 +1753,7 @@ func FnConcatWs(yield Query) { for _, str1 := range inputStrings { for _, str2 := range inputConversions { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1741,7 +1761,7 @@ func FnConcatWs(yield Query) { for _, str1 := range inputStrings { for _, str2 := range inputStrings { for _, str3 := range inputConversions { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1760,13 +1780,13 @@ func FnChar(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, i1 := range radianInputs { for _, i2 := range inputBitwise { for _, i3 := range inputConversions { - yield(fmt.Sprintf("CHAR(%s, %s, %s)", i1, i2, i3), nil) + yield(fmt.Sprintf("CHAR(%s, %s, %s)", i1, i2, i3), nil, false) } } } @@ -1774,15 +1794,15 @@ func FnChar(yield Query) { func FnHex(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } for _, str := range inputConversions { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } for _, str := range inputBitwise { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } } @@ -1802,7 +1822,7 @@ func FnUnhex(yield Query) { } for _, lhs := range inputs { - yield(fmt.Sprintf("UNHEX(%s)", lhs), nil) + yield(fmt.Sprintf("UNHEX(%s)", lhs), nil, false) } } @@ -1814,15 +1834,15 @@ func InStatement(yield Query) { if !(bugs{}).CanCompare(inputs...) { return } - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) - yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil, false) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil, false) }) } @@ -1845,7 +1865,7 @@ func FnNow(yield Query) { "SYSDATE(1)", "SYSDATE(2)", "SYSDATE(3)", "SYSDATE(4)", "SYSDATE(5)", } for _, fn := range fns { - yield(fn, nil) + yield(fn, nil, false) } } @@ -1857,7 +1877,7 @@ func FnInfo(yield Query) { "VERSION()", } for _, fn := range fns { - yield(fn, nil) + yield(fn, nil, false) } } @@ -1871,7 +1891,7 @@ func FnDateFormat(yield Query) { format := buf.String() for _, d := range inputConversions { - yield(fmt.Sprintf("DATE_FORMAT(%s, %q)", d, format), nil) + yield(fmt.Sprintf("DATE_FORMAT(%s, %q)", d, format), nil, false) } } @@ -1897,7 +1917,7 @@ func FnConvertTz(yield Query) { for _, tzFrom := range timezoneInputs { for _, tzTo := range timezoneInputs { q := fmt.Sprintf("CONVERT_TZ(%s, '%s', '%s')", num1, tzFrom, tzTo) - yield(q, nil) + yield(q, nil, false) } } } @@ -1905,26 +1925,26 @@ func FnConvertTz(yield Query) { func FnDate(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DATE(%s)", d), nil) + yield(fmt.Sprintf("DATE(%s)", d), nil, false) } } func FnDayOfMonth(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFMONTH(%s)", d), nil) - yield(fmt.Sprintf("DAY(%s)", d), nil) + yield(fmt.Sprintf("DAYOFMONTH(%s)", d), nil, false) + yield(fmt.Sprintf("DAY(%s)", d), nil, false) } } func FnDayOfWeek(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFWEEK(%s)", d), nil) + yield(fmt.Sprintf("DAYOFWEEK(%s)", d), nil, false) } } func FnDayOfYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFYEAR(%s)", d), nil) + yield(fmt.Sprintf("DAYOFYEAR(%s)", d), nil, false) } } @@ -1938,21 +1958,21 @@ func FnFromUnixtime(yield Query) { format := buf.String() for _, d := range inputConversions { - yield(fmt.Sprintf("FROM_UNIXTIME(%s)", d), nil) - yield(fmt.Sprintf("FROM_UNIXTIME(%s, %q)", d, format), nil) + yield(fmt.Sprintf("FROM_UNIXTIME(%s)", d), nil, false) + yield(fmt.Sprintf("FROM_UNIXTIME(%s, %q)", d, format), nil, false) } } func FnHour(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("HOUR(%s)", d), nil) + yield(fmt.Sprintf("HOUR(%s)", d), nil, false) } } func FnMakedate(yield Query) { for _, y := range inputConversions { for _, d := range inputConversions { - yield(fmt.Sprintf("MAKEDATE(%s, %s)", y, d), nil) + yield(fmt.Sprintf("MAKEDATE(%s, %s)", y, d), nil, false) } } } @@ -1969,7 +1989,7 @@ func FnMaketime(yield Query) { } for _, m := range minutes { for _, s := range inputConversions { - yield(fmt.Sprintf("MAKETIME(%s, %s, %s)", h, m, s), nil) + yield(fmt.Sprintf("MAKETIME(%s, %s, %s)", h, m, s), nil, false) } } } @@ -1977,31 +1997,31 @@ func FnMaketime(yield Query) { func FnMicroSecond(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MICROSECOND(%s)", d), nil) + yield(fmt.Sprintf("MICROSECOND(%s)", d), nil, false) } } func FnMinute(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MINUTE(%s)", d), nil) + yield(fmt.Sprintf("MINUTE(%s)", d), nil, false) } } func FnMonth(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MONTH(%s)", d), nil) + yield(fmt.Sprintf("MONTH(%s)", d), nil, false) } } func FnMonthName(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MONTHNAME(%s)", d), nil) + yield(fmt.Sprintf("MONTHNAME(%s)", d), nil, false) } } func FnLastDay(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("LAST_DAY(%s)", d), nil) + yield(fmt.Sprintf("LAST_DAY(%s)", d), nil, false) } dates := []string{ @@ -2018,13 +2038,13 @@ func FnLastDay(yield Query) { } for _, d := range dates { - yield(fmt.Sprintf("LAST_DAY(%s)", d), nil) + yield(fmt.Sprintf("LAST_DAY(%s)", d), nil, false) } } func FnToDays(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TO_DAYS(%s)", d), nil) + yield(fmt.Sprintf("TO_DAYS(%s)", d), nil, false) } dates := []string{ @@ -2042,13 +2062,13 @@ func FnToDays(yield Query) { } for _, d := range dates { - yield(fmt.Sprintf("TO_DAYS(%s)", d), nil) + yield(fmt.Sprintf("TO_DAYS(%s)", d), nil, false) } } func FnFromDays(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil) + yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil, false) } days := []string{ @@ -2064,13 +2084,13 @@ func FnFromDays(yield Query) { } for _, d := range days { - yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil) + yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil, false) } } func FnSecToTime(yield Query) { for _, s := range inputConversions { - yield(fmt.Sprintf("SEC_TO_TIME(%s)", s), nil) + yield(fmt.Sprintf("SEC_TO_TIME(%s)", s), nil, false) } mysqlDocSamples := []string{ @@ -2079,13 +2099,13 @@ func FnSecToTime(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnTimeToSec(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TIME_TO_SEC(%s)", d), nil) + yield(fmt.Sprintf("TIME_TO_SEC(%s)", d), nil, false) } time := []string{ @@ -2103,13 +2123,13 @@ func FnTimeToSec(yield Query) { } for _, t := range time { - yield(fmt.Sprintf("TIME_TO_SEC(%s)", t), nil) + yield(fmt.Sprintf("TIME_TO_SEC(%s)", t), nil, false) } } func FnToSeconds(yield Query) { for _, t := range inputConversions { - yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil) + yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil, false) } timeInputs := []string{ @@ -2127,7 +2147,7 @@ func FnToSeconds(yield Query) { } for _, t := range timeInputs { - yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil) + yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil, false) } mysqlDocSamples := []string{ @@ -2137,25 +2157,25 @@ func FnToSeconds(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnQuarter(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("QUARTER(%s)", d), nil) + yield(fmt.Sprintf("QUARTER(%s)", d), nil, false) } } func FnSecond(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("SECOND(%s)", d), nil) + yield(fmt.Sprintf("SECOND(%s)", d), nil, false) } } func FnTime(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TIME(%s)", d), nil) + yield(fmt.Sprintf("TIME(%s)", d), nil, false) } times := []string{ "'00:00:00'", @@ -2174,68 +2194,68 @@ func FnTime(yield Query) { } for _, d := range times { - yield(fmt.Sprintf("TIME(%s)", d), nil) + yield(fmt.Sprintf("TIME(%s)", d), nil, false) } } func FnUnixTimestamp(yield Query) { - yield("UNIX_TIMESTAMP()", nil) + yield("UNIX_TIMESTAMP()", nil, false) for _, d := range inputConversions { - yield(fmt.Sprintf("UNIX_TIMESTAMP(%s)", d), nil) - yield(fmt.Sprintf("UNIX_TIMESTAMP(%s) + 1", d), nil) + yield(fmt.Sprintf("UNIX_TIMESTAMP(%s)", d), nil, false) + yield(fmt.Sprintf("UNIX_TIMESTAMP(%s) + 1", d), nil, false) } } func FnWeek(yield Query) { for i := 0; i < 16; i++ { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEK(%s, %d)", d, i), nil) + yield(fmt.Sprintf("WEEK(%s, %d)", d, i), nil, false) } } for _, d := range inputConversions { - yield(fmt.Sprintf("WEEK(%s)", d), nil) + yield(fmt.Sprintf("WEEK(%s)", d), nil, false) } } func FnWeekDay(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEKDAY(%s)", d), nil) + yield(fmt.Sprintf("WEEKDAY(%s)", d), nil, false) } } func FnWeekOfYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEKOFYEAR(%s)", d), nil) + yield(fmt.Sprintf("WEEKOFYEAR(%s)", d), nil, false) } } func FnYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("YEAR(%s)", d), nil) + yield(fmt.Sprintf("YEAR(%s)", d), nil, false) } } func FnYearWeek(yield Query) { for i := 0; i < 8; i++ { for _, d := range inputConversions { - yield(fmt.Sprintf("YEARWEEK(%s, %d)", d, i), nil) + yield(fmt.Sprintf("YEARWEEK(%s, %d)", d, i), nil, false) } } for _, d := range inputConversions { - yield(fmt.Sprintf("YEARWEEK(%s)", d), nil) + yield(fmt.Sprintf("YEARWEEK(%s)", d), nil, false) } } func FnPeriodAdd(yield Query) { for _, p := range inputBitwise { for _, m := range inputBitwise { - yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil) + yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil, false) } } for _, p := range inputPeriods { for _, m := range inputBitwise { - yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil) + yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil, false) } } @@ -2244,19 +2264,19 @@ func FnPeriodAdd(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnPeriodDiff(yield Query) { for _, p1 := range inputBitwise { for _, p2 := range inputBitwise { - yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil) + yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil, false) } } for _, p1 := range inputPeriods { for _, p2 := range inputPeriods { - yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil) + yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil, false) } } @@ -2265,59 +2285,59 @@ func FnPeriodDiff(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnInetAton(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET_ATON(%s)", d), nil) + yield(fmt.Sprintf("INET_ATON(%s)", d), nil, false) } } func FnInetNtoa(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET_NTOA(%s)", d), nil) - yield(fmt.Sprintf("INET_NTOA(INET_ATON(%s))", d), nil) + yield(fmt.Sprintf("INET_NTOA(%s)", d), nil, false) + yield(fmt.Sprintf("INET_NTOA(INET_ATON(%s))", d), nil, false) } } func FnInet6Aton(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET6_ATON(%s)", d), nil) + yield(fmt.Sprintf("INET6_ATON(%s)", d), nil, false) } } func FnInet6Ntoa(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET6_NTOA(%s)", d), nil) - yield(fmt.Sprintf("INET6_NTOA(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("INET6_NTOA(%s)", d), nil, false) + yield(fmt.Sprintf("INET6_NTOA(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv4(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4(%s)", d), nil) + yield(fmt.Sprintf("IS_IPV4(%s)", d), nil, false) } } func FnIsIPv4Compat(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4_COMPAT(%s)", d), nil) - yield(fmt.Sprintf("IS_IPV4_COMPAT(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("IS_IPV4_COMPAT(%s)", d), nil, false) + yield(fmt.Sprintf("IS_IPV4_COMPAT(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv4Mapped(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4_MAPPED(%s)", d), nil) - yield(fmt.Sprintf("IS_IPV4_MAPPED(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("IS_IPV4_MAPPED(%s)", d), nil, false) + yield(fmt.Sprintf("IS_IPV4_MAPPED(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv6(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV6(%s)", d), nil) + yield(fmt.Sprintf("IS_IPV6(%s)", d), nil, false) } } @@ -2335,27 +2355,27 @@ func FnBinToUUID(yield Query) { "'2'", } for _, d := range uuidInputs { - yield(fmt.Sprintf("BIN_TO_UUID(%s)", d), nil) + yield(fmt.Sprintf("BIN_TO_UUID(%s)", d), nil, false) } for _, d := range uuidInputs { for _, a := range args { - yield(fmt.Sprintf("BIN_TO_UUID(%s, %s)", d, a), nil) + yield(fmt.Sprintf("BIN_TO_UUID(%s, %s)", d, a), nil, false) } } } func FnIsUUID(yield Query) { for _, d := range uuidInputs { - yield(fmt.Sprintf("IS_UUID(%s)", d), nil) + yield(fmt.Sprintf("IS_UUID(%s)", d), nil, false) } } func FnUUID(yield Query) { - yield("LENGTH(UUID())", nil) - yield("COLLATION(UUID())", nil) - yield("IS_UUID(UUID())", nil) - yield("LENGTH(UUID_TO_BIN(UUID())", nil) + yield("LENGTH(UUID())", nil, false) + yield("COLLATION(UUID())", nil, false) + yield("IS_UUID(UUID())", nil, false) + yield("LENGTH(UUID_TO_BIN(UUID())", nil, false) } func FnUUIDToBin(yield Query) { @@ -2372,12 +2392,12 @@ func FnUUIDToBin(yield Query) { "'2'", } for _, d := range uuidInputs { - yield(fmt.Sprintf("UUID_TO_BIN(%s)", d), nil) + yield(fmt.Sprintf("UUID_TO_BIN(%s)", d), nil, false) } for _, d := range uuidInputs { for _, a := range args { - yield(fmt.Sprintf("UUID_TO_BIN(%s, %s)", d, a), nil) + yield(fmt.Sprintf("UUID_TO_BIN(%s, %s)", d, a), nil, false) } } } @@ -2418,15 +2438,15 @@ func DateMath(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, d := range dates { for _, i := range inputIntervals { for _, v := range intervalValues { - yield(fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", d, v, i), nil) - yield(fmt.Sprintf("DATE_SUB(%s, INTERVAL %s %s)", d, v, i), nil) - yield(fmt.Sprintf("TIMESTAMPADD(%v, %s, %s)", i, v, d), nil) + yield(fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", d, v, i), nil, false) + yield(fmt.Sprintf("DATE_SUB(%s, INTERVAL %s %s)", d, v, i), nil, false) + yield(fmt.Sprintf("TIMESTAMPADD(%v, %s, %s)", i, v, d), nil, false) } } } @@ -2481,15 +2501,15 @@ func RegexpLike(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, i := range regexInputs { for _, p := range regexInputs { - yield(fmt.Sprintf("%s REGEXP %s", i, p), nil) - yield(fmt.Sprintf("%s NOT REGEXP %s", i, p), nil) + yield(fmt.Sprintf("%s REGEXP %s", i, p), nil, false) + yield(fmt.Sprintf("%s NOT REGEXP %s", i, p), nil, false) for _, m := range regexMatchStrings { - yield(fmt.Sprintf("REGEXP_LIKE(%s, %s, %s)", i, p, m), nil) + yield(fmt.Sprintf("REGEXP_LIKE(%s, %s, %s)", i, p, m), nil, false) } } } @@ -2565,7 +2585,7 @@ func RegexpInstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -2632,7 +2652,7 @@ func RegexpSubstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -2712,6 +2732,6 @@ func RegexpReplace(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } diff --git a/go/vt/vtgate/evalengine/testcases/helpers.go b/go/vt/vtgate/evalengine/testcases/helpers.go index 71602e12c1c..db5ad6475b4 100644 --- a/go/vt/vtgate/evalengine/testcases/helpers.go +++ b/go/vt/vtgate/evalengine/testcases/helpers.go @@ -30,7 +30,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" ) -type Query func(query string, row []sqltypes.Value) +type Query func(query string, row []sqltypes.Value, skipCollationCheck bool) type Runner func(yield Query) type TestCase struct { Run Runner