Skip to content

Commit

Permalink
*: fix in-compatible behavior when modify value from Navicat GUI (pin…
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored Mar 22, 2018
1 parent 546b5ac commit 5445e17
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 55 deletions.
3 changes: 1 addition & 2 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (
// Set `NoDefaultValueFlag` if this field doesn't have a default value and
// it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field.
setNoDefaultValueFlag(col, hasDefaultValue)

if col.Charset == charset.CharsetBin {
if col.FieldType.EvalType().IsStringKind() && col.Charset == charset.CharsetBin {
col.Flag |= mysql.BinaryFlag
}
if col.Tp == mysql.TypeBit {
Expand Down
20 changes: 13 additions & 7 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ var (
_ builtinFunc = &builtinIfJSONSig{}
)

type caseWhenFunctionClass struct {
baseFunctionClass
}

// Infer result type for builtin IF, IFNULL && NULLIF.
func inferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType {
resultFieldType := &types.FieldType{}
Expand Down Expand Up @@ -85,12 +81,12 @@ func inferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType {
}
if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) {
resultFieldType.Charset, resultFieldType.Collate, resultFieldType.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(lhs.Flag) {
if mysql.HasBinaryFlag(lhs.Flag) || !types.IsNonBinaryStr(rhs) {
resultFieldType.Flag |= mysql.BinaryFlag
}
} else if types.IsNonBinaryStr(rhs) && !types.IsBinaryStr(lhs) {
resultFieldType.Charset, resultFieldType.Collate, resultFieldType.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(rhs.Flag) {
if mysql.HasBinaryFlag(rhs.Flag) || !types.IsNonBinaryStr(lhs) {
resultFieldType.Flag |= mysql.BinaryFlag
}
} else if types.IsBinaryStr(lhs) || types.IsBinaryStr(rhs) || !evalType.IsStringKind() {
Expand Down Expand Up @@ -132,25 +128,31 @@ func inferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType {
return resultFieldType
}

type caseWhenFunctionClass struct {
baseFunctionClass
}

func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) {
if err = c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
l := len(args)
// Fill in each 'THEN' clause parameter type.
fieldTps := make([]*types.FieldType, 0, (l+1)/2)
decimal, flen, isBinaryStr := args[1].GetType().Decimal, 0, false
decimal, flen, isBinaryStr, isBinaryFlag := args[1].GetType().Decimal, 0, false, false
for i := 1; i < l; i += 2 {
fieldTps = append(fieldTps, args[i].GetType())
decimal = mathutil.Max(decimal, args[i].GetType().Decimal)
flen = mathutil.Max(flen, args[i].GetType().Flen)
isBinaryStr = isBinaryStr || types.IsBinaryStr(args[i].GetType())
isBinaryFlag = isBinaryFlag || !types.IsNonBinaryStr(args[i].GetType())
}
if l%2 == 1 {
fieldTps = append(fieldTps, args[l-1].GetType())
decimal = mathutil.Max(decimal, args[l-1].GetType().Decimal)
flen = mathutil.Max(flen, args[l-1].GetType().Flen)
isBinaryStr = isBinaryStr || types.IsBinaryStr(args[l-1].GetType())
isBinaryFlag = isBinaryFlag || !types.IsNonBinaryStr(args[l-1].GetType())
}

fieldTp := types.AggFieldType(fieldTps)
Expand All @@ -163,6 +165,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if fieldTp.EvalType().IsStringKind() && !isBinaryStr {
fieldTp.Charset, fieldTp.Collate = mysql.DefaultCharset, mysql.DefaultCollationName
}
if isBinaryFlag {
fieldTp.Flag |= mysql.BinaryFlag
}
// Set retType to BINARY(0) if all arguments are of type NULL.
if fieldTp.Tp == mysql.TypeNull {
fieldTp.Flen, fieldTp.Decimal = 0, -1
Expand Down Expand Up @@ -395,6 +400,7 @@ func (c *ifFunctionClass) getFunction(ctx sessionctx.Context, args []Expression)
retTp := inferType4ControlFuncs(args[1].GetType(), args[2].GetType())
evalTps := retTp.EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, evalTps, types.ETInt, evalTps, evalTps)
retTp.Flag |= bf.tp.Flag
bf.tp = retTp
switch evalTps {
case types.ETInt:
Expand Down
1 change: 1 addition & 0 deletions expression/builtin_miscellaneous.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ func (c *anyValueFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
argTp := args[0].GetType().EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, argTp)
args[0].GetType().Flag |= bf.tp.Flag
*bf.tp = *args[0].GetType()
var sig builtinFunc
switch argTp {
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func reverseRunes(origin []rune) []rune {
func SetBinFlagOrBinStr(argTp *types.FieldType, resTp *types.FieldType) {
if types.IsBinaryStr(argTp) {
types.SetBinChsClnFlag(resTp)
} else if mysql.HasBinaryFlag(argTp.Flag) {
} else if mysql.HasBinaryFlag(argTp.Flag) || !types.IsNonBinaryStr(argTp) {
resTp.Flag |= mysql.BinaryFlag
}
}
Expand Down
60 changes: 27 additions & 33 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ func (s *testInferTypeSuite) TestInferType(c *C) {
tests = append(tests, s.createTestCase4Literals()...)
tests = append(tests, s.createTestCase4JSONFuncs()...)
tests = append(tests, s.createTestCase4MiscellaneousFunc()...)
tests = append(tests, s.createTestCase4AggregationFunc()...)

for _, tt := range tests {
ctx := testKit.Se.(sessionctx.Context)
Expand Down Expand Up @@ -205,35 +204,35 @@ func (s *testInferTypeSuite) createTestCase4Columns() []typeInferTestCase {
return []typeInferTestCase{
{"c_bit ", mysql.TypeBit, charset.CharsetBin, mysql.UnsignedFlag, 10, 0},
{"c_year ", mysql.TypeYear, charset.CharsetBin, mysql.UnsignedFlag | mysql.ZerofillFlag, 4, 0},
{"c_int_d ", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"c_uint_d ", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 10, 0},
{"c_bigint_d ", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
{"c_ubigint_d ", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 20, 0},
{"c_float_d ", mysql.TypeFloat, charset.CharsetBin, mysql.BinaryFlag, 12, types.UnspecifiedLength},
{"c_ufloat_d ", mysql.TypeFloat, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 12, types.UnspecifiedLength},
{"c_double_d ", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, types.UnspecifiedLength},
{"c_udouble_d ", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 22, types.UnspecifiedLength},
{"c_decimal ", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 3}, // TODO: Flen should be 8
{"c_udecimal ", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 10, 3}, // TODO: Flen should be 11
{"c_decimal_d ", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"c_udecimal_d ", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 11, 0}, // TODO: Flen should be 10
{"c_int_d ", mysql.TypeLong, charset.CharsetBin, 0, 11, 0},
{"c_uint_d ", mysql.TypeLong, charset.CharsetBin, mysql.UnsignedFlag, 10, 0},
{"c_bigint_d ", mysql.TypeLonglong, charset.CharsetBin, 0, 20, 0},
{"c_ubigint_d ", mysql.TypeLonglong, charset.CharsetBin, mysql.UnsignedFlag, 20, 0},
{"c_float_d ", mysql.TypeFloat, charset.CharsetBin, 0, 12, types.UnspecifiedLength},
{"c_ufloat_d ", mysql.TypeFloat, charset.CharsetBin, mysql.UnsignedFlag, 12, types.UnspecifiedLength},
{"c_double_d ", mysql.TypeDouble, charset.CharsetBin, 0, 22, types.UnspecifiedLength},
{"c_udouble_d ", mysql.TypeDouble, charset.CharsetBin, mysql.UnsignedFlag, 22, types.UnspecifiedLength},
{"c_decimal ", mysql.TypeNewDecimal, charset.CharsetBin, 0, 6, 3}, // TODO: Flen should be 8
{"c_udecimal ", mysql.TypeNewDecimal, charset.CharsetBin, mysql.UnsignedFlag, 10, 3}, // TODO: Flen should be 11
{"c_decimal_d ", mysql.TypeNewDecimal, charset.CharsetBin, 0, 11, 0},
{"c_udecimal_d ", mysql.TypeNewDecimal, charset.CharsetBin, mysql.UnsignedFlag, 11, 0}, // TODO: Flen should be 10
{"c_datetime ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, 2},
{"c_datetime_d ", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0},
{"c_time ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"c_time_d ", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, 10, 0},
{"c_timestamp ", mysql.TypeTimestamp, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.TimestampFlag, 24, 4},
{"c_timestamp_d", mysql.TypeTimestamp, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.TimestampFlag, 19, 0},
{"c_char ", mysql.TypeString, charset.CharsetUTF8, 0, 20, 0},
{"c_char ", mysql.TypeString, charset.CharsetUTF8, 0, 20, 0}, // TODO: flag should be BinaryFlag
{"c_bchar ", mysql.TypeString, charset.CharsetUTF8, mysql.BinaryFlag, 20, 0},
{"c_varchar ", mysql.TypeVarchar, charset.CharsetUTF8, 0, 20, 0}, // TODO: tp should be TypeVarString
{"c_bvarchar ", mysql.TypeVarchar, charset.CharsetUTF8, mysql.BinaryFlag, 20, 0}, // TODO: tp should be TypeVarString
{"c_text_d ", mysql.TypeBlob, charset.CharsetUTF8, 0, 65535, 0}, // TODO: BlobFlag
{"c_btext_d ", mysql.TypeBlob, charset.CharsetUTF8, mysql.BinaryFlag, 65535, 0}, // TODO: BlobFlag
{"c_binary ", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
{"c_varbinary ", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 20, 0}, // TODO: tp should be TypeVarString
{"c_blob_d ", mysql.TypeBlob, charset.CharsetBin, mysql.BinaryFlag, 65535, 0}, // TODO: BlobFlag
{"c_set ", mysql.TypeSet, charset.CharsetUTF8, 0, types.UnspecifiedLength, 0}, // TODO: SetFlag, Flen should be 5
{"c_enum ", mysql.TypeEnum, charset.CharsetUTF8, 0, types.UnspecifiedLength, 0}, // TODO: EnumFlag, Flen should be 1
{"c_varchar ", mysql.TypeVarchar, charset.CharsetUTF8, 0, 20, 0}, // TODO: BinaryFlag, tp should be TypeVarString
{"c_bvarchar ", mysql.TypeVarchar, charset.CharsetUTF8, mysql.BinaryFlag, 20, 0}, // TODO: BinaryFlag, tp should be TypeVarString
{"c_text_d ", mysql.TypeBlob, charset.CharsetUTF8, 0, 65535, 0}, // TODO: BlobFlag, BinaryFlag
{"c_btext_d ", mysql.TypeBlob, charset.CharsetUTF8, mysql.BinaryFlag, 65535, 0}, // TODO: BlobFlag, BinaryFlag
{"c_binary ", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, 0}, // TODO: BinaryFlag
{"c_varbinary ", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 20, 0}, // TODO: BinaryFlag, tp should be TypeVarString
{"c_blob_d ", mysql.TypeBlob, charset.CharsetBin, mysql.BinaryFlag, 65535, 0}, // TODO: BlobFlag, BinaryFlag
{"c_set ", mysql.TypeSet, charset.CharsetUTF8, 0, types.UnspecifiedLength, 0}, // TODO: SetFlag, BinaryFlag, Flen should be 5
{"c_enum ", mysql.TypeEnum, charset.CharsetUTF8, 0, types.UnspecifiedLength, 0}, // TODO: EnumFlag, BinaryFlag, Flen should be 1
}
}

Expand Down Expand Up @@ -795,20 +794,20 @@ func (s *testInferTypeSuite) createTestCase4ControlFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"ifnull(c_int_d, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"ifnull(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"ifnull(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"ifnull(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"ifnull(c_int_d, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"ifnull(c_char, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"ifnull(null, null)", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, types.UnspecifiedLength},
{"ifnull(c_double_d, c_timestamp_d)", mysql.TypeVarchar, charset.CharsetUTF8, mysql.NotNullFlag, 22, types.UnspecifiedLength},
{"ifnull(c_json, c_decimal)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, math.MaxUint32, types.UnspecifiedLength},
{"if(c_int_d, c_decimal, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"if(c_int_d, c_char, c_int_d)", mysql.TypeString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"if(c_int_d, c_char, c_int_d)", mysql.TypeString, charset.CharsetUTF8, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"if(c_int_d, c_binary, c_int_d)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"if(c_int_d, c_bchar, c_int_d)", mysql.TypeString, charset.CharsetUTF8, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"if(c_int_d, c_char, c_decimal)", mysql.TypeString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"if(c_int_d, c_char, c_decimal)", mysql.TypeString, charset.CharsetUTF8, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"if(c_int_d, c_datetime, c_int_d)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 22, types.UnspecifiedLength},
{"if(c_int_d, c_int_d, c_double_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, types.UnspecifiedLength},
{"if(c_int_d, c_time_d, c_datetime)", mysql.TypeDatetime, charset.CharsetUTF8, 0, 22, 2},
{"if(c_int_d, c_time_d, c_datetime)", mysql.TypeDatetime, charset.CharsetUTF8, mysql.BinaryFlag, 22, 2}, // TODO: should not be BinaryFlag
{"if(c_int_d, c_time, c_json)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, math.MaxUint32, types.UnspecifiedLength},
{"if(null, null, null)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 0, 0},
{"case when c_int_d then c_char else c_varchar end", mysql.TypeVarchar, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
Expand All @@ -835,6 +834,7 @@ func (s *testInferTypeSuite) createTestCase4Aggregations() []typeInferTestCase {
{"avg(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 5},
{"avg(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"group_concat(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, 0},
}
}

Expand Down Expand Up @@ -1951,9 +1951,3 @@ func (s *testInferTypeSuite) createTestCase4MiscellaneousFunc() []typeInferTestC
{"release_lock(c_text_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
}
}

func (s *testInferTypeSuite) createTestCase4AggregationFunc() []typeInferTestCase {
return []typeInferTestCase{
{"group_concat(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, 0},
}
}
4 changes: 2 additions & 2 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -3485,10 +3485,10 @@ SumExpr:
args := []ast.ExprNode{ast.NewValueExpr(1)}
$$ = &ast.AggregateFuncExpr{F: $1, Args: args}
}
| builtinGroupConcat '(' BuggyDefaultFalseDistinctOpt ExpressionList OptGConcatSeparator ')'
| builtinGroupConcat '(' BuggyDefaultFalseDistinctOpt ExpressionList OrderByOptional OptGConcatSeparator ')'
{
args := $4.([]ast.ExprNode)
args = append(args, $5.(ast.ExprNode))
args = append(args, $6.(ast.ExprNode))
$$ = &ast.AggregateFuncExpr{F: $1, Args: args, Distinct: $3.(bool)}
}
| builtinMax '(' BuggyDefaultFalseDistinctOpt Expression ')'
Expand Down
1 change: 1 addition & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,7 @@ func (s *testParserSuite) TestBuiltin(c *C) {
{`select group_concat(c2,c1 SEPARATOR ';') from t group by c1;`, true},
{`select group_concat(distinct c2,c1) from t group by c1;`, true},
{`select group_concat(distinctrow c2,c1) from t group by c1;`, true},
{`SELECT student_name, GROUP_CONCAT(DISTINCT test_score ORDER BY test_score DESC SEPARATOR ' ') FROM student GROUP BY student_name;`, true},

// for encryption and compression functions
{`select AES_ENCRYPT('text',UNHEX('F3229A0B371ED2D9441B830D21A390C3'))`, true},
Expand Down
10 changes: 5 additions & 5 deletions plan/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) {
},
{
sql: "select a, d from (select * from t union all select * from t union all select * from t) z where a < 10",
best: "UnionAll{DataScan(t)->Sel([lt(cast(test.t.a), 10)])->Projection->Projection->DataScan(t)->Sel([lt(cast(test.t.a), 10)])->Projection->Projection->DataScan(t)->Sel([lt(cast(test.t.a), 10)])->Projection->Projection}->Projection",
best: "UnionAll{DataScan(t)->Projection->DataScan(t)->Projection->DataScan(t)->Projection}->Projection",
},
{
sql: "select (select count(*) from t where t.a = k.a) from t k",
Expand Down Expand Up @@ -729,7 +729,7 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) {
},
{
sql: "select sum(c1) from (select c c1, d c2 from t a union all select a c1, b c2 from t b union all select b c1, e c2 from t c) x group by c2",
best: "UnionAll{DataScan(a)->Projection->Aggr(sum(cast(a.c1)),firstrow(cast(a.c2)))->DataScan(b)->Projection->Aggr(sum(cast(b.c1)),firstrow(cast(b.c2)))->DataScan(c)->Projection->Aggr(sum(cast(c.c1)),firstrow(c.c2))}->Aggr(sum(join_agg_0))->Projection",
best: "UnionAll{DataScan(a)->Aggr(sum(a.c),firstrow(a.d))->DataScan(b)->Aggr(sum(b.a),firstrow(b.b))->DataScan(c)->Aggr(sum(c.b),firstrow(c.e))}->Aggr(sum(join_agg_0))->Projection",
},
{
sql: "select max(a.b), max(b.b) from t a join t b on a.c = b.c group by a.a",
Expand All @@ -741,7 +741,7 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) {
},
{
sql: "select max(c.b) from (select * from t a union all select * from t b) c group by c.a",
best: "UnionAll{DataScan(a)->Projection->Aggr(max(cast(a.b)),firstrow(cast(a.a)))->DataScan(b)->Projection->Aggr(max(cast(b.b)),firstrow(cast(b.a)))}->Aggr(max(join_agg_0))->Projection",
best: "UnionAll{DataScan(a)->Projection->DataScan(b)->Projection}->Projection->Projection",
},
{
sql: "select max(a.c) from t a join t b on a.a=b.a and a.b=b.b group by a.b",
Expand Down Expand Up @@ -1505,12 +1505,12 @@ func (s *testPlanSuite) TestTopNPushDown(c *C) {
// Test TopN + UA + Proj.
{
sql: "select * from t union all (select * from t s) order by a,b limit 5",
best: "UnionAll{DataScan(t)->TopN([cast(test.t.a) cast(test.t.b)],0,5)->Projection->DataScan(s)->TopN([cast(s.a) cast(s.b)],0,5)->Projection}->TopN([t.a t.b],0,5)",
best: "UnionAll{DataScan(t)->TopN([test.t.a test.t.b],0,5)->Projection->DataScan(s)->TopN([s.a s.b],0,5)->Projection}->TopN([t.a t.b],0,5)",
},
// Test TopN + UA + Proj.
{
sql: "select * from t union all (select * from t s) order by a,b limit 5, 5",
best: "UnionAll{DataScan(t)->TopN([cast(test.t.a) cast(test.t.b)],0,10)->Projection->DataScan(s)->TopN([cast(s.a) cast(s.b)],0,10)->Projection}->TopN([t.a t.b],5,5)",
best: "UnionAll{DataScan(t)->TopN([test.t.a test.t.b],0,10)->Projection->DataScan(s)->TopN([s.a s.b],0,10)->Projection}->TopN([t.a t.b],5,5)",
},
// Test Limit + UA + Proj + Sort.
{
Expand Down
8 changes: 4 additions & 4 deletions plan/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,22 +616,22 @@ func (s *testPlanSuite) TestDAGPlanBuilderUnion(c *C) {
// Test simple union.
{
sql: "select * from t union all select * from t",
best: "UnionAll{TableReader(Table(t))->Projection->TableReader(Table(t))->Projection}",
best: "UnionAll{TableReader(Table(t))->TableReader(Table(t))}",
},
// Test Order by + Union.
{
sql: "select * from t union all (select * from t) order by a ",
best: "UnionAll{TableReader(Table(t))->Projection->TableReader(Table(t))->Projection}->Sort",
best: "UnionAll{TableReader(Table(t))->TableReader(Table(t))}->Sort",
},
// Test Limit + Union.
{
sql: "select * from t union all (select * from t) limit 1",
best: "UnionAll{TableReader(Table(t)->Limit)->Projection->TableReader(Table(t)->Limit)->Projection}->Limit",
best: "UnionAll{TableReader(Table(t)->Limit)->TableReader(Table(t)->Limit)}->Limit",
},
// Test TopN + Union.
{
sql: "select a from t union all (select c from t) order by a limit 1",
best: "UnionAll{TableReader(Table(t))->Projection->TableReader(Table(t))->Projection}->TopN([t.a],0,1)",
best: "UnionAll{TableReader(Table(t)->Limit)->IndexReader(Index(t.c_d_e)[[<nil>,+inf]]->Limit)}->TopN([t.a],0,1)",
},
}
for _, tt := range tests {
Expand Down
2 changes: 1 addition & 1 deletion types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ func NewFieldType(tp byte) *FieldType {

// Equal checks whether two FieldType objects are equal.
func (ft *FieldType) Equal(other *FieldType) bool {
// We do not need to compare `ft.Flag == other.Flag` when wrapping cast upon an Expression.
partialEqual := ft.Tp == other.Tp &&
ft.Flag == other.Flag &&
ft.Flen == other.Flen &&
ft.Decimal == other.Decimal &&
ft.Charset == other.Charset &&
Expand Down

0 comments on commit 5445e17

Please sign in to comment.