diff --git a/cmd/explaintest/r/partition_pruning.result b/cmd/explaintest/r/partition_pruning.result index 651e590accb8d..24af08b1345da 100644 --- a/cmd/explaintest/r/partition_pruning.result +++ b/cmd/explaintest/r/partition_pruning.result @@ -3608,6 +3608,11 @@ id count task operator info TableReader_8 10.00 root data:Selection_7 └─Selection_7 10.00 cop eq(test.t2.a, 833) └─TableScan_6 10000.00 cop table:t2, partition:p4, range:[-inf,+inf], keep order:false, stats:pseudo +explain select * from t2 where a in (10,20,30); +id count task operator info +TableReader_8 30.00 root data:Selection_7 +└─Selection_7 30.00 cop in(test.t2.a, 10, 20, 30) + └─TableScan_6 10000.00 cop table:t2, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo explain select * from t2 where (a = 100 OR a = 900); id count task operator info Union_8 40.00 root diff --git a/cmd/explaintest/t/partition_pruning.test b/cmd/explaintest/t/partition_pruning.test index 6e9eeffce6632..1d55a527d168f 100644 --- a/cmd/explaintest/t/partition_pruning.test +++ b/cmd/explaintest/t/partition_pruning.test @@ -766,6 +766,7 @@ explain select * from t2; explain select * from t2 where a = 101; explain select * from t2 where a = 550; explain select * from t2 where a = 833; +explain select * from t2 where a in (10,20,30); explain select * from t2 where (a = 100 OR a = 900); explain select * from t2 where (a > 100 AND a < 600); explain select * from t2 where b = 4; @@ -1015,4 +1016,3 @@ explain select * from t where ts <= '2020-04-14 23:59:59.123' -- p0,p1; explain select * from t where ts <= '2020-04-25 00:00:00' -- p0,p1,p2; explain select * from t where ts > '2020-04-25 00:00:00' or ts < '2020-01-02 00:00:00' -- p0; explain select * from t where ts > '2020-04-02 00:00:00' and ts < '2020-04-07 00:00:00' -- p0,p1; - diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 045d22fe1265f..1f4b1e5d31f09 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -14,6 +14,8 @@ package expression import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" @@ -38,7 +40,16 @@ type simpleRewriter struct { // The expression string must only reference the column in table Info. func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableInfo *model.TableInfo) (Expression, error) { exprStr = "select " + exprStr - stmts, warns, err := parser.New().Parse(exprStr, "", "") + var stmts []ast.StmtNode + var err error + var warns []error + if p, ok := ctx.(interface { + ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) + }); ok { + stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "") + } else { + stmts, warns, err = parser.New().Parse(exprStr, "", "") + } for _, warn := range warns { ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) } @@ -76,7 +87,16 @@ func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo // The expression string must only reference the column in the given schema. func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *Schema) ([]Expression, error) { exprStr = "select " + exprStr - stmts, warns, err := parser.New().Parse(exprStr, "", "") + var stmts []ast.StmtNode + var err error + var warns []error + if p, ok := ctx.(interface { + ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) + }); ok { + stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "") + } else { + stmts, warns, err = parser.New().Parse(exprStr, "", "") + } for _, warn := range warns { ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 1a055f4fc9395..3d91d6e17af7b 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -255,6 +255,29 @@ func (s *testIntegrationSuite) TestPartitionTableStats(c *C) { } } +func (s *testIntegrationSuite) TestPartitionPruningForInExpr(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int(11), b int) partition by range (a) (partition p0 values less than (4), partition p1 values less than(10), partition p2 values less than maxvalue);") + tk.MustExec("insert into t values (1, 1),(10, 10),(11, 11)") + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) + } +} + func (s *testIntegrationSuite) TestErrNoDB(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("create user test") diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 888e7c094e496..bf41c2e38ddee 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -565,6 +565,9 @@ func partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression, args := op.GetArgs() newRange := partitionRangeForOrExpr(sctx, args[0], args[1], lessThan, col, partFn) return result.intersection(newRange) + } else if op.FuncName.L == ast.In { + newRange := partitionRangeForInExpr(sctx, op.GetArgs(), lessThan, col, partFn) + return result.intersection(newRange) } } @@ -587,6 +590,39 @@ func partitionRangeForOrExpr(sctx sessionctx.Context, expr1, expr2 expression.Ex return tmp1.union(tmp2) } +func partitionRangeForInExpr(sctx sessionctx.Context, args []expression.Expression, + lessThan lessThanData, partCol *expression.Column, partFn *expression.ScalarFunction) partitionRangeOR { + col, ok := args[0].(*expression.Column) + if !ok || col.ID != partCol.ID { + return fullRange(lessThan.length()) + } + + var result partitionRangeOR + unsigned := mysql.HasUnsignedFlag(col.RetType.Flag) + for i := 1; i < len(args); i++ { + constExpr, ok := args[i].(*expression.Constant) + if !ok { + return fullRange(lessThan.length()) + } + switch constExpr.Value.Kind() { + case types.KindInt64, types.KindUint64: + case types.KindNull: + result = append(result, partitionRange{0, 1}) + continue + default: + return fullRange(lessThan.length()) + } + val, err := constExpr.Value.ToInt64(sctx.GetSessionVars().StmtCtx) + if err != nil { + return fullRange(lessThan.length()) + } + + start, end := pruneUseBinarySearch(lessThan, dataForPrune{op: ast.EQ, c: val}, unsigned) + result = append(result, partitionRange{start, end}) + } + return result.simplify() +} + // monotoneIncFuncs are those functions that for any x y, if x > y => f(x) > f(y) var monotoneIncFuncs = map[string]struct{}{ ast.ToDays: {}, diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index 79872319d71f9..a7aca433780e3 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -40,5 +40,17 @@ "explain select * from t order by a limit 3", "select * from t order by a limit 3" ] + }, + { + "name": "TestPartitionPruningForInExpr", + "cases": [ + "explain select * from t where a in (1, 2,'11')", + "explain select * from t where a in (17, null)", + "explain select * from t where a in (16, 'abc')", + "explain select * from t where a in (15, 0.12, 3.47)", + "explain select * from t where a in (0.12, 3.47)", + "explain select * from t where a in (14, floor(3.47))", + "explain select * from t where b in (3, 4)" + ] } ] diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 4d881172525d9..47e38927d2bd1 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -111,5 +111,94 @@ ] } ] + }, + { + "Name": "TestPartitionPruningForInExpr", + "Cases": [ + { + "SQL": "explain select * from t where a in (1, 2,'11')", + "Plan": [ + "Union_8 60.00 root ", + "├─TableReader_11 30.00 root data:Selection_10", + "│ └─Selection_10 30.00 cop in(test.t.a, 1, 2, 11)", + "│ └─TableScan_9 10000.00 cop table:t, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_14 30.00 root data:Selection_13", + " └─Selection_13 30.00 cop in(test.t.a, 1, 2, 11)", + " └─TableScan_12 10000.00 cop table:t, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain select * from t where a in (17, null)", + "Plan": [ + "Union_8 20.00 root ", + "├─TableReader_11 10.00 root data:Selection_10", + "│ └─Selection_10 10.00 cop in(test.t.a, 17, NULL)", + "│ └─TableScan_9 10000.00 cop table:t, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_14 10.00 root data:Selection_13", + " └─Selection_13 10.00 cop in(test.t.a, 17, NULL)", + " └─TableScan_12 10000.00 cop table:t, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain select * from t where a in (16, 'abc')", + "Plan": [ + "Union_8 40.00 root ", + "├─TableReader_11 20.00 root data:Selection_10", + "│ └─Selection_10 20.00 cop in(test.t.a, 16, 0)", + "│ └─TableScan_9 10000.00 cop table:t, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_14 20.00 root data:Selection_13", + " └─Selection_13 20.00 cop in(test.t.a, 16, 0)", + " └─TableScan_12 10000.00 cop table:t, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain select * from t where a in (15, 0.12, 3.47)", + "Plan": [ + "Union_9 30.00 root ", + "├─TableReader_12 10.00 root data:Selection_11", + "│ └─Selection_11 10.00 cop or(eq(test.t.a, 15), 0)", + "│ └─TableScan_10 10000.00 cop table:t, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + "├─TableReader_15 10.00 root data:Selection_14", + "│ └─Selection_14 10.00 cop or(eq(test.t.a, 15), 0)", + "│ └─TableScan_13 10000.00 cop table:t, partition:p1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_18 10.00 root data:Selection_17", + " └─Selection_17 10.00 cop or(eq(test.t.a, 15), 0)", + " └─TableScan_16 10000.00 cop table:t, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain select * from t where a in (0.12, 3.47)", + "Plan": [ + "TableDual_6 0.00 root rows:0" + ] + }, + { + "SQL": "explain select * from t where a in (14, floor(3.47))", + "Plan": [ + "Union_8 40.00 root ", + "├─TableReader_11 20.00 root data:Selection_10", + "│ └─Selection_10 20.00 cop in(test.t.a, 14, 3)", + "│ └─TableScan_9 10000.00 cop table:t, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_14 20.00 root data:Selection_13", + " └─Selection_13 20.00 cop in(test.t.a, 14, 3)", + " └─TableScan_12 10000.00 cop table:t, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain select * from t where b in (3, 4)", + "Plan": [ + "Union_9 60.00 root ", + "├─TableReader_12 20.00 root data:Selection_11", + "│ └─Selection_11 20.00 cop in(test.t.b, 3, 4)", + "│ └─TableScan_10 10000.00 cop table:t, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + "├─TableReader_15 20.00 root data:Selection_14", + "│ └─Selection_14 20.00 cop in(test.t.b, 3, 4)", + "│ └─TableScan_13 10000.00 cop table:t, partition:p1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_18 20.00 root data:Selection_17", + " └─Selection_17 20.00 cop in(test.t.b, 3, 4)", + " └─TableScan_16 10000.00 cop table:t, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + } + ] } ]