diff --git a/expression/builtin_string.go b/expression/builtin_string.go index d07536c42bee5..749d3453d0775 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -180,6 +180,11 @@ func SetBinFlagOrBinStr(argTp *types.FieldType, resTp *types.FieldType) { } } +// addBinFlag add the binary flag to `tp` if its charset is binary +func addBinFlag(tp *types.FieldType) { + SetBinFlagOrBinStr(tp, tp) +} + type lengthFunctionClass struct { baseFunctionClass } @@ -275,10 +280,10 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express if err != nil { return nil, err } + addBinFlag(bf.tp) bf.tp.Flen = 0 for i := range args { argType := args[i].GetType() - SetBinFlagOrBinStr(argType, bf.tp) if argType.Flen < 0 { bf.tp.Flen = mysql.MaxBlobWidth @@ -350,9 +355,9 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } bf.tp.Flen = 0 + addBinFlag(bf.tp) for i := range args { argType := args[i].GetType() - SetBinFlagOrBinStr(argType, bf.tp) // skip separator param if i != 0 { @@ -2000,8 +2005,7 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio return nil, err } bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) @@ -2133,8 +2137,7 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio return nil, err } bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) @@ -2668,9 +2671,7 @@ func (c *makeSetFunctionClass) getFunction(ctx sessionctx.Context, args []Expres if err != nil { return nil, err } - for i, length := 0, len(args); i < length; i++ { - SetBinFlagOrBinStr(args[i].GetType(), bf.tp) - } + addBinFlag(bf.tp) bf.tp.Flen = c.getFlen(bf.ctx, args) if bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth @@ -3589,8 +3590,7 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express return nil, err } bf.tp.Flen = mysql.MaxBlobWidth - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[3].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 8a46b32664c24..d1d5d83210d48 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -240,8 +240,14 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"strcmp(c_char, c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, {"space(c_int_d)", mysql.TypeLongBlob, mysql.DefaultCharset, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"CONCAT(c_binary, c_int_d)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 40, types.UnspecifiedLength}, +<<<<<<< HEAD {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 40, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 4, types.UnspecifiedLength}, +======= + {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"CONCAT(c_bchar, 0x80)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 23, types.UnspecifiedLength}, + {"CONCAT('T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 4, types.UnspecifiedLength}, +>>>>>>> 7229416ba... functions: fix some string function has wrong collation and flag (#23835) {"CONCAT('T', 'i', 'DB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, {"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 6, types.UnspecifiedLength}, {"CONCAT_WS(',', 'TiDB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 25, types.UnspecifiedLength}, @@ -451,8 +457,9 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"find_in_set(c_set , c_text_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, {"find_in_set(c_enum , c_text_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, - {"make_set(c_int_d , c_text_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 65535, types.UnspecifiedLength}, + {"make_set(c_int_d , c_text_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 65535, types.UnspecifiedLength}, {"make_set(c_bigint_d , c_text_d, c_binary)", mysql.TypeMediumBlob, charset.CharsetBin, mysql.BinaryFlag, 65556, types.UnspecifiedLength}, + {"make_set(1 , c_text_d, 0x40)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 65535, types.UnspecifiedLength}, {"quote(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 42, types.UnspecifiedLength}, {"quote(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 42, types.UnspecifiedLength}, @@ -465,6 +472,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"convert(c_text_d using 'binary')", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_varchar, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"insert(c_varchar, c_int_d, c_int_d, 0x40)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_varchar, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_binary, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_binary, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength},