From 6ba84eef8ccc6bbcdd2d84a6528efb8e2ace50a0 Mon Sep 17 00:00:00 2001 From: Shenghui Wu <793703860@qq.com> Date: Thu, 3 Mar 2022 14:33:46 +0800 Subject: [PATCH] expression: fine-grained precision infer for decimal arithmetic operator (#32401) close pingcap/tidb#30961 --- expression/builtin_arithmetic.go | 2 +- expression/builtin_arithmetic_test.go | 4 ++-- expression/typeinfer_test.go | 18 +++++++++--------- .../core/testdata/integration_suite_out.json | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/expression/builtin_arithmetic.go b/expression/builtin_arithmetic.go index e76bec6c157ec..681cf34a84c0a 100644 --- a/expression/builtin_arithmetic.go +++ b/expression/builtin_arithmetic.go @@ -111,7 +111,7 @@ func setFlenDecimal4RealOrDecimal(ctx sessionctx.Context, retTp *types.FieldType if isMultiply { digitsInt = a.Flen - a.Decimal + b.Flen - b.Decimal } - retTp.Flen = digitsInt + retTp.Decimal + 3 + retTp.Flen = digitsInt + retTp.Decimal + 1 if isReal { retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxRealWidth) return diff --git a/expression/builtin_arithmetic_test.go b/expression/builtin_arithmetic_test.go index d2426d6f40243..120ce08754cb1 100644 --- a/expression/builtin_arithmetic_test.go +++ b/expression/builtin_arithmetic_test.go @@ -41,7 +41,7 @@ func TestSetFlenDecimal4RealOrDecimal(t *testing.T) { } setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false) require.Equal(t, 1, ret.Decimal) - require.Equal(t, 6, ret.Flen) + require.Equal(t, 4, ret.Flen) b.Flen = 65 setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false) @@ -72,7 +72,7 @@ func TestSetFlenDecimal4RealOrDecimal(t *testing.T) { } setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true) require.Equal(t, 1, ret.Decimal) - require.Equal(t, 8, ret.Flen) + require.Equal(t, 6, ret.Flen) b.Flen = 65 setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index a0f0ddeb44e1e..5ebf61ca8b5e5 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -717,9 +717,9 @@ func (s *InferTypeSuite) createTestCase4ArithmeticFuncs() []typeInferTestCase { {"c_int_d + c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_int_d + c_time_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"c_int_d + c_double_d", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, - {"c_int_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, - {"c_datetime + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, - {"c_bigint_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, + {"c_int_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 24, 3}, + {"c_datetime + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 24, 3}, + {"c_bigint_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 24, 3}, {"c_double_d + c_decimal", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d + c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d + c_enum", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, @@ -729,9 +729,9 @@ func (s *InferTypeSuite) createTestCase4ArithmeticFuncs() []typeInferTestCase { {"c_int_d - c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_int_d - c_time_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"c_int_d - c_double_d", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, - {"c_int_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, - {"c_datetime - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, - {"c_bigint_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, + {"c_int_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 24, 3}, + {"c_datetime - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 24, 3}, + {"c_bigint_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 24, 3}, {"c_double_d - c_decimal", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d - c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d - c_enum", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, @@ -741,9 +741,9 @@ func (s *InferTypeSuite) createTestCase4ArithmeticFuncs() []typeInferTestCase { {"c_int_d * c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_int_d * c_time_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"c_int_d * c_double_d", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, - {"c_int_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 29, 3}, - {"c_datetime * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 31, 5}, - {"c_bigint_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 29, 3}, + {"c_int_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 27, 3}, + {"c_datetime * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 29, 5}, + {"c_bigint_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 27, 3}, {"c_double_d * c_decimal", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d * c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d * c_enum", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index fe09f26006537..6d3a10ddf9fa9 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -4770,8 +4770,8 @@ " └─ExchangeSender 12500.00 cop[tiflash] ExchangeType: PassThrough", " └─HashJoin 12500.00 cop[tiflash] inner join, equal:[eq(Column#13, Column#14) eq(Column#15, Column#16)]", " ├─ExchangeReceiver(Build) 10000.00 cop[tiflash] ", - " │ └─ExchangeSender 10000.00 cop[tiflash] ExchangeType: HashPartition, Hash Cols: [name: Column#21, collate: binary], [name: Column#15, collate: binary]", - " │ └─Projection 10000.00 cop[tiflash] test.t.c1, test.t.c2, Column#13, Column#15, cast(Column#13, decimal(15,8) BINARY)->Column#21", + " │ └─ExchangeSender 10000.00 cop[tiflash] ExchangeType: HashPartition, Hash Cols: [name: Column#21, collate: binary], [name: Column#22, collate: binary]", + " │ └─Projection 10000.00 cop[tiflash] test.t.c1, test.t.c2, Column#13, Column#15, cast(Column#13, decimal(13,8) BINARY)->Column#21, cast(Column#15, decimal(10,5) BINARY)->Column#22", " │ └─Projection 10000.00 cop[tiflash] test.t.c1, test.t.c2, mul(test.t.c1, 3)->Column#13, plus(test.t.c1, 1)->Column#15", " │ └─TableFullScan 10000.00 cop[tiflash] table:t1 keep order:false, stats:pseudo", " └─ExchangeReceiver(Probe) 10000.00 cop[tiflash] ",