From e0908fda7f4509092a3d4ce570c0bd02fb003f65 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 15 Dec 2022 17:45:39 +0800 Subject: [PATCH 01/25] multi distinct agg initial commit Signed-off-by: AilinKid <314806019@qq.com> --- executor/partition_table.go | 2 + expression/aggregation/descriptor.go | 2 + expression/constant.go | 12 + expression/grouping_sets.go | 327 +++++++++++++ expression/grouping_sets_test.go | 142 ++++++ go.mod | 1 + planner/core/casetest/enforce_mpp_test.go | 58 +++ .../testdata/enforce_mpp_suite_in.json | 19 + .../testdata/enforce_mpp_suite_out.json | 188 ++++++++ planner/core/exhaust_physical_plans.go | 3 + planner/core/explain.go | 12 + planner/core/find_best_task.go | 2 +- planner/core/fragment.go | 5 +- planner/core/physical_plans.go | 59 ++- planner/core/plan_to_pb.go | 23 + planner/core/resolve_indices.go | 27 +- planner/core/task.go | 448 +++++++++++++++--- store/copr/mpp.go | 2 +- store/mockstore/unistore/cophandler/mpp.go | 82 +++- .../mockstore/unistore/cophandler/mpp_exec.go | 137 +++++- store/mockstore/unistore/rpc.go | 2 + util/plancodec/id.go | 7 + 22 files changed, 1463 insertions(+), 97 deletions(-) create mode 100644 expression/grouping_sets.go create mode 100644 expression/grouping_sets_test.go diff --git a/executor/partition_table.go b/executor/partition_table.go index de6f17e8d2cc0..7397275aec539 100644 --- a/executor/partition_table.go +++ b/executor/partition_table.go @@ -56,6 +56,8 @@ func updateExecutorTableID(ctx context.Context, exec *tipb.Executor, recursive b child = exec.Window.Child case tipb.ExecType_TypeSort: child = exec.Sort.Child + case tipb.ExecType_TypeRepeatSource: + child = exec.RepeatSource.Child default: return errors.Trace(fmt.Errorf("unknown new tipb protocol %d", exec.Tp)) } diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 2af7d4d49f348..60d23072cd399 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -42,6 +42,8 @@ type AggFuncDesc struct { HasDistinct bool // OrderByItems represents the order by clause used in GROUP_CONCAT OrderByItems []*util.ByItems + // Grouping Set ID, for distinguishing with not-set 0, starting from 1. + GroupingID int64 } // NewAggFuncDesc creates an aggregation function signature descriptor. diff --git a/expression/constant.go b/expression/constant.go index db9f94147f38a..cb2b5e5aac2e4 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -52,6 +52,18 @@ func NewZero() *Constant { } } +// NewNumber stands for constant of a given number. +func NewNumber(num int64) *Constant { + retT := types.NewFieldType(mysql.TypeLonglong) + retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion + retT.SetFlen(mysql.MaxIntWidth) + retT.SetDecimal(0) + return &Constant{ + Value: types.NewDatum(num), + RetType: retT, + } +} + // NewNull stands for null constant. func NewNull() *Constant { retT := types.NewFieldType(mysql.TypeTiny) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go new file mode 100644 index 0000000000000..28a24928087ef --- /dev/null +++ b/expression/grouping_sets.go @@ -0,0 +1,327 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "strings" + + "github.com/pingcap/tidb/kv" + fd "github.com/pingcap/tidb/planner/funcdep" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/util/size" + "github.com/pingcap/tipb/go-tipb" +) + +// Merge function will explore the internal grouping expressions and try to find the minimum grouping sets. (prefix merging) +func (gs GroupingSets) Merge() GroupingSets { + // for now, there is precondition that all grouping expressions are columns. + // for example: (a,b,c) and (a,b) and (a) will be merged as a one. + // Eg: + // before merging, there are 3 grouping sets. + // GroupingSets: + // [ + // [[a,b,c],] // every set including a grouping Expressions for initial. + // [[a,b],] + // [[a],] + // [[e],] + // ] + // + // after merging, there is only 1 grouping set. + // GroupingSets: + // [ + // [[a],[a,b],[a,b,c],] // one set including 3 grouping Expressions after merging if possible. + // [[e],] + // ] + // + // care about the prefix order, which should be taken as following the group layout expanding rule. (simple way) + // [[a],[b,a],[c,b,a],] is also conforming the rule, gradually including one/more column(s) inside for one time. + + newGroupingSets := make(GroupingSets, 0, len(gs)) + for _, oneGroupingSet := range gs { + for _, oneGroupingExpr := range oneGroupingSet { + if len(newGroupingSets) == 0 { + // means there is nothing in new grouping sets, adding current one anyway. + newGroupingSets = append(newGroupingSets, newGroupingSet(oneGroupingExpr)) + continue + } + newGroupingSets = newGroupingSets.mergeOne(oneGroupingExpr) + } + } + return newGroupingSets +} + +func (gs GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { + // for every existing grouping set, check the grouping-exprs inside and whether the current grouping-exprs is + // super-set of it or sub-set of it, adding current one to the correct position of grouping-exprs slice. + // + // [[a,b]] + // | + // / offset 0 + // | + // [b] when adding [b] grouping expr here, since it's a sub-set of current [a,b] with offset 0, take the offset 0. + // | + // insert with offset 0, and the other elements move right. + // + // [[b], [a,b]] + // offset 0 1 + // \ + // [a,b,c,d] + // when adding [a,b,c,d] grouping expr here, since it's a super-set of current [a,b] with offset 1, take the offset as 1+1. + // + // result grouping set: [[b], [a,b], [a,b,c,d]], expanding with step as two or more columns is acceptable and reasonable. + // | | | + // +----+-------+ every previous one is the subset of the latter one. + // + for i, oneNewGroupingSet := range gs { + // for every group set,try to find its position to insert if possible,otherwise create a new grouping set. + for j := len(oneNewGroupingSet) - 1; j >= 0; j-- { + cur := oneNewGroupingSet[j] + if targetOne.SubSetOf(cur) { + if j == 0 { + // the right pos should be the head (-1) + cp := make(GroupingSet, 0, len(oneNewGroupingSet)+1) + cp = append(cp, targetOne) + cp = append(cp, oneNewGroupingSet...) + gs[i] = cp + return gs + } + // do the left shift to find the right insert pos. + continue + } + if j == len(oneNewGroupingSet)-1 { + // which means the targetOne can't fit itself in this grouping set, continue next grouping set. + break + } else { + // successfully fit in, current j is the right pos to insert. + cp := make(GroupingSet, 0, len(oneNewGroupingSet)+1) + cp = append(cp, oneNewGroupingSet[:j+1]...) + cp = append(cp, targetOne) + cp = append(cp, oneNewGroupingSet[j+1:]...) + gs[i] = cp + return gs + } + } + } + // here means we couldn't find even one GroupingSet to fill the targetOne, creating a new one. + gs = append(gs, newGroupingSet(targetOne)) + // gs is a alias of slice [], we should return it back after being changed. + return gs +} + +// TargetOne is used to find a valid group layout for normal agg. +func (gs GroupingSets) TargetOne(targetOne GroupingExprs) int { + // it has three cases. + // 1: group sets {}, {} when the normal agg(d), agg(d) can only occur in the group of or after tuple split. (normal column are appended) + // 2: group sets {}, {} when the normal agg(c), agg(c) can only be found in group of , since it will be filled with null in group of + // 3: group sets {}, {} when the normal agg with multi col args, it will be a little difficult, which is banned in canUse3Stage4MultiDistinctAgg by now. + // eg1: agg(c,d), the c and d can only be found in group of in which d is also attached, while c will be filled with null in group of + // eg2: agg(b,c,d), we couldn't find a valid group either in or , unless we copy one column from b and attach it to the group data of . (cp c to is also effective) + // + // from the theoretical, why we have to fill the non-current-group-column with null value? even we may fill it as null value and copy it again like c in case 3, + // the basic reason is that we need a unified group layout sequence to feed different layout-required distinct agg. Like what we did here, we should group the + // original source with sequence columns as . For distinct(a,b), we don't wanna c in this group layout to impact what we are targeting for --- the + // real group from . So even we had the groupingID in repeated data row to identify which distinct agg this row is prepared for, we still need to fill the c + // as null and groupingID as 1 to guarantee group operator can equate to group which is equal to group, that's what the upper layer + // distinct(a,b) want for. For distinct(c), the same is true. + // + // For normal agg you better choose your targeting group-set data, otherwise, all your get is endless null values filled by repeatSource operator. + // + // here we only consider 1&2 since normal agg with multi col args is banned. + allGroupingColIDs := gs.allSetsColIDs() + for idx, groupingSet := range gs { + // diffCols are those columns being filled with null in the group row data of current grouping set. + diffCols := allGroupingColIDs.Difference(*groupingSet.allSetColIDs()) + if diffCols.Intersects(*targetOne.IDSet()) { + // normal agg's target arg columns are being filled with null value in this grouping set, continue next grouping set check. + continue + } + return idx + } + // todo: if we couldn't find a existed current valid group layout, we need to copy the column out from being filled with null value. + return -1 +} + +func (gs GroupingSet) IsEmpty() bool { + if len(gs) == 0 { + return true + } + for _, g := range gs { + if !g.IsEmpty() { + return false + } + } + return true +} + +func (gs GroupingSet) allSetColIDs() *fd.FastIntSet { + res := fd.NewFastIntSet() + for _, groupingExprs := range gs { + for _, one := range groupingExprs { + res.Insert(int(one.(*Column).UniqueID)) + } + } + return &res +} + +func (gs GroupingSet) Clone() GroupingSet { + gc := make(GroupingSet, 0, len(gs)) + for _, one := range gs { + gc = append(gc, one.Clone()) + } + return gc +} + +func (gs GroupingSet) String() string { + var str strings.Builder + str.WriteString("{") + for i, one := range gs { + if i != 0 { + str.WriteString(",") + } + str.WriteString(one.String()) + } + str.WriteString("}") + return str.String() +} + +func (gs GroupingSet) MemoryUsage() int64 { + sum := size.SizeOfSlice + int64(cap(gs))*size.SizeOfPointer + for _, one := range gs { + sum += one.MemoryUsage() + } + return sum +} + +func (gs GroupingSet) ToPB(sc *stmtctx.StatementContext, client kv.Client) (*tipb.GroupingSet, error) { + res := &tipb.GroupingSet{} + for _, gExprs := range gs { + gExprsPB, err := ExpressionsToPBList(sc, gExprs, client) + if err != nil { + return nil, err + } + res.GroupingExprs = append(res.GroupingExprs, &tipb.GroupingExpr{GroupingExpr: gExprsPB}) + } + return res, nil +} + +func (gss GroupingSets) IsEmpty() bool { + if len(gss) == 0 { + return true + } + for _, gs := range gss { + if !gs.IsEmpty() { + return false + } + } + return true +} + +func (gss GroupingSets) allSetsColIDs() *fd.FastIntSet { + res := fd.NewFastIntSet() + for _, groupingSet := range gss { + res.UnionWith(*groupingSet.allSetColIDs()) + } + return &res +} + +func (gss GroupingSets) String() string { + var str strings.Builder + str.WriteString("[") + for i, gs := range gss { + if i != 0 { + str.WriteString(",") + } + str.WriteString(gs.String()) + } + str.WriteString("]") + return str.String() +} + +func (gss GroupingSets) ToPB(sc *stmtctx.StatementContext, client kv.Client) ([]*tipb.GroupingSet, error) { + res := make([]*tipb.GroupingSet, 0, len(gss)) + for _, gs := range gss { + one, err := gs.ToPB(sc, client) + if err != nil { + return nil, err + } + res = append(res, one) + } + return res, nil +} + +func newGroupingSet(oneGroupingExpr GroupingExprs) GroupingSet { + res := make(GroupingSet, 0, 1) + res = append(res, oneGroupingExpr) + return res +} + +type GroupingSets []GroupingSet + +type GroupingSet []GroupingExprs + +type GroupingExprs []Expression + +func (g GroupingExprs) IsEmpty() bool { + return len(g) == 0 +} + +func (g GroupingExprs) SubSetOf(other GroupingExprs) bool { + oldOne := fd.NewFastIntSet() + newOne := fd.NewFastIntSet() + for _, one := range g { + oldOne.Insert(int(one.(*Column).UniqueID)) + } + for _, one := range other { + newOne.Insert(int(one.(*Column).UniqueID)) + } + return oldOne.SubsetOf(newOne) +} + +func (g GroupingExprs) IDSet() *fd.FastIntSet { + res := fd.NewFastIntSet() + for _, one := range g { + res.Insert(int(one.(*Column).UniqueID)) + } + return &res +} + +func (g GroupingExprs) Clone() GroupingExprs { + gc := make(GroupingExprs, 0, len(g)) + for _, one := range g { + gc = append(gc, one.Clone()) + } + return gc +} + +func (g GroupingExprs) String() string { + var str strings.Builder + str.WriteString("<") + for i, one := range g { + if i != 0 { + str.WriteString(",") + } + str.WriteString(one.String()) + } + str.WriteString(">") + return str.String() +} + +func (g GroupingExprs) MemoryUsage() int64 { + sum := size.SizeOfSlice + int64(cap(g))*size.SizeOfInterface + for _, one := range g { + sum += one.MemoryUsage() + } + return sum +} diff --git a/expression/grouping_sets_test.go b/expression/grouping_sets_test.go new file mode 100644 index 0000000000000..9a73abe47d186 --- /dev/null +++ b/expression/grouping_sets_test.go @@ -0,0 +1,142 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.opencensus.io/stats/view" +) + +func TestGroupSetsTargetOne(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + } + b := &Column{ + UniqueID: 2, + } + c := &Column{ + UniqueID: 3, + } + d := &Column{ + UniqueID: 4, + } + // non-merged group sets + // 1: group sets {}, {} when the normal agg(d), d can occur in either group of or after tuple split. (normal column are appended) + // 2: group sets {}, {} when the normal agg(c), c can only be found in group of , since it will be filled with null in group of + // 3: group sets {}, {} when the normal agg with multi col args, it will be a little difficult, which is banned in canUse3Stage4MultiDistinctAgg by now. + newGroupingSets := make(GroupingSets, 0, 3) + groupingExprs1 := []Expression{a, b} + groupingExprs2 := []Expression{c} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs1}) + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs2}) + + targetOne := []Expression{d} + offset := newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // and both ok + + targetOne = []Expression{c} + offset = newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 1) // only is ok + + targetOne = []Expression{b} + offset = newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // only is ok + + targetOne = []Expression{a} + offset = newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // only is ok + + targetOne = []Expression{b, c} + offset = newGroupingSets.TargetOne(targetOne) + require.Equal(t, offset, -1) // no valid one can be found. + + // merged group sets + // merged group sets {, } when the normal agg(d), + // from prospect from data grouping, we need keep the widest grouping layout when trying + // to shuffle data targeting for grouping layout {, }; Additional, since column b + // hasn't been pre-agg or something because current layout is , we should keep all the + // complete rows until to the upper receiver, which means there is no agg/group operator before + // shuffler take place. +} + +func TestGroupingSetsMergeUnitTest(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + } + b := &Column{ + UniqueID: 2, + } + c := &Column{ + UniqueID: 3, + } + d := &Column{ + UniqueID: 4, + } + // [[c,d,a,b], [b], [a,b]] + newGroupingSets := make(GroupingSets, 0, 3) + groupingExprs := []Expression{c, d, a, b} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{b} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{a, b} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + + // [[b], [a,b], [c,d,a,b]] + newGroupingSets = newGroupingSets.Merge() + require.Equal(t, len(newGroupingSets), 1) + // only one grouping set with 3 grouping expressions. + require.Equal(t, len(newGroupingSets[0]), 3) + // [b] + require.Equal(t, len(newGroupingSets[0][0]), 1) + // [a,b] + require.Equal(t, len(newGroupingSets[0][1]), 2) + // [c,d,a,b]] + require.Equal(t, len(newGroupingSets[0][2]), 4) + + // [ + // [[a],[b,a],[c,b,a],] // one set including 3 grouping Expressions after merging if possible. + // [[d],] + // ] + newGroupingSets = newGroupingSets[:0] + groupingExprs = []Expression{c, b, a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{b, a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{d} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + + // [[a],[b,a],[c,b,a],] + // [[d],] + newGroupingSets = newGroupingSets.Merge() + require.Equal(t, len(newGroupingSets), 2) + // [a] + require.Equal(t, len(newGroupingSets[0][0]), 1) + // [b,a] + require.Equal(t, len(newGroupingSets[0][1]), 2) + // [c,b,a] + require.Equal(t, len(newGroupingSets[0][2]), 3) + // [d] + require.Equal(t, len(newGroupingSets[1][0]), 1) +} diff --git a/go.mod b/go.mod index c3948c93c1d05..016d1fb1a9e6e 100644 --- a/go.mod +++ b/go.mod @@ -275,4 +275,5 @@ replace ( github.com/dgrijalva/jwt-go => github.com/form3tech-oss/jwt-go v3.2.6-0.20210809144907-32ab6a8243d7+incompatible github.com/pingcap/tidb/parser => ./parser go.opencensus.io => go.opencensus.io v0.23.1-0.20220331163232-052120675fac + github.com/pingcap/tipb => ../tipb ) diff --git a/planner/core/casetest/enforce_mpp_test.go b/planner/core/casetest/enforce_mpp_test.go index aaa2037efc0ce..043256b25c5dd 100644 --- a/planner/core/casetest/enforce_mpp_test.go +++ b/planner/core/casetest/enforce_mpp_test.go @@ -487,3 +487,61 @@ func TestMPPSingleDistinct3Stage(t *testing.T) { require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) } } + +// todo: some post optimization after resolveIndices will inject another projection below agg, which change the column name used in higher operator, +// +// since it doesn't change the schema out (index ref is still the right), so by now it's fine. SEE case: EXPLAIN select count(distinct a), count(distinct b), sum(c) from t. +func TestMPPMultiDistinct3Stage(t *testing.T) { + store := testkit.CreateMockStore(t, withMockTiFlash(2)) + tk := testkit.NewTestKit(t, store) + + // test table + tk.MustExec("use test;") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int, d int);") + tk.MustExec("alter table t set tiflash replica 1") + tb := external.GetTableByName(t, tk, "test", "t") + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\";") + tk.MustExec("set @@session.tidb_enforce_mpp=1") + tk.MustExec("set @@session.tidb_allow_mpp=ON;") + // todo: current mock regionCache won't scale the regions among tiFlash nodes. The under layer still collect data from only one of the nodes. + tk.MustExec("split table t BETWEEN (0) AND (5000) REGIONS 5;") + tk.MustExec("insert into t values(1000, 1000, 1000, 1)") + tk.MustExec("insert into t values(1000, 1000, 1000, 1)") + tk.MustExec("insert into t values(2000, 2000, 2000, 1)") + tk.MustExec("insert into t values(2000, 2000, 2000, 1)") + tk.MustExec("insert into t values(3000, 3000, 3000, 1)") + tk.MustExec("insert into t values(3000, 3000, 3000, 1)") + tk.MustExec("insert into t values(4000, 4000, 4000, 1)") + tk.MustExec("insert into t values(4000, 4000, 4000, 1)") + tk.MustExec("insert into t values(5000, 5000, 5000, 1)") + tk.MustExec("insert into t values(5000, 5000, 5000, 1)") + + var input []string + var output []struct { + SQL string + Plan []string + Warn []string + } + enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData.LoadTestCases(t, &input, &output) + for i, tt := range input { + testdata.OnRecord(func() { + output[i].SQL = tt + }) + if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") { + tk.MustExec(tt) + continue + } + testdata.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) + } +} diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_in.json b/planner/core/casetest/testdata/enforce_mpp_suite_in.json index fc6d31dec9da9..18f1da4a8578e 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_in.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_in.json @@ -132,5 +132,24 @@ "EXPLAIN select count(distinct c+a), count(a) from t;", "EXPLAIN select sum(b), count(distinct c+a, b, e), count(a+b) from t;" ] + }, + { + "name": "TestMPPMultiDistinct3Stage", + "cases": [ + "EXPLAIN select count(distinct a) from t", + "select count(distinct a) from t", + "EXPLAIN select count(distinct a), count(distinct b) from t", + "select count(distinct a), count(distinct b) from t", + "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", + "select count(distinct a), count(distinct b), count(c) from t", + "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", + "select count(distinct a), count(distinct b), sum(c) from t", + "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "EXPLAIN select count(distinct a+b), sum(c) from t", + "select count(distinct a+b), sum(c) from t", + "EXPLAIN select count(distinct a+b), count(distinct b+c), count(c) from t", + "select count(distinct a+b), count(distinct b+c), count(c) from t" + ] } ] diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/testdata/enforce_mpp_suite_out.json index 483291bfabd22..f97fe2fdb4d6f 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_out.json @@ -1203,5 +1203,193 @@ "Warn": null } ] + }, + { + "Name": "TestMPPMultiDistinct3Stage", + "Cases": [ + { + "SQL": "EXPLAIN select count(distinct a) from t", + "Plan": [ + "TableReader_30 1.00 root data:ExchangeSender_29", + "└─ExchangeSender_29 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_23 1.00 mpp[tiflash] Column#6", + " └─HashAgg_24 1.00 mpp[tiflash] funcs:sum(Column#8)->Column#6", + " └─ExchangeReceiver_28 1.00 mpp[tiflash] ", + " └─ExchangeSender_27 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_24 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#8", + " └─ExchangeReceiver_26 1.00 mpp[tiflash] ", + " └─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary]", + " └─HashAgg_22 1.00 mpp[tiflash] group by:test.t.a, ", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a) from t", + "Plan": [ + "5" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b) from t", + "Plan": [ + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#6, funcs:sum(Column#16)->Column#7", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", + " └─Projection_30 1.00 mpp[tiflash] test.t.a, test.t.b, Column#13, case(eq(Column#13, 1), test.t.a, )->Column#15, case(eq(Column#13, 2), test.t.b, )->Column#17", + " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", + " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#13, collate: binary]", + " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#13, test.t.a, test.t.b, ", + " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b) from t", + "Plan": [ + "5 5" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", + "Plan": [ + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", + " └─Projection_30 1.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", + " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", + " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", + " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#18, test.t.a, test.t.b, funcs:count(Column#19)->Column#17", + " └─Projection_29 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", + " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), count(c) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", + "Plan": [ + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", + " └─Projection_30 1.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", + " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", + " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", + " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#26, Column#27, Column#28, funcs:sum(Column#25)->Column#17", + " └─Projection_37 20000.00 mpp[tiflash] cast(Column#19, decimal(10,0) BINARY)->Column#25, test.t.a, test.t.b, Column#18", + " └─Projection_29 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", + " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), sum(c) from t", + "Plan": [ + "5 5 30000" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "Plan": [ + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8, Column#9", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#26)->Column#6, funcs:sum(Column#28)->Column#7, funcs:sum(Column#30)->Column#8, funcs:sum(Column#31)->Column#9", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#27, test.t.b)->Column#26, funcs:count(distinct Column#29)->Column#28, funcs:sum(Column#21)->Column#30, funcs:sum(Column#22)->Column#31", + " └─Projection_30 1.00 mpp[tiflash] Column#21, Column#22, test.t.a, test.t.b, Column#23, case(eq(Column#23, 1), test.t.a, )->Column#27, case(eq(Column#23, 2), test.t.b, )->Column#29", + " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", + " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#23, collate: binary]", + " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#34, Column#35, Column#36, funcs:count(Column#32)->Column#21, funcs:sum(Column#33)->Column#22", + " └─Projection_37 20000.00 mpp[tiflash] Column#24, cast(Column#25, decimal(10,0) BINARY)->Column#33, test.t.a, test.t.b, Column#23", + " └─Projection_29 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#23, case(eq(Column#23, 1), test.t.c, )->Column#24, case(eq(Column#23, 1), test.t.d, )->Column#25", + " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#23, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": [ + "Some grouping sets should be merged", + "Some grouping sets should be merged" + ] + }, + { + "SQL": "select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "Plan": [ + "5 5 10 10" + ], + "Warn": [ + "Some grouping sets should be merged", + "Some grouping sets should be merged" + ] + }, + { + "SQL": "EXPLAIN select count(distinct a+b), sum(c) from t", + "Plan": [ + "TableReader_26 1.00 root data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct Column#10)->Column#6, funcs:sum(Column#11)->Column#7", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:Column#13, funcs:sum(Column#12)->Column#11", + " └─Projection_27 10000.00 mpp[tiflash] cast(test.t.c, decimal(10,0) BINARY)->Column#12, plus(test.t.a, test.t.b)->Column#13", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a+b), sum(c) from t", + "Plan": [ + "5 30000" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a+b), count(distinct b+c), count(c) from t", + "Plan": [ + "TableReader_26 1.00 root data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct Column#12)->Column#6, funcs:count(distinct Column#13)->Column#7, funcs:sum(Column#14)->Column#8", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:Column#16, Column#17, funcs:count(Column#15)->Column#14", + " └─Projection_27 10000.00 mpp[tiflash] test.t.c, plus(test.t.a, test.t.b)->Column#16, plus(test.t.b, test.t.c)->Column#17", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a+b), count(distinct b+c), count(c) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + } + ] } ] diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index f9dc616bb19f5..a21d19960f461 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2870,6 +2870,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert partitionCols := la.GetPotentialPartitionKeys() // trying to match the required partitions. if prop.MPPPartitionTp == property.HashType { + // partition key required by upper layer is subset of current layout. if matches := prop.IsSubsetOf(partitionCols); len(matches) != 0 { partitionCols = choosePartitionKeys(partitionCols, matches) } else { @@ -2896,6 +2897,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert } // 2-phase agg + // no partition property down,record partition cols inside agg itself, enforce shuffler latter. childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, RejectSort: true} agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) agg.SetSchema(la.schema.Clone()) @@ -2917,6 +2919,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) agg.SetSchema(la.schema.Clone()) if la.HasDistinct() || la.HasOrderBy() { + // mpp scalar mode means the data will be pass through to only one tiFlash node at last. agg.MppRunMode = MppScalar } else { agg.MppRunMode = MppTiDB diff --git a/planner/core/explain.go b/planner/core/explain.go index 93d3533bbe3cc..bca27228ed195 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -359,6 +359,18 @@ func (p *PhysicalLimit) ExplainInfo() string { return str.String() } +// ExplainInfo implements Plan interface. +func (p *PhysicalRepeatSource) ExplainInfo() string { + var str strings.Builder + str.WriteString("group set num:") + str.WriteString(strconv.FormatInt(int64(len(p.GroupingSets)), 10)) + str.WriteString(", groupingID:Column#") + str.WriteString(strconv.FormatInt(p.GroupingIDCol.UniqueID, 10)) + str.WriteString(", ") + str.WriteString(p.GroupingSets.String()) + return str.String() +} + // ExplainInfo implements Plan interface. func (p *basePhysicalAgg) ExplainInfo() string { return p.explainInfo(false) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 47a2027e589ed..8118aa07d1f51 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -2452,7 +2452,7 @@ func (p *LogicalCTE) findBestTask(prop *property.PhysicalProperty, _ *PlanCounte // The physical plan has been build when derive stats. pcte := PhysicalCTE{SeedPlan: p.cte.seedPartPhysicalPlan, RecurPlan: p.cte.recursivePartPhysicalPlan, CTE: p.cte, cteAsName: p.cteAsName, cteName: p.cteName}.Init(p.ctx, p.stats) pcte.SetSchema(p.schema) - t = &rootTask{pcte, false} + t = &rootTask{p: pcte, isEmpty: false} if prop.CanAddEnforcer { t = enforceProperty(prop, t, p.basePlan.ctx) } diff --git a/planner/core/fragment.go b/planner/core/fragment.go index 1a4f5e6c61f9b..a712f93b418ae 100644 --- a/planner/core/fragment.go +++ b/planner/core/fragment.go @@ -151,6 +151,7 @@ func (e *mppTaskGenerator) constructMPPTasksByChildrenTasks(tasks []*kv.MPPTask) newTasks := make([]*kv.MPPTask, 0, len(tasks)) for _, task := range tasks { addr := task.Meta.GetAddress() + // for upper fragment, the task num is equal to address num covered by lower tasks _, ok := addressMap[addr] if !ok { mppTask := &kv.MPPTask{ @@ -193,7 +194,7 @@ func (f *Fragment) init(p PhysicalPlan) error { // We would remove all the union-all operators by 'untwist'ing and copying the plans above union-all. // This will make every route from root (ExchangeSender) to leaf nodes (ExchangeReceiver and TableScan) -// a new ioslated tree (and also a fragment) without union all. These trees (fragments then tasks) will +// a new isolated tree (and also a fragment) without union all. These trees (fragments then tasks) will // finally be gathered to TiDB or be exchanged to upper tasks again. // For instance, given a plan "select c1 from t union all select c1 from s" // after untwist, there will be two plans in `forest` slice: @@ -278,6 +279,7 @@ func (e *mppTaskGenerator) generateMPPTasksForExchangeSender(s *PhysicalExchange } results := make([]*kv.MPPTask, 0, len(frags)) for _, f := range frags { + // from the bottom up tasks, err := e.generateMPPTasksForFragment(f) if err != nil { return nil, nil, errors.Trace(err) @@ -291,6 +293,7 @@ func (e *mppTaskGenerator) generateMPPTasksForExchangeSender(s *PhysicalExchange func (e *mppTaskGenerator) generateMPPTasksForFragment(f *Fragment) (tasks []*kv.MPPTask, err error) { for _, r := range f.ExchangeReceivers { + // chain call: to get lower fragments and tasks r.Tasks, r.frags, err = e.generateMPPTasksForExchangeSender(r.GetExchangeSender()) if err != nil { return nil, errors.Trace(err) diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index da0a1e2b0ddf7..cb26f7444e465 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tidb/util/ranger" "github.com/pingcap/tidb/util/size" "github.com/pingcap/tidb/util/stringutil" @@ -1495,7 +1496,61 @@ func (p *PhysicalExchangeReceiver) MemoryUsage() (sum int64) { return } -// PhysicalExchangeSender dispatches data to upstream tasks. That means push mode processing, +// PhysicalRepeatSource is used to repeat underlying data sources to feed different grouping sets. +type PhysicalRepeatSource struct { + // data after repeat-OP will generate a new grouping-ID column to indicate what grouping set is it for. + physicalSchemaProducer + + // generated grouping ID column itself. + GroupingIDCol *expression.Column + + // GroupingSets is used to define what kind of group layout should the underlying data follow. + // For simple case: select count(distinct a), count(distinct b) from t; the grouping expressions are [a] and [b]. + GroupingSets expression.GroupingSets +} + +// Init only assigns type and context. +func (p PhysicalRepeatSource) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *PhysicalRepeatSource { + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeRepeatSource, &p, offset) + p.stats = stats + return &p +} + +// Clone implements PhysicalPlan interface. +func (p *PhysicalRepeatSource) Clone() (PhysicalPlan, error) { + np := new(PhysicalRepeatSource) + base, err := p.basePhysicalPlan.cloneWithSelf(np) + if err != nil { + return nil, errors.Trace(err) + } + np.basePhysicalPlan = *base + // clone ID cols. + np.GroupingIDCol = p.GroupingIDCol.Clone().(*expression.Column) + + // clone grouping expressions. + clonedGroupingSets := make([]expression.GroupingSet, 0, len(p.GroupingSets)) + for _, one := range p.GroupingSets { + clonedGroupingSets = append(clonedGroupingSets, one.Clone()) + } + np.GroupingSets = p.GroupingSets + return np, nil +} + +// MemoryUsage return the memory usage of PhysicalRepeatSource +func (p *PhysicalRepeatSource) MemoryUsage() (sum int64) { + if p == nil { + return + } + + sum = p.physicalSchemaProducer.MemoryUsage() + size.SizeOfSlice + int64(cap(p.GroupingSets))*size.SizeOfPointer + for _, gs := range p.GroupingSets { + sum += gs.MemoryUsage() + } + sum += p.GroupingIDCol.MemoryUsage() + return +} + +// PhysicalExchangeSender dispatches data to upstream tasks. That means push mode processing. type PhysicalExchangeSender struct { basePhysicalPlan @@ -1507,7 +1562,7 @@ type PhysicalExchangeSender struct { CompressionMode kv.ExchangeCompressionMode } -// Clone implment PhysicalPlan interface. +// Clone implements PhysicalPlan interface. func (p *PhysicalExchangeSender) Clone() (PhysicalPlan, error) { np := new(PhysicalExchangeSender) base, err := p.basePhysicalPlan.cloneWithSelf(np) diff --git a/planner/core/plan_to_pb.go b/planner/core/plan_to_pb.go index 61edd9471a20c..4b5269f262ccc 100644 --- a/planner/core/plan_to_pb.go +++ b/planner/core/plan_to_pb.go @@ -33,6 +33,29 @@ func (p *basePhysicalPlan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Exe return nil, errors.Errorf("plan %s fails converts to PB", p.basePlan.ExplainID()) } +// ToPB implements PhysicalPlan ToPB interface. +func (p *PhysicalRepeatSource) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { + sc := ctx.GetSessionVars().StmtCtx + client := ctx.GetClient() + groupingSetsPB, err := p.GroupingSets.ToPB(sc, client) + if err != nil { + return nil, err + } + repeatSource := &tipb.RepeatSource{ + GroupingSets: groupingSetsPB, + } + executorID := "" + if storeType == kv.TiFlash { + var err error + repeatSource.Child, err = p.children[0].ToPB(ctx, storeType) + if err != nil { + return nil, errors.Trace(err) + } + executorID = p.ExplainID().String() + } + return &tipb.Executor{Tp: tipb.ExecType_TypeRepeatSource, RepeatSource: repeatSource, ExecutorId: &executorID}, nil +} + // ToPB implements PhysicalPlan ToPB interface. func (p *PhysicalHashAgg) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { sc := ctx.GetSessionVars().StmtCtx diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index b602c46a78bd2..451cf149f3568 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -390,7 +390,7 @@ func (p *PhysicalSelection) ResolveIndices() (err error) { return nil } -// ResolveIndicesItself resolve indices for PhyicalPlan itself +// ResolveIndicesItself resolve indices for PhysicalPlan itself func (p *PhysicalExchangeSender) ResolveIndicesItself() (err error) { for i, col := range p.HashCols { colExpr, err1 := col.Col.ResolveIndices(p.children[0].Schema()) @@ -411,6 +411,31 @@ func (p *PhysicalExchangeSender) ResolveIndices() (err error) { return p.ResolveIndicesItself() } +// ResolveIndicesItself resolve indices for PhysicalPlan itself +func (p *PhysicalRepeatSource) ResolveIndicesItself() error { + for _, gs := range p.GroupingSets { + for _, groupingExprs := range gs { + for k, groupingExpr := range groupingExprs { + gExpr, err := groupingExpr.ResolveIndices(p.children[0].Schema()) + if err != nil { + return err + } + groupingExprs[k] = gExpr + } + } + } + return nil +} + +// ResolveIndices implements Plan interface. +func (p *PhysicalRepeatSource) ResolveIndices() (err error) { + err = p.physicalSchemaProducer.ResolveIndices() + if err != nil { + return err + } + return p.ResolveIndicesItself() +} + // ResolveIndices implements Plan interface. func (p *basePhysicalAgg) ResolveIndices() (err error) { err = p.physicalSchemaProducer.ResolveIndices() diff --git a/planner/core/task.go b/planner/core/task.go index f111ea62e7c8a..f04a4d0ffd5ae 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -15,7 +15,9 @@ package core import ( + "fmt" "math" + "strings" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -141,6 +143,12 @@ func attachPlan2Task(p PhysicalPlan, t task) task { p.SetChildren(v.p) v.p = p case *mppTask: + if v.attachTaskCallBack != nil { + v.attachTaskCallBack(p) + } + if p == nil || v == nil || v.p == nil { + fmt.Println(1) + } p.SetChildren(v.p) v.p = p } @@ -788,6 +796,10 @@ type rootTask struct { p PhysicalPlan isEmpty bool // isEmpty indicates if this task contains a dual table and returns empty data. // TODO: The flag 'isEmpty' is only checked by Projection and UnionAll. We should support more cases in the future. + + // derive stats is done before physical plan enumeration, which when repeat source is inserted, + // the all upper layer physical op's base stats should be changed. + attachTaskCallBack func(PhysicalPlan) } func (t *rootTask) copy() task { @@ -1553,6 +1565,7 @@ func BuildFinalModeAggregation( finalAggFunc.OrderByItems = aggFunc.OrderByItems args := make([]expression.Expression, 0, len(aggFunc.Args)) if aggFunc.HasDistinct { + // 感觉这里是变下推到 kv 的 agg 结构的 /* eg: SELECT COUNT(DISTINCT a), SUM(b) FROM t GROUP BY c @@ -1564,6 +1577,7 @@ func BuildFinalModeAggregation( */ // onlyAddFirstRow means if the distinctArg does not occur in group by items, // it should be replaced with a firstrow() agg function, needed for the order by items of group_concat() + // 这个地方也是我们所看到的,为啥 distinct agg column 会出现在 group by 中的原因 getDistinctExpr := func(distinctArg expression.Expression, onlyAddFirstRow bool) (ret expression.Expression) { // 1. add all args to partial.GroupByItems foundInGroupBy := false @@ -1630,6 +1644,11 @@ func BuildFinalModeAggregation( byItems = append(byItems, &util.ByItems{Expr: getDistinctExpr(byItem.Expr, true), Desc: byItem.Desc}) } + if aggFunc.HasDistinct && isMPPTask && aggFunc.GroupingID > 0 { + // keep the groupingID as it was, otherwise the new split final aggregate's ganna lost its groupingID info. + finalAggFunc.GroupingID = aggFunc.GroupingID + } + finalAggFunc.OrderByItems = byItems finalAggFunc.HasDistinct = aggFunc.HasDistinct // In logical optimize phase, the Agg->PartitionUnion->TableReader may become @@ -1649,12 +1668,14 @@ func BuildFinalModeAggregation( return } if aggregation.NeedCount(finalAggFunc.Name) { + // only Avg and Count need count if isMPPTask && finalAggFunc.Name == ast.AggFuncCount { // For MPP Task, the final count() is changed to sum(). // Note: MPP mode does not run avg() directly, instead, avg() -> sum()/(case when count() = 0 then 1 else count() end), // so we do not process it here. finalAggFunc.Name = ast.AggFuncSum } else { + // avg branch ft := types.NewFieldType(mysql.TypeLonglong) ft.SetFlen(21) ft.SetCharset(charset.CharsetBin) @@ -1713,6 +1734,7 @@ func BuildFinalModeAggregation( args = append(args, aggFunc.Args[len(aggFunc.Args)-1]) } } else { + // other agg desc just split into two parts partialFuncDesc := aggFunc.Clone() partial.AggFuncs = append(partial.AggFuncs, partialFuncDesc) if aggFunc.Name == ast.AggFuncFirstRow { @@ -1787,6 +1809,7 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType exprs = append(exprs, divide) } else { + // other non-avg agg use the old schema as it did. newAggFuncs = append(newAggFuncs, aggFunc) newSchema.Append(p.schema.Columns[i]) exprs = append(exprs, p.schema.Columns[i]) @@ -1868,8 +1891,80 @@ func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTas return partialAgg, finalAgg } +func (p *basePhysicalAgg) scale3StageForDistinctAgg() (bool, expression.GroupingSets) { + if p.canUse3Stage4SingleDistinctAgg() { + return true, nil + } + return p.canUse3Stage4MultiDistinctAgg() +} + +// canUse3Stage4MultiDistinctAgg returns true if this agg can use 3 stage for multi distinct aggregation +func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss expression.GroupingSets) { + if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 { + return false, nil + } + defer func() { + // some clean work. + if can == false { + for _, fun := range p.AggFuncs { + fun.GroupingID = 0 + } + } + }() + // GroupingSets is alias of []GroupingSet, the below equal to = make([]GroupingSet, 0, 2) + GroupingSets := make(expression.GroupingSets, 0, 2) + for _, fun := range p.AggFuncs { + if fun.HasDistinct { + if fun.Name != ast.AggFuncCount { + // now only for multi count(distinct x) + return false, nil + } + for _, arg := range fun.Args { + // bail out when args are not simple column, see GitHub issue #35417 + if _, ok := arg.(*expression.Column); !ok { + return false, nil + } + } + // here it's a valid count distinct agg with normal column args, collecting its distinct expr. + GroupingSets = append(GroupingSets, expression.GroupingSet{fun.Args}) + // groupingID now is the offset of target grouping in GroupingSets. + // todo: it may be changed after grouping set merge in the future. + fun.GroupingID = int64(len(GroupingSets)) + } else if len(fun.Args) > 1 { + return false, nil + } + // banned count(distinct x order by y) + if len(fun.OrderByItems) > 0 || fun.Mode != aggregation.CompleteMode { + return false, nil + } + } + Compressed := GroupingSets.Merge() + if len(Compressed) != len(GroupingSets) { + // todo arenatlx: some grouping set should be merged which is not supported by now temporarily. + p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("Some grouping sets should be merged")) + } + if len(GroupingSets) > 1 { + // fill the grouping ID for normal agg. + for _, fun := range p.AggFuncs { + if fun.GroupingID == 0 { + // the grouping ID hasn't set. find the targeting grouping set. + groupingSetOffset := GroupingSets.TargetOne(fun.Args) + if groupingSetOffset == -1 { + // todo: if we couldn't find a existed current valid group layout, we need to copy the column out from being filled with null value. + p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("couldn't find a proper group set for normal agg")) + return false, nil + } + // starting with 1 + fun.GroupingID = int64(groupingSetOffset + 1) + } + } + return true, GroupingSets + } + return false, nil +} + // canUse3StageDistinctAgg returns true if this agg can use 3 stage for distinct aggregation -func (p *basePhysicalAgg) canUse3StageDistinctAgg() bool { +func (p *basePhysicalAgg) canUse3Stage4SingleDistinctAgg() bool { num := 0 if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 { return false @@ -2047,12 +2142,262 @@ func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *mppTask) task { return mpp } +func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan, canUse3StageAgg bool, + groupingSets expression.GroupingSets, mpp *mppTask) (final, mid, proj1, part, proj2 PhysicalPlan, _ error) { + // generate 3 stage aggregation for single/multi count distinct if applicable. + // + // select count(distinct a), count(b) from foo + // will generate plan: + // HashAgg sum(#1), sum(#2) -> final agg + // +- Exchange Passthrough + // +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg + // +- Exchange HashPartition by a + // +- HashAgg count(b) #3, group by a -> partial agg + // +- TableScan foo + // + // select count(distinct a), count(distinct b), count(c) from foo + // will generate plan: + // HashAgg sum(#1), sum(#2), sum(#3) -> final agg + // +- Exchange Passthrough + // +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg + // +- Exchange HashPartition by a + // +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg + // +- RepeatSource a, b -> repeat source + // +- TableScan foo + // + // set the default expression to constant 1 for the convenience to choose default group set data. + var groupingIDCol expression.Expression + groupingIDCol = expression.NewNumber(1) + if len(groupingSets) > 0 { + // enforce the repeatSource operator above the children. + // single distinct agg mode can eliminate repeatSource op insertion. + + // physical plan is enumerated without children from itself, use mpp subtree instead p.children. + stats := mpp.p.statsInfo().Scale(float64(len(groupingSets))) + + physicalRepeatSource := PhysicalRepeatSource{ + GroupingSets: groupingSets, + }.Init(p.ctx, stats, mpp.p.SelectBlockOffset()) + + // generate a new column as groupingID to identify which this row is targeting for. + groupingIDCol = &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + // append the physical repeatSource op with groupingID column. + physicalRepeatSource.SetSchema(mpp.p.Schema().Clone()) + physicalRepeatSource.schema.Append(groupingIDCol.(*expression.Column)) + physicalRepeatSource.GroupingIDCol = groupingIDCol.(*expression.Column) + + // set attach mpp task call back, aim to expand the input rowcount of stats for upper ops. + //mpp.attachTaskCallBack = func(p PhysicalPlan) { + // p.statsInfo().RowCount = p.statsInfo().RowCount * float64(len(groupingSets)) + //} + + // attach PhysicalRepeatSource to mpp + attachPlan2Task(physicalRepeatSource, mpp) + } + + if partialAgg != nil && canUse3StageAgg { + if groupingSets == nil { + // single distinct agg mode. + clonedAgg, err := finalAgg.Clone() + if err != nil { + return nil, nil, nil, nil, nil, err + } + + // step1: adjust middle agg + middleHashAgg := clonedAgg.(*PhysicalHashAgg) + distinctPos := 0 + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + for i, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.RetTp, + } + if fun.HasDistinct { + distinctPos = i + fun.Mode = aggregation.Partial1Mode + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) + } + middleHashAgg.schema = middleSchema + + // step2: adjust final agg + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if distinctPos == i { + // change count(distinct) to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + newArgs = append(newArgs, middleSchema.Columns[i]) + } else { + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, nil, err + } + newArgs = append(newArgs, newCol) + } + } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + finalAggDescs = append(finalAggDescs, fun) + } + finalHashAgg.AggFuncs = finalAggDescs + // partialAgg is im-mutated from args + return finalHashAgg, middleHashAgg, nil, partialAgg, nil, nil + } else { + // having group sets + clonedAgg, err := finalAgg.Clone() + if err != nil { + return nil, nil, nil, nil, nil, err + } + // select count(distinct a), count(distinct b), count(c) from t + // final agg : sum(#1), sum(#2), sum(#3) + // | | +--------------------------------------------+ + // | +--------------------------------+ | + // +-------------+ | | + // v v v + // middle agg : count(distinct caseWhen4a) as #1, count(distinct caseWhen4b) as #2, sum(#4) as #3 # all partial result + // | | + // +---------------------+ | + // v v + // proj4Middle : a, b, groupingID, #4(partial res), caseWhen4a, caseWhen4b # caseWhen4a & caseWhen4a depend on groupingID and a, b + // | + // | + // exchange hash partition: # shuffle data by + // | + // v + // partial agg : a, b, groupingID, count(caseWhen4c) as #4 # since group by a, b, groupingID we can directly output all this three. + // | + // v + // proj4Partial: a, b, c, groupingID, caseWhen4c # caseWhen4c depend on groupingID and c + // + // repeatSource: a, b, c, groupingID # appended groupingID to diff group set + // + // lower source: a, b, c # don't care what the lower source is + + // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. + // Since we may substitute the first arg of normal agg with case-when expression here, append a + // customized proj here rather than depend on postOptimize to insert a blunt one for us. + partialHashAgg := partialAgg.(*PhysicalHashAgg) + partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) + partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) + // proj4Partial output all the base col from lower op + caseWhen proj cols. + proj4Partial := new(PhysicalProjection).Init(p.ctx, mpp.p.statsInfo(), mpp.p.SelectBlockOffset()) + for _, col := range mpp.p.Schema().Columns { + proj4Partial.Exprs = append(proj4Partial.Exprs, col) + } + proj4Partial.SetSchema(mpp.p.Schema().Clone()) + for _, fun := range partialHashAgg.AggFuncs { + if !fun.HasDistinct { + // for normal agg phase1, we should also modify them to target for specified group data. + eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewNumber(fun.GroupingID)) + caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) + caseWhenProjCol := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.Args[0].GetType(), + } + proj4Partial.Exprs = append(proj4Partial.Exprs, caseWhen) + proj4Partial.Schema().Append(caseWhenProjCol) + fun.Args[0] = caseWhenProjCol + } + } + + // step2: adjust middle agg + middleHashAgg := clonedAgg.(*PhysicalHashAgg) + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + // proj4Middle output all the base col from lower op + caseWhen proj cols. + proj4Middle := new(PhysicalProjection).Init(p.ctx, partialHashAgg.stats, partialHashAgg.SelectBlockOffset()) + for _, col := range partialHashAgg.Schema().Columns { + proj4Middle.Exprs = append(proj4Middle.Exprs, col) + } + proj4Middle.SetSchema(partialHashAgg.Schema().Clone()) + for _, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.RetTp, + } + if fun.HasDistinct { + fun.Mode = aggregation.Partial1Mode + // change the middle distinct agg args to target for specified group data. + // even for distinct agg multi args like count(distinct a,b), either null of a,b will cause count func to skip this row. + // so we only need to set the either one arg of them to be a case when expression to target for its self-group: + // Expr = (case when groupingID = ? then arg else null end) + // count(distinct a,b) => count(distinct (case when groupingID = targetID then a else null end), b) + eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewNumber(fun.GroupingID)) + caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) + caseWhenProjCol := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.Args[0].GetType(), + } + proj4Middle.Exprs = append(proj4Middle.Exprs, caseWhen) + proj4Middle.Schema().Append(caseWhenProjCol) + fun.Args[0] = caseWhenProjCol + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // record the origin column unique id down before change it to be case when expr. + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) + } + middleHashAgg.schema = middleSchema + + // step3: adjust final agg + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if fun.HasDistinct { + // change count(distinct) agg to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + // count(distinct a,b) -> become a single partial result col. + newArgs = append(newArgs, middleSchema.Columns[i]) + } else { + // remap final normal agg args to be output schema of middle normal agg. + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, nil, err + } + newArgs = append(newArgs, newCol) + } + } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + fun.GroupingID = 0 + finalAggDescs = append(finalAggDescs, fun) + } + finalHashAgg.AggFuncs = finalAggDescs + return finalHashAgg, middleHashAgg, proj4Middle, partialHashAgg, proj4Partial, nil + } + } + // return the original finalAgg and partiAgg + return finalAgg, nil, nil, partialAgg, nil, nil +} + func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { t := tasks[0].copy() mpp, ok := t.(*mppTask) if !ok { return invalidTask } + if strings.HasPrefix(p.ctx.GetSessionVars().StmtCtx.OriginalSQL, "select count(distinct a, b), count(distinct b), count(c) from t group by d") { + fmt.Println(1) + } switch p.MppRunMode { case Mpp1Phase: // 1-phase agg: when the partition columns can be satisfied, where the plan does not need to enforce Exchange @@ -2086,6 +2431,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { }) } } + // 形成新的 partition cols property prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} newMpp := mpp.enforceExchangerImpl(prop) if newMpp.invalid() { @@ -2108,84 +2454,35 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { case MppScalar: prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.SinglePartitionType} if !mpp.needEnforceExchanger(prop) { + // one hand: when the low layer already satisfied the single partition layout, just do the all agg computation in the single node. return p.attach2TaskForMpp1Phase(mpp) } + // other hand: try to split the mppScalar agg into multi phases agg **down** to multi nodes since data already distributed across nodes. // we have to check it before the content of p has been modified - canUse3StageAgg := p.canUse3StageDistinctAgg() + canUse3StageAgg, groupingSets := p.scale3StageForDistinctAgg() proj := p.convertAvgForMPP() partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true) if finalAgg == nil { return invalidTask } - // generate 3 stage aggregation for single count distinct if applicable. - // select count(distinct a), count(b) from foo - // will generate plan: - // HashAgg sum(#1), sum(#2) -> final agg - // +- Exchange Passthrough - // +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg - // +- Exchange HashPartition by a - // +- HashAgg count(b) #3, group by a -> partial agg - // +- TableScan foo - var middleAgg *PhysicalHashAgg = nil - if partialAgg != nil && canUse3StageAgg { - clonedAgg, err := finalAgg.Clone() - if err != nil { - return invalidTask - } - middleAgg = clonedAgg.(*PhysicalHashAgg) - distinctPos := 0 - middleSchema := expression.NewSchema() - schemaMap := make(map[int64]*expression.Column, len(middleAgg.AggFuncs)) - for i, fun := range middleAgg.AggFuncs { - col := &expression.Column{ - UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: fun.RetTp, - } - if fun.HasDistinct { - distinctPos = i - fun.Mode = aggregation.Partial1Mode - } else { - fun.Mode = aggregation.Partial2Mode - originalCol := fun.Args[0].(*expression.Column) - schemaMap[originalCol.UniqueID] = col - } - middleSchema.Append(col) - } - middleAgg.schema = middleSchema + final, middle, proj4Middle, partial, proj4Partial, err := p.adjust3StagePhaseAgg(partialAgg, finalAgg, canUse3StageAgg, groupingSets, mpp) + if err != nil { + return invalidTask + } - finalHashAgg := finalAgg.(*PhysicalHashAgg) - finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) - for i, fun := range finalHashAgg.AggFuncs { - newArgs := make([]expression.Expression, 0, 1) - if distinctPos == i { - // change count(distinct) to sum() - fun.Name = ast.AggFuncSum - fun.HasDistinct = false - newArgs = append(newArgs, middleSchema.Columns[i]) - } else { - for _, arg := range fun.Args { - newCol, err := arg.RemapColumn(schemaMap) - if err != nil { - return invalidTask - } - newArgs = append(newArgs, newCol) - } - } - fun.Mode = aggregation.FinalMode - fun.Args = newArgs - finalAggDescs = append(finalAggDescs, fun) - } - finalHashAgg.AggFuncs = finalAggDescs + // partial agg proj would be null if one scalar agg cannot run in two-phase mode + if proj4Partial != nil { + attachPlan2Task(proj4Partial, mpp) } // partial agg would be null if one scalar agg cannot run in two-phase mode - if partialAgg != nil { - attachPlan2Task(partialAgg, mpp) + if partial != nil { + attachPlan2Task(partial, mpp) } - if middleAgg != nil && canUse3StageAgg { - items := partialAgg.(*PhysicalHashAgg).GroupByItems + if middle != nil && canUse3StageAgg { + items := partial.(*PhysicalHashAgg).GroupByItems partitionCols := make([]*property.MPPPartitionColumn, 0, len(items)) for _, expr := range items { col, ok := expr.(*expression.Column) @@ -2198,14 +2495,18 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { }) } - prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} - newMpp := mpp.enforceExchanger(prop) - attachPlan2Task(middleAgg, newMpp) + exProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} + newMpp := mpp.enforceExchanger(exProp) + if proj4Middle != nil { + attachPlan2Task(proj4Middle, newMpp) + } + attachPlan2Task(middle, newMpp) mpp = newMpp } + // prop here still be the first generated single-partition requirement. newMpp := mpp.enforceExchanger(prop) - attachPlan2Task(finalAgg, newMpp) + attachPlan2Task(final, newMpp) if proj == nil { proj = PhysicalProjection{ Exprs: make([]expression.Expression, 0, len(p.Schema().Columns)), @@ -2262,6 +2563,7 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task { attachPlan2Task(p, t) } } else if _, ok := t.(*mppTask); ok { + // 如果底层暂时还是个 mppTask 的话 return final.attach2TaskForMpp(tasks...) } else { attachPlan2Task(p, t) @@ -2301,6 +2603,10 @@ func (p *PhysicalWindow) attach2Task(tasks ...task) task { type mppTask struct { p PhysicalPlan + // derive stats is done before physical plan enumeration, which when repeat source is inserted, + // the all upper layer physical op's base stats should be changed. + attachTaskCallBack func(PhysicalPlan) + partTp property.MPPPartitionType hashCols []*property.MPPPartitionColumn @@ -2380,6 +2686,7 @@ func accumulateNetSeekCost4MPP(p PhysicalPlan) (cost float64) { return } +// mppTask 转 root task 的时候,需要加一个 sender 并封装为 table reader func (t *mppTask) convertToRootTaskImpl(ctx sessionctx.Context) *rootTask { sender := PhysicalExchangeSender{ ExchangeType: tipb.ExchangeType_PassThrough, @@ -2393,7 +2700,8 @@ func (t *mppTask) convertToRootTaskImpl(ctx sessionctx.Context) *rootTask { p.stats = t.p.statsInfo() collectPartitionInfosFromMPPPlan(p, t.p) rt := &rootTask{ - p: p, + p: p, + attachTaskCallBack: t.attachTaskCallBack, } if len(t.rootTaskConds) > 0 { @@ -2470,7 +2778,7 @@ func (t *mppTask) enforceExchangerImpl(prop *property.PhysicalProperty) *mppTask ctx := t.p.SCtx() sender := PhysicalExchangeSender{ ExchangeType: prop.MPPPartitionTp.ToExchangeType(), - HashCols: prop.MPPPartitionCols, + HashCols: prop.MPPPartitionCols, // 执行层面怎么挂钩 hash key 和具体的 node 节点? }.Init(ctx, t.p.statsInfo()) if ctx.GetSessionVars().ChooseMppVersion() >= kv.MppVersionV1 { diff --git a/store/copr/mpp.go b/store/copr/mpp.go index 1c4ef6a40e85a..3b4efc96b2c2d 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -350,7 +350,7 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *Backoffer, req if !req.IsRoot { return } - + // only root task should establish a stream conn with tiFlash to receive result. m.establishMPPConns(bo, req, taskMeta) } diff --git a/store/mockstore/unistore/cophandler/mpp.go b/store/mockstore/unistore/cophandler/mpp.go index 1cf0746e861dd..785e63fdeea18 100644 --- a/store/mockstore/unistore/cophandler/mpp.go +++ b/store/mockstore/unistore/cophandler/mpp.go @@ -200,6 +200,63 @@ func (b *mppExecBuilder) buildLimit(pb *tipb.Limit) (*limitExec, error) { return exec, nil } +func (b *mppExecBuilder) buildRepeatSource(pb *tipb.RepeatSource) (mppExec, error) { + child, err := b.buildMPPExecutor(pb.Child) + if err != nil { + return nil, err + } + exec := &repeatSourceExec{ + baseMPPExec: baseMPPExec{sc: b.sc, mppCtx: b.mppCtx, children: []mppExec{child}}, + } + + childFieldTypes := child.getFieldTypes() + // convert the grouping sets. + tidbGss := expression.GroupingSets{} + for _, gs := range pb.GroupingSets { + tidbGs := expression.GroupingSet{} + for _, groupingExprs := range gs.GroupingExprs { + tidbGroupingExprs, err := convertToExprs(b.sc, childFieldTypes, groupingExprs.GroupingExpr) + if err != nil { + return nil, err + } + tidbGs = append(tidbGs, tidbGroupingExprs) + } + tidbGss = append(tidbGss, tidbGs) + } + exec.groupingSets = tidbGss + inGroupingSetMap := make(map[int]struct{}, len(exec.groupingSets)) + for _, gs := range exec.groupingSets { + // for every grouping set, collect column offsets under this grouping set. + for _, groupingExprs := range gs { + for _, groupingExpr := range groupingExprs { + if col, ok := groupingExpr.(*expression.Column); !ok { + return nil, errors.New("grouping set expr is not column ref") + } else { + inGroupingSetMap[col.Index] = struct{}{} + } + } + } + } + var mutatedFieldTypes []*types.FieldType + // change the field types return from children tobe nullable. + for offset, f := range childFieldTypes { + cf := f.Clone() + if _, ok := inGroupingSetMap[offset]; ok { + // remove the not null flag, make it nullable. + cf.SetFlag(cf.GetFlag() & ^mysql.NotNullFlag) + } + mutatedFieldTypes = append(mutatedFieldTypes, cf) + } + + // adding groupingID uint64|not-null as last one field types. + groupingIDFieldType := types.NewFieldType(mysql.TypeLonglong) + groupingIDFieldType.SetFlag(mysql.NotNullFlag | mysql.UnsignedFlag) + mutatedFieldTypes = append(mutatedFieldTypes, groupingIDFieldType) + + exec.fieldTypes = mutatedFieldTypes + return exec, nil +} + func (b *mppExecBuilder) buildTopN(pb *tipb.TopN) (mppExec, error) { child, err := b.buildMPPExecutor(pb.Child) if err != nil { @@ -253,18 +310,19 @@ func (b *mppExecBuilder) buildMPPExchangeSender(pb *tipb.ExchangeSender) (*exchS exchangeTp: pb.Tp, } if pb.Tp == tipb.ExchangeType_Hash { - if len(pb.PartitionKeys) != 1 { - return nil, errors.New("The number of hash key must be 1") - } - expr, err := expression.PBToExpr(pb.PartitionKeys[0], child.getFieldTypes(), b.sc) - if err != nil { - return nil, errors.Trace(err) - } - col, ok := expr.(*expression.Column) - if !ok { - return nil, errors.New("Hash key must be column type") + // remove the limitation of len(pb.PartitionKeys) == 1 + for _, partitionKey := range pb.PartitionKeys { + expr, err := expression.PBToExpr(partitionKey, child.getFieldTypes(), b.sc) + if err != nil { + return nil, errors.Trace(err) + } + col, ok := expr.(*expression.Column) + if !ok { + return nil, errors.New("Hash key must be column type") + } + e.hashKeyOffsets = append(e.hashKeyOffsets, col.Index) + e.hashKeyTypes = append(e.hashKeyTypes, e.fieldTypes[col.Index]) } - e.hashKeyOffset = col.Index } for _, taskMeta := range pb.EncodedTaskMeta { @@ -493,6 +551,8 @@ func (b *mppExecBuilder) buildMPPExecutor(exec *tipb.Executor) (mppExec, error) case tipb.ExecType_TypePartitionTableScan: ts := exec.PartitionTableScan return b.buildMPPPartitionTableScan(ts) + case tipb.ExecType_TypeRepeatSource: + return b.buildRepeatSource(exec.RepeatSource) default: return nil, errors.Errorf(ErrExecutorNotSupportedMsg + exec.Tp.String()) } diff --git a/store/mockstore/unistore/cophandler/mpp_exec.go b/store/mockstore/unistore/cophandler/mpp_exec.go index d0d0a71d5b85a..f69c904239658 100644 --- a/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/store/mockstore/unistore/cophandler/mpp_exec.go @@ -17,6 +17,7 @@ package cophandler import ( "bytes" "encoding/binary" + "hash/fnv" "io" "math" "sort" @@ -415,6 +416,112 @@ func (e *limitExec) next() (*chunk.Chunk, error) { return chk, nil } +// repeatSourceExec is the basic mock logic for repeat executor in uniStore, +// with which we can validate the mpp plan correctness from explain result and result returned. +type repeatSourceExec struct { + baseMPPExec + + lastNum int + lastChunk *chunk.Chunk + groupingSets expression.GroupingSets + groupingSetOffsetMap []map[int]struct{} + groupingSetScope map[int]struct{} +} + +func (e *repeatSourceExec) open() error { + var err error + err = e.children[0].open() + if err != nil { + return err + } + // building the quick finding map + e.groupingSetOffsetMap = make([]map[int]struct{}, 0, len(e.groupingSets)) + e.groupingSetScope = make(map[int]struct{}, len(e.groupingSets)) + for _, gs := range e.groupingSets { + tmp := make(map[int]struct{}, len(gs)) + // for every grouping set, collect column offsets under this grouping set. + for _, groupingExprs := range gs { + for _, groupingExpr := range groupingExprs { + if col, ok := groupingExpr.(*expression.Column); !ok { + return errors.New("grouping set expr is not column ref") + } else { + tmp[col.Index] = struct{}{} + e.groupingSetScope[col.Index] = struct{}{} + } + } + } + e.groupingSetOffsetMap = append(e.groupingSetOffsetMap, tmp) + } + return nil +} + +func (e *repeatSourceExec) isGroupingCol(index int) bool { + if _, ok := e.groupingSetScope[index]; ok { + return true + } + return false +} + +func (e *repeatSourceExec) next() (*chunk.Chunk, error) { + var ( + err error + ) + if e.groupingSets.IsEmpty() { + return e.children[0].next() + } + resChk := chunk.NewChunkWithCapacity(e.getFieldTypes(), DefaultBatchSize) + for { + if e.lastChunk == nil || e.lastChunk.NumRows() == e.lastNum { + // fetch one chunk from children. + e.lastChunk, err = e.children[0].next() + if err != nil { + return nil, err + } + e.lastNum = 0 + if e.lastChunk == nil || e.lastChunk.NumRows() == 0 { + break + } + e.execSummary.updateOnlyRows(e.lastChunk.NumRows()) + } + numRows := e.lastChunk.NumRows() + numGroupingOffset := len(e.groupingSets) + + for i := e.lastNum; i < numRows; i++ { + row := e.lastChunk.GetRow(i) + e.lastNum++ + // for every grouping set, repeat the base row N times. + for g := 0; g < numGroupingOffset; g++ { + repeatRow := chunk.MutRowFromTypes(e.fieldTypes) + // for every targeted grouping set: + // 1: for every column in this grouping set, setting them as it was. + // 2: for every column in other target grouping set, setting them as null. + // 3: for every column not in any grouping set, setting them as it was. + // * normal agg only aimed at one replica of them with groupingID = 1 + // * so we don't need to change non-related column to be nullable. + // * so we don't need to mutate the column to be null when groupingID > 1 + for datumOffset, datumType := range e.fieldTypes[:len(e.fieldTypes)-1] { + if _, ok := e.groupingSetOffsetMap[g][datumOffset]; ok { + repeatRow.SetDatum(datumOffset, row.GetDatum(datumOffset, datumType)) + } else if !e.isGroupingCol(datumOffset) { + repeatRow.SetDatum(datumOffset, row.GetDatum(datumOffset, datumType)) + } else { + repeatRow.SetDatum(datumOffset, types.NewDatum(nil)) + } + } + // the last one column should be groupingID col. + groupingID := g + 1 + repeatRow.SetDatum(len(e.fieldTypes)-1, types.NewDatum(groupingID)) + resChk.AppendRow(repeatRow.ToRow()) + } + if DefaultBatchSize-resChk.NumRows() < numGroupingOffset { + // no enough room for another repeated N rows, return this chunk immediately. + return resChk, nil + } + } + } + return resChk, nil +} + type topNExec struct { baseMPPExec @@ -492,10 +599,11 @@ func (e *topNExec) next() (*chunk.Chunk, error) { type exchSenderExec struct { baseMPPExec - tunnels []*ExchangerTunnel - outputOffsets []uint32 - exchangeTp tipb.ExchangeType - hashKeyOffset int + tunnels []*ExchangerTunnel + outputOffsets []uint32 + exchangeTp tipb.ExchangeType + hashKeyOffsets []int + hashKeyTypes []*types.FieldType } func (e *exchSenderExec) open() error { @@ -548,15 +656,22 @@ func (e *exchSenderExec) next() (*chunk.Chunk, error) { for i := 0; i < len(e.tunnels); i++ { targetChunks = append(targetChunks, chunk.NewChunkWithCapacity(e.fieldTypes, rows)) } + hashVals := fnv.New64() + payload := make([]byte, 1) for i := 0; i < rows; i++ { row := chk.GetRow(i) - d := row.GetDatum(e.hashKeyOffset, e.fieldTypes[e.hashKeyOffset]) - if d.IsNull() { - targetChunks[0].AppendRow(row) - } else { - hashKey := int(d.GetInt64() % int64(len(e.tunnels))) - targetChunks[hashKey].AppendRow(row) + hashVals.Reset() + // use hash values to get unique uint64 to mod. + // collect all the hash key datum. + err := codec.HashChunkRow(e.sc, hashVals, row, e.hashKeyTypes, e.hashKeyOffsets, payload) + if err != nil { + for _, tunnel := range e.tunnels { + tunnel.ErrCh <- err + } + return nil, nil } + hashKey := hashVals.Sum64() % uint64(len(e.tunnels)) + targetChunks[hashKey].AppendRow(row) } for i, tunnel := range e.tunnels { if targetChunks[i].NumRows() > 0 { @@ -618,6 +733,7 @@ func (e *exchRecvExec) init() error { } serverMetas = append(serverMetas, meta) } + // for receiver: open conn worker for every receive meta. for _, meta := range serverMetas { e.wg.Add(1) go e.runTunnelWorker(e.mppCtx.TaskHandler, meta) @@ -668,6 +784,7 @@ func (e *exchRecvExec) EstablishConnAndReceiveData(h *MPPTaskHandler, meta *mpp. } return nil, errors.Trace(err) } + // stream resp if mppResponse == nil { return ret, nil } diff --git a/store/mockstore/unistore/rpc.go b/store/mockstore/unistore/rpc.go index 15fe6828ddabf..3fa7aef4bec23 100644 --- a/store/mockstore/unistore/rpc.go +++ b/store/mockstore/unistore/rpc.go @@ -332,6 +332,7 @@ func (c *RPCClient) handleCopStream(ctx context.Context, req *coprocessor.Reques }, nil } +// handleEstablishMPPConnection handle the mock mpp collection came from root or peers. func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.EstablishMPPConnectionRequest, timeout time.Duration, storeID uint64) (*tikvrpc.MPPStreamResponse, error) { mockServer := new(mockMPPConnectStreamServer) err := c.usSvr.EstablishMPPConnectionWithStoreID(r, mockServer, storeID) @@ -348,6 +349,7 @@ func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.Est _, cancel := context.WithCancel(ctx) streamResp.Lease.Cancel = cancel streamResp.Timeout = timeout + // mock the stream resp from the server's resp slice first, err := streamResp.Recv() if err != nil { if errors.Cause(err) != io.EOF { diff --git a/util/plancodec/id.go b/util/plancodec/id.go index 0366b7dc3f5a6..5e107a9d62ce2 100644 --- a/util/plancodec/id.go +++ b/util/plancodec/id.go @@ -57,6 +57,8 @@ const ( TypeExchangeSender = "ExchangeSender" // TypeExchangeReceiver is the type of mpp exchanger receiver. TypeExchangeReceiver = "ExchangeReceiver" + // TypeRepeatSource is the type of mpp repeat source operator. + TypeRepeatSource = "RepeatSource" // TypeMergeJoin is the type of merge join. TypeMergeJoin = "MergeJoin" // TypeIndexJoin is the type of index look up join. @@ -193,6 +195,7 @@ const ( typeShuffleReceiverID int = 55 typeForeignKeyCheck int = 56 typeForeignKeyCascade int = 57 + typeRepeatSourceID int = 58 ) // TypeStringToPhysicalID converts the plan type string to plan id. @@ -312,6 +315,8 @@ func TypeStringToPhysicalID(tp string) int { return typeForeignKeyCheck case TypeForeignKeyCascade: return typeForeignKeyCascade + case TypeRepeatSource: + return typeRepeatSourceID } // Should never reach here. return 0 @@ -434,6 +439,8 @@ func PhysicalIDToTypeString(id int) string { return TypeForeignKeyCheck case typeForeignKeyCascade: return TypeForeignKeyCascade + case typeRepeatSourceID: + return TypeRepeatSource } // Should never reach here. From 0442a2094174fed9979533bf172c8952523f2b73 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Sun, 18 Dec 2022 17:17:58 +0800 Subject: [PATCH 02/25] adjust the stats for 3 stages agg for-multi distinct-agg Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets.go | 25 ++- .../testdata/enforce_mpp_suite_out.json | 106 ++++++------- planner/core/stats.go | 10 ++ planner/core/task.go | 144 ++++++++++++++---- 4 files changed, 198 insertions(+), 87 deletions(-) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index 28a24928087ef..dcdc9ac16bc31 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -29,7 +29,7 @@ func (gs GroupingSets) Merge() GroupingSets { // for now, there is precondition that all grouping expressions are columns. // for example: (a,b,c) and (a,b) and (a) will be merged as a one. // Eg: - // before merging, there are 3 grouping sets. + // before merging, there are 4 grouping sets. // GroupingSets: // [ // [[a,b,c],] // every set including a grouping Expressions for initial. @@ -38,7 +38,7 @@ func (gs GroupingSets) Merge() GroupingSets { // [[e],] // ] // - // after merging, there is only 1 grouping set. + // after merging, there is only 2 grouping set. // GroupingSets: // [ // [[a],[a,b],[a,b,c],] // one set including 3 grouping Expressions after merging if possible. @@ -80,7 +80,7 @@ func (gs GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { // [a,b,c,d] // when adding [a,b,c,d] grouping expr here, since it's a super-set of current [a,b] with offset 1, take the offset as 1+1. // - // result grouping set: [[b], [a,b], [a,b,c,d]], expanding with step as two or more columns is acceptable and reasonable. + // result grouping set: [[b], [a,b], [a,b,c,d]], expanding with step with two or more columns is acceptable and reasonable. // | | | // +----+-------+ every previous one is the subset of the latter one. // @@ -116,7 +116,7 @@ func (gs GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { } // here means we couldn't find even one GroupingSet to fill the targetOne, creating a new one. gs = append(gs, newGroupingSet(targetOne)) - // gs is a alias of slice [], we should return it back after being changed. + // gs is an alias of slice [], we should return it back after being changed. return gs } @@ -136,10 +136,11 @@ func (gs GroupingSets) TargetOne(targetOne GroupingExprs) int { // as null and groupingID as 1 to guarantee group operator can equate to group which is equal to group, that's what the upper layer // distinct(a,b) want for. For distinct(c), the same is true. // - // For normal agg you better choose your targeting group-set data, otherwise, all your get is endless null values filled by repeatSource operator. + // For normal agg you better choose your targeting group-set data, otherwise, otherwise your get is all groups,most of them is endless null values filled by + // repeatSource operator, and null value can also influence your non null-strict normal agg, although you don't want them to. // // here we only consider 1&2 since normal agg with multi col args is banned. - allGroupingColIDs := gs.allSetsColIDs() + allGroupingColIDs := gs.AllSetsColIDs() for idx, groupingSet := range gs { // diffCols are those columns being filled with null in the group row data of current grouping set. diffCols := allGroupingColIDs.Difference(*groupingSet.allSetColIDs()) @@ -175,6 +176,16 @@ func (gs GroupingSet) allSetColIDs() *fd.FastIntSet { return &res } +func (gs GroupingSet) ExtractCols() []*Column { + cols := make([]*Column, 0, len(gs)) + for _, groupingExprs := range gs { + for _, one := range groupingExprs { + cols = append(cols, one.(*Column)) + } + } + return cols +} + func (gs GroupingSet) Clone() GroupingSet { gc := make(GroupingSet, 0, len(gs)) for _, one := range gs { @@ -228,7 +239,7 @@ func (gss GroupingSets) IsEmpty() bool { return true } -func (gss GroupingSets) allSetsColIDs() *fd.FastIntSet { +func (gss GroupingSets) AllSetsColIDs() *fd.FastIntSet { res := fd.NewFastIntSet() for _, groupingSet := range gss { res.UnionWith(*groupingSet.allSetColIDs()) diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/testdata/enforce_mpp_suite_out.json index f97fe2fdb4d6f..fa7780097826d 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_out.json @@ -1234,18 +1234,18 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", - "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#6, funcs:sum(Column#16)->Column#7", - " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", - " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", - " └─Projection_30 1.00 mpp[tiflash] test.t.a, test.t.b, Column#13, case(eq(Column#13, 1), test.t.a, )->Column#15, case(eq(Column#13, 2), test.t.b, )->Column#17", - " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", - " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#13, collate: binary]", - " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#13, test.t.a, test.t.b, ", - " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", + "TableReader_38 1.00 root data:ExchangeSender_37", + "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7", + " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#6, funcs:sum(Column#16)->Column#7", + " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", + " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", + " └─Projection_32 16000.00 mpp[tiflash] test.t.a, test.t.b, Column#13, case(eq(Column#13, 1), test.t.a, )->Column#15, case(eq(Column#13, 2), test.t.b, )->Column#17", + " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", + " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#13, collate: binary]", + " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#13, test.t.a, test.t.b, ", + " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null @@ -1260,19 +1260,19 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", - "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", - " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", - " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", - " └─Projection_30 1.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", - " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", - " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", - " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#18, test.t.a, test.t.b, funcs:count(Column#19)->Column#17", - " └─Projection_29 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", - " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", + "TableReader_38 1.00 root data:ExchangeSender_37", + "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", + " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", + " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", + " └─Projection_32 16000.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", + " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", + " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", + " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#18, test.t.a, test.t.b, funcs:count(Column#19)->Column#17", + " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", + " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null @@ -1287,20 +1287,20 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", - "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", - " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", - " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", - " └─Projection_30 1.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", - " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", - " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", - " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#26, Column#27, Column#28, funcs:sum(Column#25)->Column#17", - " └─Projection_37 20000.00 mpp[tiflash] cast(Column#19, decimal(10,0) BINARY)->Column#25, test.t.a, test.t.b, Column#18", - " └─Projection_29 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", - " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", + "TableReader_38 1.00 root data:ExchangeSender_37", + "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", + " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", + " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", + " └─Projection_32 16000.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", + " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", + " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", + " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#26, Column#27, Column#28, funcs:sum(Column#25)->Column#17", + " └─Projection_39 20000.00 mpp[tiflash] cast(Column#19, decimal(10,0) BINARY)->Column#25, test.t.a, test.t.b, Column#18", + " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", + " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null @@ -1315,20 +1315,20 @@ { "SQL": "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", - "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8, Column#9", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#26)->Column#6, funcs:sum(Column#28)->Column#7, funcs:sum(Column#30)->Column#8, funcs:sum(Column#31)->Column#9", - " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", - " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_27 1.00 mpp[tiflash] funcs:count(distinct Column#27, test.t.b)->Column#26, funcs:count(distinct Column#29)->Column#28, funcs:sum(Column#21)->Column#30, funcs:sum(Column#22)->Column#31", - " └─Projection_30 1.00 mpp[tiflash] Column#21, Column#22, test.t.a, test.t.b, Column#23, case(eq(Column#23, 1), test.t.a, )->Column#27, case(eq(Column#23, 2), test.t.b, )->Column#29", - " └─ExchangeReceiver_32 1.00 mpp[tiflash] ", - " └─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#23, collate: binary]", - " └─HashAgg_25 1.00 mpp[tiflash] group by:Column#34, Column#35, Column#36, funcs:count(Column#32)->Column#21, funcs:sum(Column#33)->Column#22", - " └─Projection_37 20000.00 mpp[tiflash] Column#24, cast(Column#25, decimal(10,0) BINARY)->Column#33, test.t.a, test.t.b, Column#23", - " └─Projection_29 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#23, case(eq(Column#23, 1), test.t.c, )->Column#24, case(eq(Column#23, 1), test.t.d, )->Column#25", - " └─RepeatSource_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#23, [{},{}]", + "TableReader_38 1.00 root data:ExchangeSender_37", + "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7, Column#8, Column#9", + " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#26)->Column#6, funcs:sum(Column#28)->Column#7, funcs:sum(Column#30)->Column#8, funcs:sum(Column#31)->Column#9", + " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", + " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#27, test.t.b)->Column#26, funcs:count(distinct Column#29)->Column#28, funcs:sum(Column#21)->Column#30, funcs:sum(Column#22)->Column#31", + " └─Projection_32 16000.00 mpp[tiflash] Column#21, Column#22, test.t.a, test.t.b, Column#23, case(eq(Column#23, 1), test.t.a, )->Column#27, case(eq(Column#23, 2), test.t.b, )->Column#29", + " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", + " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#23, collate: binary]", + " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#34, Column#35, Column#36, funcs:count(Column#32)->Column#21, funcs:sum(Column#33)->Column#22", + " └─Projection_39 20000.00 mpp[tiflash] Column#24, cast(Column#25, decimal(10,0) BINARY)->Column#33, test.t.a, test.t.b, Column#23", + " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#23, case(eq(Column#23, 1), test.t.c, )->Column#24, case(eq(Column#23, 1), test.t.d, )->Column#25", + " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#23, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": [ diff --git a/planner/core/stats.go b/planner/core/stats.go index 96065fccf12b5..eb10273a60189 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -567,6 +567,16 @@ func getColsNDVWithMatchedLen(cols []*expression.Column, schema *expression.Sche return NDV, 1 } +func getColsDNVWithMatchedLenFromUniqueIDs(ids []int64, schema *expression.Schema, profile *property.StatsInfo) (float64, int) { + cols := make([]*expression.Column, 0, len(ids)) + for _, id := range ids { + cols = append(cols, &expression.Column{ + UniqueID: id, + }) + } + return getColsNDVWithMatchedLen(cols, schema, profile) +} + func (p *LogicalProjection) getGroupNDVs(colGroups [][]*expression.Column, childProfile *property.StatsInfo, selfSchema *expression.Schema) []property.GroupNDV { if len(colGroups) == 0 || len(childProfile.GroupNDVs) == 0 { return nil diff --git a/planner/core/task.go b/planner/core/task.go index f04a4d0ffd5ae..0870db47b9038 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -143,12 +143,6 @@ func attachPlan2Task(p PhysicalPlan, t task) task { p.SetChildren(v.p) v.p = p case *mppTask: - if v.attachTaskCallBack != nil { - v.attachTaskCallBack(p) - } - if p == nil || v == nil || v.p == nil { - fmt.Println(1) - } p.SetChildren(v.p) v.p = p } @@ -1888,6 +1882,7 @@ func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTas MppRunMode: p.MppRunMode, }.initForHash(p.ctx, p.stats, p.blockOffset, prop) finalAgg.schema = finalPref.Schema + // partialAgg and finalAgg use the same ref of stats return partialAgg, finalAgg } @@ -2142,6 +2137,97 @@ func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *mppTask) task { return mpp } +// scaleStats4GroupingSets scale the derived stats because the lower source has been repeated. +// +// parent OP <- logicalAgg <- children OP (derived stats) +// | +// v +// parent OP <- physicalAgg <- children OP (stats used) +// | +// +----------+----------+----------+ +// Final Mid Partial Repeat +// +// physical agg stats is reasonable from the whole, because repeat operator is designed to facilitate +// the Mid and Partial Agg, which means when leaving the Final, its output rowcount could be exactly +// the same as what it derived(estimated) before entering physical optimization phase. +// +// From the cost model correctness, for these inserted sub-agg and even repeat source operator, we should +// recompute the stats for them particularly. +// +// for example: grouping sets {},{}, group by items {a,b,c,groupingID} +// after repeat source: +// +// a, b, c, groupingID +// ... null c 1 ---+ +// ... null c 1 +------- replica group 1 +// ... null c 1 ---+ +// null ... c 2 ---+ +// null ... c 2 +------- replica group 2 +// null ... c 2 ---+ +// +// since null value is seen the same when grouping data (groupingID in one replica is always the same): +// - so the num of group in replica 1 is equal to NDV(a,c) +// - so the num of group in replica 2 is equal to NDV(b,c) +// +// in a summary, the total num of group of all replica is equal to = Σ:NDV(each-grouping-set-cols, normal-group-cols) +func (p *PhysicalHashAgg) scaleStats4GroupingSets(groupingSets expression.GroupingSets, groupingIDCol *expression.Column, + childSchema *expression.Schema, childStats *property.StatsInfo) { + idSets := groupingSets.AllSetsColIDs() + normalGbyCols := make([]*expression.Column, 0, len(p.GroupByItems)) + for _, gbyExpr := range p.GroupByItems { + cols := expression.ExtractColumns(gbyExpr) + for _, col := range cols { + if !idSets.Has(int(col.UniqueID)) && col.UniqueID != groupingIDCol.UniqueID { + normalGbyCols = append(normalGbyCols, col) + } + } + } + sumNDV := float64(0) + for _, groupingSet := range groupingSets { + // for every grouping set, pick its cols out, and combine with normal group cols to get the NDV. + groupingSetCols := groupingSet.ExtractCols() + groupingSetCols = append(groupingSetCols, normalGbyCols...) + NDV, _ := getColsNDVWithMatchedLen(groupingSetCols, childSchema, childStats) + sumNDV += NDV + } + // After group operator, all same rows are grouped into one row, that means all + // change the sub-agg's stats + if p.stats != nil { + // equivalence to a new cloned one. (cause finalAgg and partialAgg may share a same copy of stats) + cpStats := p.stats.Scale(1) + cpStats.RowCount = sumNDV + // We cannot estimate the ColNDVs for every output, so we use a conservative strategy. + for k, _ := range cpStats.ColNDVs { + cpStats.ColNDVs[k] = sumNDV + } + // for old groupNDV, if it's containing one more grouping set cols, just plus the NDV where the col is excluded. + // for example: old grouping NDV(b,c), where b is in grouping sets {},{}. so when countering the new NDV: + // cases: + // new grouping NDV(b,c) := old NDV(b,c) + NDV(null, c) = old NDV(b,c) + DNV(c). + // new grouping NDV(a,b,c) := old NDV(a,b,c) + NDV(null,b,c) + NDV(a,null,c) = old NDV(a,b,c) + NDV(b,c) + NDV(a,c) + allGroupingSetsIDs := groupingSets.AllSetsColIDs() + for _, oneGNDV := range cpStats.GroupNDVs { + newGNDV := oneGNDV.NDV + intersectionIDs := make([]int64, 0, len(oneGNDV.Cols)) + for i, id := range oneGNDV.Cols { + if allGroupingSetsIDs.Has(int(id)) { + // when meet an id in grouping sets, skip it (cause its null) and append the rest ids to count the incrementNDV. + beforeLen := len(intersectionIDs) + intersectionIDs = append(intersectionIDs, oneGNDV.Cols[i:]...) + incrementNDV, _ := getColsDNVWithMatchedLenFromUniqueIDs(intersectionIDs, childSchema, childStats) + newGNDV += incrementNDV + // restore the before intersectionIDs slice. + intersectionIDs = intersectionIDs[:beforeLen] + } + // insert ids one by one. + intersectionIDs = append(intersectionIDs, id) + } + oneGNDV.NDV = newGNDV + } + p.stats = cpStats + } +} + func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan, canUse3StageAgg bool, groupingSets expression.GroupingSets, mpp *mppTask) (final, mid, proj1, part, proj2 PhysicalPlan, _ error) { // generate 3 stage aggregation for single/multi count distinct if applicable. @@ -2160,9 +2246,9 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan // HashAgg sum(#1), sum(#2), sum(#3) -> final agg // +- Exchange Passthrough // +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg - // +- Exchange HashPartition by a + // +- Exchange HashPartition by a,b,groupingID // +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg - // +- RepeatSource a, b -> repeat source + // +- RepeatSource {}, {} -> repeat source // +- TableScan foo // // set the default expression to constant 1 for the convenience to choose default group set data. @@ -2173,27 +2259,26 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan // single distinct agg mode can eliminate repeatSource op insertion. // physical plan is enumerated without children from itself, use mpp subtree instead p.children. - stats := mpp.p.statsInfo().Scale(float64(len(groupingSets))) + // scale(len(groupingSets)) will change the NDV, while Repeat doesn't change the NDV and groupNDV. + stats := mpp.p.statsInfo().Scale(float64(1)) + stats.RowCount = stats.RowCount * float64(len(groupingSets)) physicalRepeatSource := PhysicalRepeatSource{ GroupingSets: groupingSets, }.Init(p.ctx, stats, mpp.p.SelectBlockOffset()) // generate a new column as groupingID to identify which this row is targeting for. + tp := types.NewFieldType(mysql.TypeLonglong) + tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) groupingIDCol = &expression.Column{ UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: types.NewFieldType(mysql.TypeLonglong), + RetType: tp, } // append the physical repeatSource op with groupingID column. physicalRepeatSource.SetSchema(mpp.p.Schema().Clone()) physicalRepeatSource.schema.Append(groupingIDCol.(*expression.Column)) physicalRepeatSource.GroupingIDCol = groupingIDCol.(*expression.Column) - // set attach mpp task call back, aim to expand the input rowcount of stats for upper ops. - //mpp.attachTaskCallBack = func(p PhysicalPlan) { - // p.statsInfo().RowCount = p.statsInfo().RowCount * float64(len(groupingSets)) - //} - // attach PhysicalRepeatSource to mpp attachPlan2Task(physicalRepeatSource, mpp) } @@ -2261,6 +2346,11 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan if err != nil { return nil, nil, nil, nil, nil, err } + cloneHashAgg := clonedAgg.(*PhysicalHashAgg) + // Clone(), it will share same base-plan elements from the finalAgg, including id,tp,stats. Make a new one here. + cloneHashAgg.basePlan = newBasePlan(cloneHashAgg.ctx, cloneHashAgg.tp, cloneHashAgg.blockOffset) + cloneHashAgg.stats = finalAgg.Stats() // reuse the final agg stats here. + // select count(distinct a), count(distinct b), count(c) from t // final agg : sum(#1), sum(#2), sum(#3) // | | +--------------------------------------------+ @@ -2289,15 +2379,19 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. // Since we may substitute the first arg of normal agg with case-when expression here, append a // customized proj here rather than depend on postOptimize to insert a blunt one for us. - partialHashAgg := partialAgg.(*PhysicalHashAgg) - partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) - partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) + // // proj4Partial output all the base col from lower op + caseWhen proj cols. proj4Partial := new(PhysicalProjection).Init(p.ctx, mpp.p.statsInfo(), mpp.p.SelectBlockOffset()) for _, col := range mpp.p.Schema().Columns { proj4Partial.Exprs = append(proj4Partial.Exprs, col) } proj4Partial.SetSchema(mpp.p.Schema().Clone()) + + partialHashAgg := partialAgg.(*PhysicalHashAgg) + partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) + partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) + // it will create a new stats for partial agg. + partialHashAgg.scaleStats4GroupingSets(groupingSets, groupingIDCol.(*expression.Column), proj4Partial.Schema(), proj4Partial.statsInfo()) for _, fun := range partialHashAgg.AggFuncs { if !fun.HasDistinct { // for normal agg phase1, we should also modify them to target for specified group data. @@ -2314,11 +2408,12 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan } // step2: adjust middle agg - middleHashAgg := clonedAgg.(*PhysicalHashAgg) + // proj4Middle output all the base col from lower op + caseWhen proj cols.(reuse partial agg stats) + proj4Middle := new(PhysicalProjection).Init(p.ctx, partialHashAgg.stats, partialHashAgg.SelectBlockOffset()) + // middleHashAgg shared the same stats with the final agg does. + middleHashAgg := cloneHashAgg middleSchema := expression.NewSchema() schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) - // proj4Middle output all the base col from lower op + caseWhen proj cols. - proj4Middle := new(PhysicalProjection).Init(p.ctx, partialHashAgg.stats, partialHashAgg.SelectBlockOffset()) for _, col := range partialHashAgg.Schema().Columns { proj4Middle.Exprs = append(proj4Middle.Exprs, col) } @@ -2603,10 +2698,6 @@ func (p *PhysicalWindow) attach2Task(tasks ...task) task { type mppTask struct { p PhysicalPlan - // derive stats is done before physical plan enumeration, which when repeat source is inserted, - // the all upper layer physical op's base stats should be changed. - attachTaskCallBack func(PhysicalPlan) - partTp property.MPPPartitionType hashCols []*property.MPPPartitionColumn @@ -2700,8 +2791,7 @@ func (t *mppTask) convertToRootTaskImpl(ctx sessionctx.Context) *rootTask { p.stats = t.p.statsInfo() collectPartitionInfosFromMPPPlan(p, t.p) rt := &rootTask{ - p: p, - attachTaskCallBack: t.attachTaskCallBack, + p: p, } if len(t.rootTaskConds) > 0 { From f71e92a335c7fd0d62e9637a967766c173c4f241 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Sun, 18 Dec 2022 17:29:52 +0800 Subject: [PATCH 03/25] address comments from yiding Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets.go | 12 ++++++------ planner/core/explain.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index dcdc9ac16bc31..bb1fde46f48ef 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -24,6 +24,12 @@ import ( "github.com/pingcap/tipb/go-tipb" ) +type GroupingSets []GroupingSet + +type GroupingSet []GroupingExprs + +type GroupingExprs []Expression + // Merge function will explore the internal grouping expressions and try to find the minimum grouping sets. (prefix merging) func (gs GroupingSets) Merge() GroupingSets { // for now, there is precondition that all grouping expressions are columns. @@ -278,12 +284,6 @@ func newGroupingSet(oneGroupingExpr GroupingExprs) GroupingSet { return res } -type GroupingSets []GroupingSet - -type GroupingSet []GroupingExprs - -type GroupingExprs []Expression - func (g GroupingExprs) IsEmpty() bool { return len(g) == 0 } diff --git a/planner/core/explain.go b/planner/core/explain.go index bca27228ed195..f622f75deec09 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -364,8 +364,8 @@ func (p *PhysicalRepeatSource) ExplainInfo() string { var str strings.Builder str.WriteString("group set num:") str.WriteString(strconv.FormatInt(int64(len(p.GroupingSets)), 10)) - str.WriteString(", groupingID:Column#") - str.WriteString(strconv.FormatInt(p.GroupingIDCol.UniqueID, 10)) + str.WriteString(", groupingID:") + str.WriteString(p.GroupingIDCol.String()) str.WriteString(", ") str.WriteString(p.GroupingSets.String()) return str.String() From 17244a443ae0c5a26c9fdd359292307a9f8185a8 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Mon, 19 Dec 2022 14:06:47 +0800 Subject: [PATCH 04/25] change the name from RepeatSource to Expand Signed-off-by: AilinKid <314806019@qq.com> --- executor/partition_table.go | 4 +-- expression/grouping_sets.go | 2 +- planner/core/explain.go | 2 +- planner/core/physical_plans.go | 16 +++++----- planner/core/plan_to_pb.go | 8 ++--- planner/core/resolve_indices.go | 4 +-- planner/core/task.go | 30 ++++++++----------- store/mockstore/unistore/cophandler/mpp.go | 8 ++--- .../mockstore/unistore/cophandler/mpp_exec.go | 12 ++++---- util/plancodec/id.go | 14 ++++----- 10 files changed, 48 insertions(+), 52 deletions(-) diff --git a/executor/partition_table.go b/executor/partition_table.go index 7397275aec539..714cdc206b95e 100644 --- a/executor/partition_table.go +++ b/executor/partition_table.go @@ -56,8 +56,8 @@ func updateExecutorTableID(ctx context.Context, exec *tipb.Executor, recursive b child = exec.Window.Child case tipb.ExecType_TypeSort: child = exec.Sort.Child - case tipb.ExecType_TypeRepeatSource: - child = exec.RepeatSource.Child + case tipb.ExecType_TypeExpand: + child = exec.Expand.Child default: return errors.Trace(fmt.Errorf("unknown new tipb protocol %d", exec.Tp)) } diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index bb1fde46f48ef..a8e22dc4f973a 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -143,7 +143,7 @@ func (gs GroupingSets) TargetOne(targetOne GroupingExprs) int { // distinct(a,b) want for. For distinct(c), the same is true. // // For normal agg you better choose your targeting group-set data, otherwise, otherwise your get is all groups,most of them is endless null values filled by - // repeatSource operator, and null value can also influence your non null-strict normal agg, although you don't want them to. + // Expand operator, and null value can also influence your non null-strict normal agg, although you don't want them to. // // here we only consider 1&2 since normal agg with multi col args is banned. allGroupingColIDs := gs.AllSetsColIDs() diff --git a/planner/core/explain.go b/planner/core/explain.go index f622f75deec09..e9e4d7fa1877b 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -360,7 +360,7 @@ func (p *PhysicalLimit) ExplainInfo() string { } // ExplainInfo implements Plan interface. -func (p *PhysicalRepeatSource) ExplainInfo() string { +func (p *PhysicalExpand) ExplainInfo() string { var str strings.Builder str.WriteString("group set num:") str.WriteString(strconv.FormatInt(int64(len(p.GroupingSets)), 10)) diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index cb26f7444e465..84d6e39fd667c 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -1496,8 +1496,8 @@ func (p *PhysicalExchangeReceiver) MemoryUsage() (sum int64) { return } -// PhysicalRepeatSource is used to repeat underlying data sources to feed different grouping sets. -type PhysicalRepeatSource struct { +// PhysicalExpand is used to expand underlying data sources to feed different grouping sets. +type PhysicalExpand struct { // data after repeat-OP will generate a new grouping-ID column to indicate what grouping set is it for. physicalSchemaProducer @@ -1510,15 +1510,15 @@ type PhysicalRepeatSource struct { } // Init only assigns type and context. -func (p PhysicalRepeatSource) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *PhysicalRepeatSource { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeRepeatSource, &p, offset) +func (p PhysicalExpand) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *PhysicalExpand { + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeExpand, &p, offset) p.stats = stats return &p } // Clone implements PhysicalPlan interface. -func (p *PhysicalRepeatSource) Clone() (PhysicalPlan, error) { - np := new(PhysicalRepeatSource) +func (p *PhysicalExpand) Clone() (PhysicalPlan, error) { + np := new(PhysicalExpand) base, err := p.basePhysicalPlan.cloneWithSelf(np) if err != nil { return nil, errors.Trace(err) @@ -1536,8 +1536,8 @@ func (p *PhysicalRepeatSource) Clone() (PhysicalPlan, error) { return np, nil } -// MemoryUsage return the memory usage of PhysicalRepeatSource -func (p *PhysicalRepeatSource) MemoryUsage() (sum int64) { +// MemoryUsage return the memory usage of PhysicalExpand +func (p *PhysicalExpand) MemoryUsage() (sum int64) { if p == nil { return } diff --git a/planner/core/plan_to_pb.go b/planner/core/plan_to_pb.go index 4b5269f262ccc..8b385d3d86a43 100644 --- a/planner/core/plan_to_pb.go +++ b/planner/core/plan_to_pb.go @@ -34,26 +34,26 @@ func (p *basePhysicalPlan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Exe } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalRepeatSource) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalExpand) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { sc := ctx.GetSessionVars().StmtCtx client := ctx.GetClient() groupingSetsPB, err := p.GroupingSets.ToPB(sc, client) if err != nil { return nil, err } - repeatSource := &tipb.RepeatSource{ + expand := &tipb.Expand{ GroupingSets: groupingSetsPB, } executorID := "" if storeType == kv.TiFlash { var err error - repeatSource.Child, err = p.children[0].ToPB(ctx, storeType) + expand.Child, err = p.children[0].ToPB(ctx, storeType) if err != nil { return nil, errors.Trace(err) } executorID = p.ExplainID().String() } - return &tipb.Executor{Tp: tipb.ExecType_TypeRepeatSource, RepeatSource: repeatSource, ExecutorId: &executorID}, nil + return &tipb.Executor{Tp: tipb.ExecType_TypeExpand, Expand: expand, ExecutorId: &executorID}, nil } // ToPB implements PhysicalPlan ToPB interface. diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index 451cf149f3568..a38644ec07564 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -412,7 +412,7 @@ func (p *PhysicalExchangeSender) ResolveIndices() (err error) { } // ResolveIndicesItself resolve indices for PhysicalPlan itself -func (p *PhysicalRepeatSource) ResolveIndicesItself() error { +func (p *PhysicalExpand) ResolveIndicesItself() error { for _, gs := range p.GroupingSets { for _, groupingExprs := range gs { for k, groupingExpr := range groupingExprs { @@ -428,7 +428,7 @@ func (p *PhysicalRepeatSource) ResolveIndicesItself() error { } // ResolveIndices implements Plan interface. -func (p *PhysicalRepeatSource) ResolveIndices() (err error) { +func (p *PhysicalExpand) ResolveIndices() (err error) { err = p.physicalSchemaProducer.ResolveIndices() if err != nil { return err diff --git a/planner/core/task.go b/planner/core/task.go index 0870db47b9038..787a296ad9980 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -790,10 +790,6 @@ type rootTask struct { p PhysicalPlan isEmpty bool // isEmpty indicates if this task contains a dual table and returns empty data. // TODO: The flag 'isEmpty' is only checked by Projection and UnionAll. We should support more cases in the future. - - // derive stats is done before physical plan enumeration, which when repeat source is inserted, - // the all upper layer physical op's base stats should be changed. - attachTaskCallBack func(PhysicalPlan) } func (t *rootTask) copy() task { @@ -2151,11 +2147,11 @@ func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *mppTask) task { // the Mid and Partial Agg, which means when leaving the Final, its output rowcount could be exactly // the same as what it derived(estimated) before entering physical optimization phase. // -// From the cost model correctness, for these inserted sub-agg and even repeat source operator, we should +// From the cost model correctness, for these inserted sub-agg and even expand operator, we should // recompute the stats for them particularly. // // for example: grouping sets {},{}, group by items {a,b,c,groupingID} -// after repeat source: +// after expand: // // a, b, c, groupingID // ... null c 1 ---+ @@ -2248,22 +2244,22 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan // +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg // +- Exchange HashPartition by a,b,groupingID // +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg - // +- RepeatSource {}, {} -> repeat source + // +- Expand {}, {} -> expand // +- TableScan foo // // set the default expression to constant 1 for the convenience to choose default group set data. var groupingIDCol expression.Expression groupingIDCol = expression.NewNumber(1) if len(groupingSets) > 0 { - // enforce the repeatSource operator above the children. - // single distinct agg mode can eliminate repeatSource op insertion. + // enforce Expand operator above the children. + // single distinct agg mode can eliminate expand op insertion. // physical plan is enumerated without children from itself, use mpp subtree instead p.children. // scale(len(groupingSets)) will change the NDV, while Repeat doesn't change the NDV and groupNDV. stats := mpp.p.statsInfo().Scale(float64(1)) stats.RowCount = stats.RowCount * float64(len(groupingSets)) - physicalRepeatSource := PhysicalRepeatSource{ + physicalExpand := PhysicalExpand{ GroupingSets: groupingSets, }.Init(p.ctx, stats, mpp.p.SelectBlockOffset()) @@ -2274,13 +2270,13 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), RetType: tp, } - // append the physical repeatSource op with groupingID column. - physicalRepeatSource.SetSchema(mpp.p.Schema().Clone()) - physicalRepeatSource.schema.Append(groupingIDCol.(*expression.Column)) - physicalRepeatSource.GroupingIDCol = groupingIDCol.(*expression.Column) + // append the physical expand op with groupingID column. + physicalExpand.SetSchema(mpp.p.Schema().Clone()) + physicalExpand.schema.Append(groupingIDCol.(*expression.Column)) + physicalExpand.GroupingIDCol = groupingIDCol.(*expression.Column) - // attach PhysicalRepeatSource to mpp - attachPlan2Task(physicalRepeatSource, mpp) + // attach PhysicalExpand to mpp + attachPlan2Task(physicalExpand, mpp) } if partialAgg != nil && canUse3StageAgg { @@ -2372,7 +2368,7 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan // v // proj4Partial: a, b, c, groupingID, caseWhen4c # caseWhen4c depend on groupingID and c // - // repeatSource: a, b, c, groupingID # appended groupingID to diff group set + // expand: a, b, c, groupingID # appended groupingID to diff group set // // lower source: a, b, c # don't care what the lower source is diff --git a/store/mockstore/unistore/cophandler/mpp.go b/store/mockstore/unistore/cophandler/mpp.go index 785e63fdeea18..eec358c038ba9 100644 --- a/store/mockstore/unistore/cophandler/mpp.go +++ b/store/mockstore/unistore/cophandler/mpp.go @@ -200,12 +200,12 @@ func (b *mppExecBuilder) buildLimit(pb *tipb.Limit) (*limitExec, error) { return exec, nil } -func (b *mppExecBuilder) buildRepeatSource(pb *tipb.RepeatSource) (mppExec, error) { +func (b *mppExecBuilder) buildExpand(pb *tipb.Expand) (mppExec, error) { child, err := b.buildMPPExecutor(pb.Child) if err != nil { return nil, err } - exec := &repeatSourceExec{ + exec := &expandExec{ baseMPPExec: baseMPPExec{sc: b.sc, mppCtx: b.mppCtx, children: []mppExec{child}}, } @@ -551,8 +551,8 @@ func (b *mppExecBuilder) buildMPPExecutor(exec *tipb.Executor) (mppExec, error) case tipb.ExecType_TypePartitionTableScan: ts := exec.PartitionTableScan return b.buildMPPPartitionTableScan(ts) - case tipb.ExecType_TypeRepeatSource: - return b.buildRepeatSource(exec.RepeatSource) + case tipb.ExecType_TypeExpand: + return b.buildExpand(exec.Expand) default: return nil, errors.Errorf(ErrExecutorNotSupportedMsg + exec.Tp.String()) } diff --git a/store/mockstore/unistore/cophandler/mpp_exec.go b/store/mockstore/unistore/cophandler/mpp_exec.go index f69c904239658..1a2f504afb4a4 100644 --- a/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/store/mockstore/unistore/cophandler/mpp_exec.go @@ -416,9 +416,9 @@ func (e *limitExec) next() (*chunk.Chunk, error) { return chk, nil } -// repeatSourceExec is the basic mock logic for repeat executor in uniStore, +// expandExec is the basic mock logic for expand executor in uniStore, // with which we can validate the mpp plan correctness from explain result and result returned. -type repeatSourceExec struct { +type expandExec struct { baseMPPExec lastNum int @@ -428,7 +428,7 @@ type repeatSourceExec struct { groupingSetScope map[int]struct{} } -func (e *repeatSourceExec) open() error { +func (e *expandExec) open() error { var err error err = e.children[0].open() if err != nil { @@ -455,14 +455,14 @@ func (e *repeatSourceExec) open() error { return nil } -func (e *repeatSourceExec) isGroupingCol(index int) bool { +func (e *expandExec) isGroupingCol(index int) bool { if _, ok := e.groupingSetScope[index]; ok { return true } return false } -func (e *repeatSourceExec) next() (*chunk.Chunk, error) { +func (e *expandExec) next() (*chunk.Chunk, error) { var ( err error ) @@ -489,7 +489,7 @@ func (e *repeatSourceExec) next() (*chunk.Chunk, error) { for i := e.lastNum; i < numRows; i++ { row := e.lastChunk.GetRow(i) e.lastNum++ - // for every grouping set, repeat the base row N times. + // for every grouping set, expand the base row N times. for g := 0; g < numGroupingOffset; g++ { repeatRow := chunk.MutRowFromTypes(e.fieldTypes) // for every targeted grouping set: diff --git a/util/plancodec/id.go b/util/plancodec/id.go index 5e107a9d62ce2..20ad0e4301cfa 100644 --- a/util/plancodec/id.go +++ b/util/plancodec/id.go @@ -57,8 +57,8 @@ const ( TypeExchangeSender = "ExchangeSender" // TypeExchangeReceiver is the type of mpp exchanger receiver. TypeExchangeReceiver = "ExchangeReceiver" - // TypeRepeatSource is the type of mpp repeat source operator. - TypeRepeatSource = "RepeatSource" + // TypeExpand is the type of mpp expand source operator. + TypeExpand = "Expand" // TypeMergeJoin is the type of merge join. TypeMergeJoin = "MergeJoin" // TypeIndexJoin is the type of index look up join. @@ -195,7 +195,7 @@ const ( typeShuffleReceiverID int = 55 typeForeignKeyCheck int = 56 typeForeignKeyCascade int = 57 - typeRepeatSourceID int = 58 + typeExpandID int = 58 ) // TypeStringToPhysicalID converts the plan type string to plan id. @@ -315,8 +315,8 @@ func TypeStringToPhysicalID(tp string) int { return typeForeignKeyCheck case TypeForeignKeyCascade: return typeForeignKeyCascade - case TypeRepeatSource: - return typeRepeatSourceID + case TypeExpand: + return typeExpandID } // Should never reach here. return 0 @@ -439,8 +439,8 @@ func PhysicalIDToTypeString(id int) string { return TypeForeignKeyCheck case typeForeignKeyCascade: return TypeForeignKeyCascade - case typeRepeatSourceID: - return TypeRepeatSource + case typeExpandID: + return TypeExpand } // Should never reach here. From 2829859e826b2ecd61f7861ae8ecaa3f6f32ab0a Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Mon, 19 Dec 2022 14:14:48 +0800 Subject: [PATCH 05/25] fix the test name from repeat source to expand Signed-off-by: AilinKid <314806019@qq.com> --- planner/core/casetest/testdata/enforce_mpp_suite_out.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/testdata/enforce_mpp_suite_out.json index fa7780097826d..b099bd678fe3d 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_out.json @@ -1245,7 +1245,7 @@ " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#13, collate: binary]", " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#13, test.t.a, test.t.b, ", - " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", + " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null @@ -1272,7 +1272,7 @@ " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#18, test.t.a, test.t.b, funcs:count(Column#19)->Column#17", " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", - " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", + " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null @@ -1300,7 +1300,7 @@ " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#26, Column#27, Column#28, funcs:sum(Column#25)->Column#17", " └─Projection_39 20000.00 mpp[tiflash] cast(Column#19, decimal(10,0) BINARY)->Column#25, test.t.a, test.t.b, Column#18", " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", - " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", + " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null @@ -1328,7 +1328,7 @@ " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#34, Column#35, Column#36, funcs:count(Column#32)->Column#21, funcs:sum(Column#33)->Column#22", " └─Projection_39 20000.00 mpp[tiflash] Column#24, cast(Column#25, decimal(10,0) BINARY)->Column#33, test.t.a, test.t.b, Column#23", " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#23, case(eq(Column#23, 1), test.t.c, )->Column#24, case(eq(Column#23, 1), test.t.d, )->Column#25", - " └─RepeatSource_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#23, [{},{}]", + " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#23, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": [ From 7ba7c64dca198efe9947af52f7efe34fe7af7373 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Tue, 27 Dec 2022 23:09:32 +0800 Subject: [PATCH 06/25] update go mod Signed-off-by: AilinKid <314806019@qq.com> --- go.mod | 1 - 1 file changed, 1 deletion(-) diff --git a/go.mod b/go.mod index 016d1fb1a9e6e..c3948c93c1d05 100644 --- a/go.mod +++ b/go.mod @@ -275,5 +275,4 @@ replace ( github.com/dgrijalva/jwt-go => github.com/form3tech-oss/jwt-go v3.2.6-0.20210809144907-32ab6a8243d7+incompatible github.com/pingcap/tidb/parser => ./parser go.opencensus.io => go.opencensus.io v0.23.1-0.20220331163232-052120675fac - github.com/pingcap/tipb => ../tipb ) From 162d5e15f372795fbc75af2c75e9cec36547562d Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Tue, 27 Dec 2022 23:50:43 +0800 Subject: [PATCH 07/25] address comment Signed-off-by: AilinKid <314806019@qq.com> --- expression/aggregation/descriptor.go | 2 +- expression/constant.go | 4 +- planner/core/task.go | 485 +++++++++++++-------------- 3 files changed, 242 insertions(+), 249 deletions(-) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 60d23072cd399..3894355a2c546 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -42,7 +42,7 @@ type AggFuncDesc struct { HasDistinct bool // OrderByItems represents the order by clause used in GROUP_CONCAT OrderByItems []*util.ByItems - // Grouping Set ID, for distinguishing with not-set 0, starting from 1. + // GroupingID is used for distinguishing with not-set 0, starting from 1. GroupingID int64 } diff --git a/expression/constant.go b/expression/constant.go index cb2b5e5aac2e4..a04b71e639d59 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -52,8 +52,8 @@ func NewZero() *Constant { } } -// NewNumber stands for constant of a given number. -func NewNumber(num int64) *Constant { +// NewUInt64Const stands for constant of a given number. +func NewUInt64Const(num int64) *Constant { retT := types.NewFieldType(mysql.TypeLonglong) retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion retT.SetFlen(mysql.MaxIntWidth) diff --git a/planner/core/task.go b/planner/core/task.go index 787a296ad9980..06d7fa3f3e1e1 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1555,7 +1555,6 @@ func BuildFinalModeAggregation( finalAggFunc.OrderByItems = aggFunc.OrderByItems args := make([]expression.Expression, 0, len(aggFunc.Args)) if aggFunc.HasDistinct { - // 感觉这里是变下推到 kv 的 agg 结构的 /* eg: SELECT COUNT(DISTINCT a), SUM(b) FROM t GROUP BY c @@ -1567,7 +1566,6 @@ func BuildFinalModeAggregation( */ // onlyAddFirstRow means if the distinctArg does not occur in group by items, // it should be replaced with a firstrow() agg function, needed for the order by items of group_concat() - // 这个地方也是我们所看到的,为啥 distinct agg column 会出现在 group by 中的原因 getDistinctExpr := func(distinctArg expression.Expression, onlyAddFirstRow bool) (ret expression.Expression) { // 1. add all args to partial.GroupByItems foundInGroupBy := false @@ -1924,7 +1922,7 @@ func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss express } else if len(fun.Args) > 1 { return false, nil } - // banned count(distinct x order by y) + // banned group_concat(x order by y) if len(fun.OrderByItems) > 0 || fun.Mode != aggregation.CompleteMode { return false, nil } @@ -2224,260 +2222,256 @@ func (p *PhysicalHashAgg) scaleStats4GroupingSets(groupingSets expression.Groupi } } +// adjust3StagePhaseAgg generate 3 stage aggregation for single/multi count distinct if applicable. +// +// select count(distinct a), count(b) from foo +// +// will generate plan: +// +// HashAgg sum(#1), sum(#2) -> final agg +// +- Exchange Passthrough +// +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg +// +- Exchange HashPartition by a +// +- HashAgg count(b) #3, group by a -> partial agg +// +- TableScan foo +// +// select count(distinct a), count(distinct b), count(c) from foo +// +// will generate plan: +// +// HashAgg sum(#1), sum(#2), sum(#3) -> final agg +// +- Exchange Passthrough +// +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg +// +- Exchange HashPartition by a,b,groupingID +// +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg +// +- Expand {}, {} -> expand +// +- TableScan foo func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan, canUse3StageAgg bool, groupingSets expression.GroupingSets, mpp *mppTask) (final, mid, proj1, part, proj2 PhysicalPlan, _ error) { - // generate 3 stage aggregation for single/multi count distinct if applicable. + if !(partialAgg != nil && canUse3StageAgg) { + // quick path: return the original finalAgg and partiAgg + return finalAgg, nil, nil, partialAgg, nil, nil + } + if len(groupingSets) == 0 { + // single distinct agg mode. + clonedAgg, err := finalAgg.Clone() + if err != nil { + return nil, nil, nil, nil, nil, err + } + + // step1: adjust middle agg + middleHashAgg := clonedAgg.(*PhysicalHashAgg) + distinctPos := 0 + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + for i, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.RetTp, + } + if fun.HasDistinct { + distinctPos = i + fun.Mode = aggregation.Partial1Mode + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) + } + middleHashAgg.schema = middleSchema + + // step2: adjust final agg + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if distinctPos == i { + // change count(distinct) to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + newArgs = append(newArgs, middleSchema.Columns[i]) + } else { + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, nil, err + } + newArgs = append(newArgs, newCol) + } + } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + finalAggDescs = append(finalAggDescs, fun) + } + finalHashAgg.AggFuncs = finalAggDescs + // partialAgg is im-mutated from args + return finalHashAgg, middleHashAgg, nil, partialAgg, nil, nil + } + // multi distinct agg mode, having grouping sets. + // set the default expression to constant 1 for the convenience to choose default group set data. + var groupingIDCol expression.Expression + groupingIDCol = expression.NewUInt64Const(1) + // enforce Expand operator above the children. + // physical plan is enumerated without children from itself, use mpp subtree instead p.children. + // scale(len(groupingSets)) will change the NDV, while Repeat doesn't change the NDV and groupNDV. + stats := mpp.p.statsInfo().Scale(float64(1)) + stats.RowCount = stats.RowCount * float64(len(groupingSets)) + physicalExpand := PhysicalExpand{ + GroupingSets: groupingSets, + }.Init(p.ctx, stats, mpp.p.SelectBlockOffset()) + // generate a new column as groupingID to identify which this row is targeting for. + tp := types.NewFieldType(mysql.TypeLonglong) + tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) + groupingIDCol = &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: tp, + } + // append the physical expand op with groupingID column. + physicalExpand.SetSchema(mpp.p.Schema().Clone()) + physicalExpand.schema.Append(groupingIDCol.(*expression.Column)) + physicalExpand.GroupingIDCol = groupingIDCol.(*expression.Column) + // attach PhysicalExpand to mpp + attachPlan2Task(physicalExpand, mpp) + + // having group sets + clonedAgg, err := finalAgg.Clone() + if err != nil { + return nil, nil, nil, nil, nil, err + } + cloneHashAgg := clonedAgg.(*PhysicalHashAgg) + // Clone(), it will share same base-plan elements from the finalAgg, including id,tp,stats. Make a new one here. + cloneHashAgg.basePlan = newBasePlan(cloneHashAgg.ctx, cloneHashAgg.tp, cloneHashAgg.blockOffset) + cloneHashAgg.stats = finalAgg.Stats() // reuse the final agg stats here. + + // select count(distinct a), count(distinct b), count(c) from t + // final agg : sum(#1), sum(#2), sum(#3) + // | | +--------------------------------------------+ + // | +--------------------------------+ | + // +-------------+ | | + // v v v + // middle agg : count(distinct caseWhen4a) as #1, count(distinct caseWhen4b) as #2, sum(#4) as #3 # all partial result + // | | + // +---------------------+ | + // v v + // proj4Middle : a, b, groupingID, #4(partial res), caseWhen4a, caseWhen4b # caseWhen4a & caseWhen4a depend on groupingID and a, b + // | + // | + // exchange hash partition: # shuffle data by + // | + // v + // partial agg : a, b, groupingID, count(caseWhen4c) as #4 # since group by a, b, groupingID we can directly output all this three. + // | + // v + // proj4Partial: a, b, c, groupingID, caseWhen4c # caseWhen4c depend on groupingID and c // - // select count(distinct a), count(b) from foo - // will generate plan: - // HashAgg sum(#1), sum(#2) -> final agg - // +- Exchange Passthrough - // +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg - // +- Exchange HashPartition by a - // +- HashAgg count(b) #3, group by a -> partial agg - // +- TableScan foo + // expand: a, b, c, groupingID # appended groupingID to diff group set // - // select count(distinct a), count(distinct b), count(c) from foo - // will generate plan: - // HashAgg sum(#1), sum(#2), sum(#3) -> final agg - // +- Exchange Passthrough - // +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg - // +- Exchange HashPartition by a,b,groupingID - // +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg - // +- Expand {}, {} -> expand - // +- TableScan foo + // lower source: a, b, c # don't care what the lower source is + + // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. + // Since we may substitute the first arg of normal agg with case-when expression here, append a + // customized proj here rather than depend on postOptimize to insert a blunt one for us. // - // set the default expression to constant 1 for the convenience to choose default group set data. - var groupingIDCol expression.Expression - groupingIDCol = expression.NewNumber(1) - if len(groupingSets) > 0 { - // enforce Expand operator above the children. - // single distinct agg mode can eliminate expand op insertion. - - // physical plan is enumerated without children from itself, use mpp subtree instead p.children. - // scale(len(groupingSets)) will change the NDV, while Repeat doesn't change the NDV and groupNDV. - stats := mpp.p.statsInfo().Scale(float64(1)) - stats.RowCount = stats.RowCount * float64(len(groupingSets)) - - physicalExpand := PhysicalExpand{ - GroupingSets: groupingSets, - }.Init(p.ctx, stats, mpp.p.SelectBlockOffset()) - - // generate a new column as groupingID to identify which this row is targeting for. - tp := types.NewFieldType(mysql.TypeLonglong) - tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) - groupingIDCol = &expression.Column{ + // proj4Partial output all the base col from lower op + caseWhen proj cols. + proj4Partial := new(PhysicalProjection).Init(p.ctx, mpp.p.statsInfo(), mpp.p.SelectBlockOffset()) + for _, col := range mpp.p.Schema().Columns { + proj4Partial.Exprs = append(proj4Partial.Exprs, col) + } + proj4Partial.SetSchema(mpp.p.Schema().Clone()) + + partialHashAgg := partialAgg.(*PhysicalHashAgg) + partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) + partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) + // it will create a new stats for partial agg. + partialHashAgg.scaleStats4GroupingSets(groupingSets, groupingIDCol.(*expression.Column), proj4Partial.Schema(), proj4Partial.statsInfo()) + for _, fun := range partialHashAgg.AggFuncs { + if !fun.HasDistinct { + // for normal agg phase1, we should also modify them to target for specified group data. + eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) + caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) + caseWhenProjCol := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.Args[0].GetType(), + } + proj4Partial.Exprs = append(proj4Partial.Exprs, caseWhen) + proj4Partial.Schema().Append(caseWhenProjCol) + fun.Args[0] = caseWhenProjCol + } + } + + // step2: adjust middle agg + // proj4Middle output all the base col from lower op + caseWhen proj cols.(reuse partial agg stats) + proj4Middle := new(PhysicalProjection).Init(p.ctx, partialHashAgg.stats, partialHashAgg.SelectBlockOffset()) + // middleHashAgg shared the same stats with the final agg does. + middleHashAgg := cloneHashAgg + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + for _, col := range partialHashAgg.Schema().Columns { + proj4Middle.Exprs = append(proj4Middle.Exprs, col) + } + proj4Middle.SetSchema(partialHashAgg.Schema().Clone()) + for _, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: tp, + RetType: fun.RetTp, } - // append the physical expand op with groupingID column. - physicalExpand.SetSchema(mpp.p.Schema().Clone()) - physicalExpand.schema.Append(groupingIDCol.(*expression.Column)) - physicalExpand.GroupingIDCol = groupingIDCol.(*expression.Column) - - // attach PhysicalExpand to mpp - attachPlan2Task(physicalExpand, mpp) + if fun.HasDistinct { + fun.Mode = aggregation.Partial1Mode + // change the middle distinct agg args to target for specified group data. + // even for distinct agg multi args like count(distinct a,b), either null of a,b will cause count func to skip this row. + // so we only need to set the either one arg of them to be a case when expression to target for its self-group: + // Expr = (case when groupingID = ? then arg else null end) + // count(distinct a,b) => count(distinct (case when groupingID = targetID then a else null end), b) + eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) + caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) + caseWhenProjCol := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.Args[0].GetType(), + } + proj4Middle.Exprs = append(proj4Middle.Exprs, caseWhen) + proj4Middle.Schema().Append(caseWhenProjCol) + fun.Args[0] = caseWhenProjCol + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // record the origin column unique id down before change it to be case when expr. + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) } + middleHashAgg.schema = middleSchema - if partialAgg != nil && canUse3StageAgg { - if groupingSets == nil { - // single distinct agg mode. - clonedAgg, err := finalAgg.Clone() - if err != nil { - return nil, nil, nil, nil, nil, err - } - - // step1: adjust middle agg - middleHashAgg := clonedAgg.(*PhysicalHashAgg) - distinctPos := 0 - middleSchema := expression.NewSchema() - schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) - for i, fun := range middleHashAgg.AggFuncs { - col := &expression.Column{ - UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: fun.RetTp, - } - if fun.HasDistinct { - distinctPos = i - fun.Mode = aggregation.Partial1Mode - } else { - fun.Mode = aggregation.Partial2Mode - originalCol := fun.Args[0].(*expression.Column) - // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) - schemaMap[originalCol.UniqueID] = col - } - middleSchema.Append(col) - } - middleHashAgg.schema = middleSchema - - // step2: adjust final agg - finalHashAgg := finalAgg.(*PhysicalHashAgg) - finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) - for i, fun := range finalHashAgg.AggFuncs { - newArgs := make([]expression.Expression, 0, 1) - if distinctPos == i { - // change count(distinct) to sum() - fun.Name = ast.AggFuncSum - fun.HasDistinct = false - newArgs = append(newArgs, middleSchema.Columns[i]) - } else { - for _, arg := range fun.Args { - newCol, err := arg.RemapColumn(schemaMap) - if err != nil { - return nil, nil, nil, nil, nil, err - } - newArgs = append(newArgs, newCol) - } - } - fun.Mode = aggregation.FinalMode - fun.Args = newArgs - finalAggDescs = append(finalAggDescs, fun) - } - finalHashAgg.AggFuncs = finalAggDescs - // partialAgg is im-mutated from args - return finalHashAgg, middleHashAgg, nil, partialAgg, nil, nil + // step3: adjust final agg + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if fun.HasDistinct { + // change count(distinct) agg to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + // count(distinct a,b) -> become a single partial result col. + newArgs = append(newArgs, middleSchema.Columns[i]) } else { - // having group sets - clonedAgg, err := finalAgg.Clone() - if err != nil { - return nil, nil, nil, nil, nil, err - } - cloneHashAgg := clonedAgg.(*PhysicalHashAgg) - // Clone(), it will share same base-plan elements from the finalAgg, including id,tp,stats. Make a new one here. - cloneHashAgg.basePlan = newBasePlan(cloneHashAgg.ctx, cloneHashAgg.tp, cloneHashAgg.blockOffset) - cloneHashAgg.stats = finalAgg.Stats() // reuse the final agg stats here. - - // select count(distinct a), count(distinct b), count(c) from t - // final agg : sum(#1), sum(#2), sum(#3) - // | | +--------------------------------------------+ - // | +--------------------------------+ | - // +-------------+ | | - // v v v - // middle agg : count(distinct caseWhen4a) as #1, count(distinct caseWhen4b) as #2, sum(#4) as #3 # all partial result - // | | - // +---------------------+ | - // v v - // proj4Middle : a, b, groupingID, #4(partial res), caseWhen4a, caseWhen4b # caseWhen4a & caseWhen4a depend on groupingID and a, b - // | - // | - // exchange hash partition: # shuffle data by - // | - // v - // partial agg : a, b, groupingID, count(caseWhen4c) as #4 # since group by a, b, groupingID we can directly output all this three. - // | - // v - // proj4Partial: a, b, c, groupingID, caseWhen4c # caseWhen4c depend on groupingID and c - // - // expand: a, b, c, groupingID # appended groupingID to diff group set - // - // lower source: a, b, c # don't care what the lower source is - - // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. - // Since we may substitute the first arg of normal agg with case-when expression here, append a - // customized proj here rather than depend on postOptimize to insert a blunt one for us. - // - // proj4Partial output all the base col from lower op + caseWhen proj cols. - proj4Partial := new(PhysicalProjection).Init(p.ctx, mpp.p.statsInfo(), mpp.p.SelectBlockOffset()) - for _, col := range mpp.p.Schema().Columns { - proj4Partial.Exprs = append(proj4Partial.Exprs, col) - } - proj4Partial.SetSchema(mpp.p.Schema().Clone()) - - partialHashAgg := partialAgg.(*PhysicalHashAgg) - partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) - partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) - // it will create a new stats for partial agg. - partialHashAgg.scaleStats4GroupingSets(groupingSets, groupingIDCol.(*expression.Column), proj4Partial.Schema(), proj4Partial.statsInfo()) - for _, fun := range partialHashAgg.AggFuncs { - if !fun.HasDistinct { - // for normal agg phase1, we should also modify them to target for specified group data. - eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewNumber(fun.GroupingID)) - caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) - caseWhenProjCol := &expression.Column{ - UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: fun.Args[0].GetType(), - } - proj4Partial.Exprs = append(proj4Partial.Exprs, caseWhen) - proj4Partial.Schema().Append(caseWhenProjCol) - fun.Args[0] = caseWhenProjCol - } - } - - // step2: adjust middle agg - // proj4Middle output all the base col from lower op + caseWhen proj cols.(reuse partial agg stats) - proj4Middle := new(PhysicalProjection).Init(p.ctx, partialHashAgg.stats, partialHashAgg.SelectBlockOffset()) - // middleHashAgg shared the same stats with the final agg does. - middleHashAgg := cloneHashAgg - middleSchema := expression.NewSchema() - schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) - for _, col := range partialHashAgg.Schema().Columns { - proj4Middle.Exprs = append(proj4Middle.Exprs, col) - } - proj4Middle.SetSchema(partialHashAgg.Schema().Clone()) - for _, fun := range middleHashAgg.AggFuncs { - col := &expression.Column{ - UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: fun.RetTp, - } - if fun.HasDistinct { - fun.Mode = aggregation.Partial1Mode - // change the middle distinct agg args to target for specified group data. - // even for distinct agg multi args like count(distinct a,b), either null of a,b will cause count func to skip this row. - // so we only need to set the either one arg of them to be a case when expression to target for its self-group: - // Expr = (case when groupingID = ? then arg else null end) - // count(distinct a,b) => count(distinct (case when groupingID = targetID then a else null end), b) - eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewNumber(fun.GroupingID)) - caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) - caseWhenProjCol := &expression.Column{ - UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: fun.Args[0].GetType(), - } - proj4Middle.Exprs = append(proj4Middle.Exprs, caseWhen) - proj4Middle.Schema().Append(caseWhenProjCol) - fun.Args[0] = caseWhenProjCol - } else { - fun.Mode = aggregation.Partial2Mode - originalCol := fun.Args[0].(*expression.Column) - // record the origin column unique id down before change it to be case when expr. - // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) - schemaMap[originalCol.UniqueID] = col - } - middleSchema.Append(col) - } - middleHashAgg.schema = middleSchema - - // step3: adjust final agg - finalHashAgg := finalAgg.(*PhysicalHashAgg) - finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) - for i, fun := range finalHashAgg.AggFuncs { - newArgs := make([]expression.Expression, 0, 1) - if fun.HasDistinct { - // change count(distinct) agg to sum() - fun.Name = ast.AggFuncSum - fun.HasDistinct = false - // count(distinct a,b) -> become a single partial result col. - newArgs = append(newArgs, middleSchema.Columns[i]) - } else { - // remap final normal agg args to be output schema of middle normal agg. - for _, arg := range fun.Args { - newCol, err := arg.RemapColumn(schemaMap) - if err != nil { - return nil, nil, nil, nil, nil, err - } - newArgs = append(newArgs, newCol) - } + // remap final normal agg args to be output schema of middle normal agg. + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, nil, err } - fun.Mode = aggregation.FinalMode - fun.Args = newArgs - fun.GroupingID = 0 - finalAggDescs = append(finalAggDescs, fun) + newArgs = append(newArgs, newCol) } - finalHashAgg.AggFuncs = finalAggDescs - return finalHashAgg, middleHashAgg, proj4Middle, partialHashAgg, proj4Partial, nil } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + fun.GroupingID = 0 + finalAggDescs = append(finalAggDescs, fun) } - // return the original finalAgg and partiAgg - return finalAgg, nil, nil, partialAgg, nil, nil + finalHashAgg.AggFuncs = finalAggDescs + return finalHashAgg, middleHashAgg, proj4Middle, partialHashAgg, proj4Partial, nil } func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { @@ -2545,10 +2539,10 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { case MppScalar: prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.SinglePartitionType} if !mpp.needEnforceExchanger(prop) { - // one hand: when the low layer already satisfied the single partition layout, just do the all agg computation in the single node. + // On the one hand: when the low layer already satisfied the single partition layout, just do the all agg computation in the single node. return p.attach2TaskForMpp1Phase(mpp) } - // other hand: try to split the mppScalar agg into multi phases agg **down** to multi nodes since data already distributed across nodes. + // On the other hand: try to split the mppScalar agg into multi phases agg **down** to multi nodes since data already distributed across nodes. // we have to check it before the content of p has been modified canUse3StageAgg, groupingSets := p.scale3StageForDistinctAgg() proj := p.convertAvgForMPP() @@ -2773,7 +2767,6 @@ func accumulateNetSeekCost4MPP(p PhysicalPlan) (cost float64) { return } -// mppTask 转 root task 的时候,需要加一个 sender 并封装为 table reader func (t *mppTask) convertToRootTaskImpl(ctx sessionctx.Context) *rootTask { sender := PhysicalExchangeSender{ ExchangeType: tipb.ExchangeType_PassThrough, From c996b80d312383e6e7191b47a158d64660edd396 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Wed, 28 Dec 2022 00:26:56 +0800 Subject: [PATCH 08/25] make bazel prepare Signed-off-by: AilinKid <314806019@qq.com> --- expression/BUILD.bazel | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index fcaa063a23c9c..f7abba73c0e9b 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -54,6 +54,7 @@ go_library( "expression.go", "extension.go", "function_traits.go", + "grouping_sets.go", "helper.go", "partition_pruner.go", "scalar_function.go", @@ -78,6 +79,7 @@ go_library( "//parser/opcode", "//parser/terror", "//parser/types", + "//planner/funcdep", "//privilege", "//sessionctx", "//sessionctx/stmtctx", @@ -171,6 +173,7 @@ go_test( "expr_to_pb_test.go", "expression_test.go", "function_traits_test.go", + "grouping_sets_test.go", "helper_test.go", "main_test.go", "multi_valued_index_test.go", @@ -227,6 +230,7 @@ go_test( "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//tikv", + "@io_opencensus_go//stats/view", "@org_uber_go_goleak//:goleak", ], ) From 739b579464acef121a214aec21f9941d4cd94fd0 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Wed, 28 Dec 2022 10:28:12 +0800 Subject: [PATCH 09/25] make fmt Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets.go | 39 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index a8e22dc4f973a..bc884b842d066 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -31,7 +31,7 @@ type GroupingSet []GroupingExprs type GroupingExprs []Expression // Merge function will explore the internal grouping expressions and try to find the minimum grouping sets. (prefix merging) -func (gs GroupingSets) Merge() GroupingSets { +func (gss GroupingSets) Merge() GroupingSets { // for now, there is precondition that all grouping expressions are columns. // for example: (a,b,c) and (a,b) and (a) will be merged as a one. // Eg: @@ -54,8 +54,8 @@ func (gs GroupingSets) Merge() GroupingSets { // care about the prefix order, which should be taken as following the group layout expanding rule. (simple way) // [[a],[b,a],[c,b,a],] is also conforming the rule, gradually including one/more column(s) inside for one time. - newGroupingSets := make(GroupingSets, 0, len(gs)) - for _, oneGroupingSet := range gs { + newGroupingSets := make(GroupingSets, 0, len(gss)) + for _, oneGroupingSet := range gss { for _, oneGroupingExpr := range oneGroupingSet { if len(newGroupingSets) == 0 { // means there is nothing in new grouping sets, adding current one anyway. @@ -68,7 +68,7 @@ func (gs GroupingSets) Merge() GroupingSets { return newGroupingSets } -func (gs GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { +func (gss GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { // for every existing grouping set, check the grouping-exprs inside and whether the current grouping-exprs is // super-set of it or sub-set of it, adding current one to the correct position of grouping-exprs slice. // @@ -90,7 +90,7 @@ func (gs GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { // | | | // +----+-------+ every previous one is the subset of the latter one. // - for i, oneNewGroupingSet := range gs { + for i, oneNewGroupingSet := range gss { // for every group set,try to find its position to insert if possible,otherwise create a new grouping set. for j := len(oneNewGroupingSet) - 1; j >= 0; j-- { cur := oneNewGroupingSet[j] @@ -100,8 +100,8 @@ func (gs GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { cp := make(GroupingSet, 0, len(oneNewGroupingSet)+1) cp = append(cp, targetOne) cp = append(cp, oneNewGroupingSet...) - gs[i] = cp - return gs + gss[i] = cp + return gss } // do the left shift to find the right insert pos. continue @@ -109,25 +109,24 @@ func (gs GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { if j == len(oneNewGroupingSet)-1 { // which means the targetOne can't fit itself in this grouping set, continue next grouping set. break - } else { - // successfully fit in, current j is the right pos to insert. - cp := make(GroupingSet, 0, len(oneNewGroupingSet)+1) - cp = append(cp, oneNewGroupingSet[:j+1]...) - cp = append(cp, targetOne) - cp = append(cp, oneNewGroupingSet[j+1:]...) - gs[i] = cp - return gs } + // successfully fit in, current j is the right pos to insert. + cp := make(GroupingSet, 0, len(oneNewGroupingSet)+1) + cp = append(cp, oneNewGroupingSet[:j+1]...) + cp = append(cp, targetOne) + cp = append(cp, oneNewGroupingSet[j+1:]...) + gss[i] = cp + return gss } } // here means we couldn't find even one GroupingSet to fill the targetOne, creating a new one. - gs = append(gs, newGroupingSet(targetOne)) + gss = append(gss, newGroupingSet(targetOne)) // gs is an alias of slice [], we should return it back after being changed. - return gs + return gss } // TargetOne is used to find a valid group layout for normal agg. -func (gs GroupingSets) TargetOne(targetOne GroupingExprs) int { +func (gss GroupingSets) TargetOne(targetOne GroupingExprs) int { // it has three cases. // 1: group sets {}, {} when the normal agg(d), agg(d) can only occur in the group of or after tuple split. (normal column are appended) // 2: group sets {}, {} when the normal agg(c), agg(c) can only be found in group of , since it will be filled with null in group of @@ -146,8 +145,8 @@ func (gs GroupingSets) TargetOne(targetOne GroupingExprs) int { // Expand operator, and null value can also influence your non null-strict normal agg, although you don't want them to. // // here we only consider 1&2 since normal agg with multi col args is banned. - allGroupingColIDs := gs.AllSetsColIDs() - for idx, groupingSet := range gs { + allGroupingColIDs := gss.AllSetsColIDs() + for idx, groupingSet := range gss { // diffCols are those columns being filled with null in the group row data of current grouping set. diffCols := allGroupingColIDs.Difference(*groupingSet.allSetColIDs()) if diffCols.Intersects(*targetOne.IDSet()) { From e2d5069abcc5fb0cf4a1623160e1fe00e1c38560 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Wed, 28 Dec 2022 10:43:21 +0800 Subject: [PATCH 10/25] add function comment Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index bc884b842d066..a78f0dd3a81f9 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -24,10 +24,13 @@ import ( "github.com/pingcap/tipb/go-tipb" ) +// GroupingSets indicates the grouping sets definition. type GroupingSets []GroupingSet +// GroupingSet indicates one grouping set definition. type GroupingSet []GroupingExprs +// GroupingExprs indicates one grouping-expressions inside a grouping set. type GroupingExprs []Expression // Merge function will explore the internal grouping expressions and try to find the minimum grouping sets. (prefix merging) @@ -159,6 +162,7 @@ func (gss GroupingSets) TargetOne(targetOne GroupingExprs) int { return -1 } +// IsEmpty indicates whether current grouping set is empty. func (gs GroupingSet) IsEmpty() bool { if len(gs) == 0 { return true @@ -181,6 +185,7 @@ func (gs GroupingSet) allSetColIDs() *fd.FastIntSet { return &res } +// ExtractCols is used to extract basic columns from one grouping set. func (gs GroupingSet) ExtractCols() []*Column { cols := make([]*Column, 0, len(gs)) for _, groupingExprs := range gs { @@ -191,6 +196,7 @@ func (gs GroupingSet) ExtractCols() []*Column { return cols } +// Clone is used to clone a copy of current grouping set. func (gs GroupingSet) Clone() GroupingSet { gc := make(GroupingSet, 0, len(gs)) for _, one := range gs { @@ -199,6 +205,7 @@ func (gs GroupingSet) Clone() GroupingSet { return gc } +// String is used to output a string which simply described current grouping set. func (gs GroupingSet) String() string { var str strings.Builder str.WriteString("{") @@ -212,6 +219,7 @@ func (gs GroupingSet) String() string { return str.String() } +// MemoryUsage is used to output current memory usage by current grouping set. func (gs GroupingSet) MemoryUsage() int64 { sum := size.SizeOfSlice + int64(cap(gs))*size.SizeOfPointer for _, one := range gs { @@ -220,6 +228,7 @@ func (gs GroupingSet) MemoryUsage() int64 { return sum } +// ToPB is used to convert current grouping set to pb constructor. func (gs GroupingSet) ToPB(sc *stmtctx.StatementContext, client kv.Client) (*tipb.GroupingSet, error) { res := &tipb.GroupingSet{} for _, gExprs := range gs { @@ -232,6 +241,7 @@ func (gs GroupingSet) ToPB(sc *stmtctx.StatementContext, client kv.Client) (*tip return res, nil } +// IsEmpty indicates whether current grouping sets is empty. func (gss GroupingSets) IsEmpty() bool { if len(gss) == 0 { return true @@ -244,6 +254,7 @@ func (gss GroupingSets) IsEmpty() bool { return true } +// AllSetsColIDs is used to collect all the column id inside into a fast int set. func (gss GroupingSets) AllSetsColIDs() *fd.FastIntSet { res := fd.NewFastIntSet() for _, groupingSet := range gss { @@ -252,6 +263,7 @@ func (gss GroupingSets) AllSetsColIDs() *fd.FastIntSet { return &res } +// String is used to output a string which simply described current grouping sets. func (gss GroupingSets) String() string { var str strings.Builder str.WriteString("[") @@ -265,6 +277,7 @@ func (gss GroupingSets) String() string { return str.String() } +// ToPB is used to convert current grouping sets to pb constructor. func (gss GroupingSets) ToPB(sc *stmtctx.StatementContext, client kv.Client) ([]*tipb.GroupingSet, error) { res := make([]*tipb.GroupingSet, 0, len(gss)) for _, gs := range gss { @@ -283,10 +296,12 @@ func newGroupingSet(oneGroupingExpr GroupingExprs) GroupingSet { return res } +// IsEmpty indicates whether current grouping expressions are empty. func (g GroupingExprs) IsEmpty() bool { return len(g) == 0 } +// SubSetOf is used to do the logical computation of subset between two grouping expressions. func (g GroupingExprs) SubSetOf(other GroupingExprs) bool { oldOne := fd.NewFastIntSet() newOne := fd.NewFastIntSet() @@ -299,6 +314,7 @@ func (g GroupingExprs) SubSetOf(other GroupingExprs) bool { return oldOne.SubsetOf(newOne) } +// IDSet is used to collect column ids inside grouping expressions into a fast int set. func (g GroupingExprs) IDSet() *fd.FastIntSet { res := fd.NewFastIntSet() for _, one := range g { @@ -307,6 +323,7 @@ func (g GroupingExprs) IDSet() *fd.FastIntSet { return &res } +// Clone is used to clone a copy of current grouping expressions. func (g GroupingExprs) Clone() GroupingExprs { gc := make(GroupingExprs, 0, len(g)) for _, one := range g { @@ -315,6 +332,7 @@ func (g GroupingExprs) Clone() GroupingExprs { return gc } +// String is used to output a string which simply described current grouping expressions. func (g GroupingExprs) String() string { var str strings.Builder str.WriteString("<") @@ -328,6 +346,7 @@ func (g GroupingExprs) String() string { return str.String() } +// MemoryUsage is used to output current memory usage by current grouping expressions. func (g GroupingExprs) MemoryUsage() int64 { sum := size.SizeOfSlice + int64(cap(g))*size.SizeOfInterface for _, one := range g { From b2ef4225eb8c0863c18446dcaba125e87afd21b9 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Wed, 28 Dec 2022 10:48:57 +0800 Subject: [PATCH 11/25] fix lint Signed-off-by: AilinKid <314806019@qq.com> --- store/mockstore/unistore/cophandler/mpp.go | 8 ++++---- store/mockstore/unistore/cophandler/mpp_exec.go | 12 +++++------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/store/mockstore/unistore/cophandler/mpp.go b/store/mockstore/unistore/cophandler/mpp.go index eec358c038ba9..adeea777871c7 100644 --- a/store/mockstore/unistore/cophandler/mpp.go +++ b/store/mockstore/unistore/cophandler/mpp.go @@ -229,15 +229,15 @@ func (b *mppExecBuilder) buildExpand(pb *tipb.Expand) (mppExec, error) { // for every grouping set, collect column offsets under this grouping set. for _, groupingExprs := range gs { for _, groupingExpr := range groupingExprs { - if col, ok := groupingExpr.(*expression.Column); !ok { + col, ok := groupingExpr.(*expression.Column) + if !ok { return nil, errors.New("grouping set expr is not column ref") - } else { - inGroupingSetMap[col.Index] = struct{}{} } + inGroupingSetMap[col.Index] = struct{}{} } } } - var mutatedFieldTypes []*types.FieldType + mutatedFieldTypes := make([]*types.FieldType, 0, len(childFieldTypes)) // change the field types return from children tobe nullable. for offset, f := range childFieldTypes { cf := f.Clone() diff --git a/store/mockstore/unistore/cophandler/mpp_exec.go b/store/mockstore/unistore/cophandler/mpp_exec.go index 1a2f504afb4a4..a9233d423b7ef 100644 --- a/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/store/mockstore/unistore/cophandler/mpp_exec.go @@ -429,9 +429,7 @@ type expandExec struct { } func (e *expandExec) open() error { - var err error - err = e.children[0].open() - if err != nil { + if err := e.children[0].open(); err != nil { return err } // building the quick finding map @@ -442,12 +440,12 @@ func (e *expandExec) open() error { // for every grouping set, collect column offsets under this grouping set. for _, groupingExprs := range gs { for _, groupingExpr := range groupingExprs { - if col, ok := groupingExpr.(*expression.Column); !ok { + col, ok := groupingExpr.(*expression.Column) + if !ok { return errors.New("grouping set expr is not column ref") - } else { - tmp[col.Index] = struct{}{} - e.groupingSetScope[col.Index] = struct{}{} } + tmp[col.Index] = struct{}{} + e.groupingSetScope[col.Index] = struct{}{} } } e.groupingSetOffsetMap = append(e.groupingSetOffsetMap, tmp) From 8d38687888b128b4e316100ee7da3da111a6a17d Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Wed, 28 Dec 2022 19:17:57 +0800 Subject: [PATCH 12/25] fix lint Signed-off-by: AilinKid <314806019@qq.com> --- planner/core/task.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/planner/core/task.go b/planner/core/task.go index 06d7fa3f3e1e1..1aaf84f3911b8 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1894,7 +1894,7 @@ func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss express } defer func() { // some clean work. - if can == false { + if !can { for _, fun := range p.AggFuncs { fun.GroupingID = 0 } @@ -2191,7 +2191,7 @@ func (p *PhysicalHashAgg) scaleStats4GroupingSets(groupingSets expression.Groupi cpStats := p.stats.Scale(1) cpStats.RowCount = sumNDV // We cannot estimate the ColNDVs for every output, so we use a conservative strategy. - for k, _ := range cpStats.ColNDVs { + for k := range cpStats.ColNDVs { cpStats.ColNDVs[k] = sumNDV } // for old groupNDV, if it's containing one more grouping set cols, just plus the NDV where the col is excluded. @@ -2312,7 +2312,6 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan // multi distinct agg mode, having grouping sets. // set the default expression to constant 1 for the convenience to choose default group set data. var groupingIDCol expression.Expression - groupingIDCol = expression.NewUInt64Const(1) // enforce Expand operator above the children. // physical plan is enumerated without children from itself, use mpp subtree instead p.children. // scale(len(groupingSets)) will change the NDV, while Repeat doesn't change the NDV and groupNDV. From a4e0506586714faedc5bf93d456d627af53abd9a Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Wed, 28 Dec 2022 20:10:20 +0800 Subject: [PATCH 13/25] fix old test Signed-off-by: AilinKid <314806019@qq.com> --- .../casetest/testdata/plan_suite_out.json | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/planner/core/casetest/testdata/plan_suite_out.json b/planner/core/casetest/testdata/plan_suite_out.json index 86ee4c51259b2..91669f2dc17d9 100644 --- a/planner/core/casetest/testdata/plan_suite_out.json +++ b/planner/core/casetest/testdata/plan_suite_out.json @@ -4280,17 +4280,22 @@ "TableReader 1.00 root MppVersion: 1, data:ExchangeSender", "└─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection 1.00 mpp[tiflash] Column#7", - " └─HashAgg 1.00 mpp[tiflash] group by:Column#12, funcs:sum(Column#13)->Column#7", + " └─HashAgg 1.00 mpp[tiflash] group by:Column#22, funcs:sum(Column#23)->Column#7", " └─ExchangeReceiver 1.00 mpp[tiflash] ", - " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#12, collate: binary]", - " └─HashAgg 1.00 mpp[tiflash] group by:Column#15, funcs:count(Column#14)->Column#13", - " └─Projection 1.00 mpp[tiflash] Column#5, plus(Column#6, 1)->Column#15", + " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#22, collate: binary]", + " └─HashAgg 1.00 mpp[tiflash] group by:Column#25, funcs:count(Column#24)->Column#23", + " └─Projection 1.00 mpp[tiflash] Column#5, plus(Column#6, 1)->Column#25", " └─Projection 1.00 mpp[tiflash] Column#5, Column#6", - " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#5, funcs:count(distinct test.employee.empid)->Column#6", + " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#5, funcs:sum(Column#16)->Column#6", " └─ExchangeReceiver 1.00 mpp[tiflash] ", " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 mpp[tiflash] group by:test.employee.deptid, test.employee.empid, ", - " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" + " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", + " └─Projection 16000.00 mpp[tiflash] test.employee.deptid, test.employee.empid, Column#13, case(eq(Column#13, 1), test.employee.deptid, )->Column#15, case(eq(Column#13, 2), test.employee.empid, )->Column#17", + " └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#13, collate: binary]", + " └─HashAgg 16000.00 mpp[tiflash] group by:Column#13, test.employee.deptid, test.employee.empid, ", + " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", + " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, { @@ -4301,11 +4306,16 @@ " └─Projection 1.00 mpp[tiflash] Column#7", " └─HashAgg 1.00 mpp[tiflash] group by:Column#6, funcs:count(Column#5)->Column#7", " └─Projection 1.00 mpp[tiflash] Column#5, Column#6", - " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#5, funcs:count(distinct test.employee.empid)->Column#6", + " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#5, funcs:sum(Column#16)->Column#6", " └─ExchangeReceiver 1.00 mpp[tiflash] ", " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 mpp[tiflash] group by:test.employee.deptid, test.employee.empid, ", - " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" + " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", + " └─Projection 16000.00 mpp[tiflash] test.employee.deptid, test.employee.empid, Column#13, case(eq(Column#13, 1), test.employee.deptid, )->Column#15, case(eq(Column#13, 2), test.employee.empid, )->Column#17", + " └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#13, collate: binary]", + " └─HashAgg 16000.00 mpp[tiflash] group by:Column#13, test.employee.deptid, test.employee.empid, ", + " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", + " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, { @@ -4317,11 +4327,16 @@ " ├─ExchangeReceiver(Build) 1.00 mpp[tiflash] ", " │ └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: Broadcast", " │ └─Projection 1.00 mpp[tiflash] Column#9, Column#10", - " │ └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#9, funcs:count(distinct test.employee.empid)->Column#10", + " │ └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#9, funcs:sum(Column#14)->Column#10", " │ └─ExchangeReceiver 1.00 mpp[tiflash] ", " │ └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " │ └─HashAgg 1.00 mpp[tiflash] group by:test.employee.deptid, test.employee.empid, ", - " │ └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo", + " │ └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct Column#13)->Column#12, funcs:count(distinct Column#15)->Column#14", + " │ └─Projection 16000.00 mpp[tiflash] test.employee.deptid, test.employee.empid, Column#11, case(eq(Column#11, 1), test.employee.deptid, )->Column#13, case(eq(Column#11, 2), test.employee.empid, )->Column#15", + " │ └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " │ └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", + " │ └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", + " │ └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", + " │ └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo", " └─TableFullScan(Probe) 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, From 347df77a9eb81a66075545d4aa35e6ff004e009f Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 29 Dec 2022 15:50:02 +0800 Subject: [PATCH 14/25] make fmt Signed-off-by: AilinKid <314806019@qq.com> --- planner/core/task.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/planner/core/task.go b/planner/core/task.go index 1aaf84f3911b8..9428d18061103 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -2515,7 +2515,6 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { }) } } - // 形成新的 partition cols property prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} newMpp := mpp.enforceExchangerImpl(prop) if newMpp.invalid() { @@ -2647,7 +2646,6 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task { attachPlan2Task(p, t) } } else if _, ok := t.(*mppTask); ok { - // 如果底层暂时还是个 mppTask 的话 return final.attach2TaskForMpp(tasks...) } else { attachPlan2Task(p, t) @@ -2856,7 +2854,7 @@ func (t *mppTask) enforceExchangerImpl(prop *property.PhysicalProperty) *mppTask ctx := t.p.SCtx() sender := PhysicalExchangeSender{ ExchangeType: prop.MPPPartitionTp.ToExchangeType(), - HashCols: prop.MPPPartitionCols, // 执行层面怎么挂钩 hash key 和具体的 node 节点? + HashCols: prop.MPPPartitionCols, }.Init(ctx, t.p.statsInfo()) if ctx.GetSessionVars().ChooseMppVersion() >= kv.MppVersionV1 { From 5e8f4ae7d18d0f6608d690d6f8a3e6ef197cbb87 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 29 Dec 2022 16:18:35 +0800 Subject: [PATCH 15/25] banned cases need column clone Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets.go | 19 ++++++++++ .../testdata/enforce_mpp_suite_in.json | 3 +- .../testdata/enforce_mpp_suite_out.json | 38 +++++++++++-------- planner/core/task.go | 7 +++- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index a78f0dd3a81f9..b24e1f6aad2eb 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -162,6 +162,25 @@ func (gss GroupingSets) TargetOne(targetOne GroupingExprs) int { return -1 } +// NeedCloneColumn indicate whether we need to copy column to when expanding datasource. +func (gss GroupingSets) NeedCloneColumn() bool { + // for grouping sets like: {},{} / {},{} + // the column c should be copied one more time here, otherwise it will be filled with null values and not visible for the other grouping set again. + setIDs := make([]*fd.FastIntSet, 0, len(gss)) + for _, groupingSet := range gss { + setIDs = append(setIDs, groupingSet.allSetColIDs()) + } + for idx, oneSetIDs := range setIDs { + for j := idx + 1; j < len(setIDs); j++ { + otherSetIDs := setIDs[j] + if oneSetIDs.Intersects(*otherSetIDs) { + return true + } + } + } + return false +} + // IsEmpty indicates whether current grouping set is empty. func (gs GroupingSet) IsEmpty() bool { if len(gs) == 0 { diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_in.json b/planner/core/casetest/testdata/enforce_mpp_suite_in.json index 18f1da4a8578e..9b74e49a81808 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_in.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_in.json @@ -149,7 +149,8 @@ "EXPLAIN select count(distinct a+b), sum(c) from t", "select count(distinct a+b), sum(c) from t", "EXPLAIN select count(distinct a+b), count(distinct b+c), count(c) from t", - "select count(distinct a+b), count(distinct b+c), count(c) from t" + "select count(distinct a+b), count(distinct b+c), count(c) from t", + "explain select count(distinct a,c), count(distinct b,c), count(c) from t" ] } ] diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/testdata/enforce_mpp_suite_out.json index b099bd678fe3d..ebd9545578f3c 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_out.json @@ -1315,21 +1315,15 @@ { "SQL": "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", "Plan": [ - "TableReader_38 1.00 root data:ExchangeSender_37", - "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7, Column#8, Column#9", - " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#26)->Column#6, funcs:sum(Column#28)->Column#7, funcs:sum(Column#30)->Column#8, funcs:sum(Column#31)->Column#9", - " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", - " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#27, test.t.b)->Column#26, funcs:count(distinct Column#29)->Column#28, funcs:sum(Column#21)->Column#30, funcs:sum(Column#22)->Column#31", - " └─Projection_32 16000.00 mpp[tiflash] Column#21, Column#22, test.t.a, test.t.b, Column#23, case(eq(Column#23, 1), test.t.a, )->Column#27, case(eq(Column#23, 2), test.t.b, )->Column#29", - " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", - " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#23, collate: binary]", - " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#34, Column#35, Column#36, funcs:count(Column#32)->Column#21, funcs:sum(Column#33)->Column#22", - " └─Projection_39 20000.00 mpp[tiflash] Column#24, cast(Column#25, decimal(10,0) BINARY)->Column#33, test.t.a, test.t.b, Column#23", - " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#23, case(eq(Column#23, 1), test.t.c, )->Column#24, case(eq(Column#23, 1), test.t.d, )->Column#25", - " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#23, [{},{}]", - " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + "TableReader_26 1.00 root data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8, Column#9", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct test.t.a, test.t.b)->Column#6, funcs:count(distinct test.t.b)->Column#7, funcs:sum(Column#12)->Column#8, funcs:sum(Column#13)->Column#9", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:Column#16, Column#17, funcs:count(Column#14)->Column#12, funcs:sum(Column#15)->Column#13", + " └─Projection_27 10000.00 mpp[tiflash] test.t.c, cast(test.t.d, decimal(10,0) BINARY)->Column#15, test.t.a, test.t.b", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": [ "Some grouping sets should be merged", @@ -1389,6 +1383,20 @@ "5 5 10" ], "Warn": null + }, + { + "SQL": "explain select count(distinct a,c), count(distinct b,c), count(c) from t", + "Plan": [ + "TableReader_26 1.00 root data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct test.t.a, test.t.c)->Column#6, funcs:count(distinct test.t.b, test.t.c)->Column#7, funcs:sum(Column#10)->Column#8", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:test.t.a, test.t.b, test.t.c, funcs:count(test.t.c)->Column#10", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null } ] } diff --git a/planner/core/task.go b/planner/core/task.go index 9428d18061103..f67e1cc775049 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1929,8 +1929,13 @@ func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss express } Compressed := GroupingSets.Merge() if len(Compressed) != len(GroupingSets) { - // todo arenatlx: some grouping set should be merged which is not supported by now temporarily. p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("Some grouping sets should be merged")) + // todo arenatlx: some grouping set should be merged which is not supported by now temporarily. + return false, nil + } + if GroupingSets.NeedCloneColumn() { + // todo: column clone haven't implemented. + return false, nil } if len(GroupingSets) > 1 { // fill the grouping ID for normal agg. From 83bd74f6409d5b0098b876e92343f2e6f226315b Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 29 Dec 2022 16:19:30 +0800 Subject: [PATCH 16/25] make fmt --- expression/constant.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expression/constant.go b/expression/constant.go index a04b71e639d59..a0807ce6cc1c3 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -52,7 +52,7 @@ func NewZero() *Constant { } } -// NewUInt64Const stands for constant of a given number. +// NewUInt64Const stands for constant of a given number. func NewUInt64Const(num int64) *Constant { retT := types.NewFieldType(mysql.TypeLonglong) retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion From 854a3fad20be14480a65e64c69dd693407e80ddb Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 29 Dec 2022 16:20:47 +0800 Subject: [PATCH 17/25] remove debug log Signed-off-by: AilinKid <314806019@qq.com> --- planner/core/task.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/planner/core/task.go b/planner/core/task.go index f67e1cc775049..f634d987e8e66 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -15,10 +15,6 @@ package core import ( - "fmt" - "math" - "strings" - "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/expression" @@ -43,6 +39,7 @@ import ( "github.com/pingcap/tidb/util/size" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" + "math" ) var ( @@ -2484,9 +2481,6 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { if !ok { return invalidTask } - if strings.HasPrefix(p.ctx.GetSessionVars().StmtCtx.OriginalSQL, "select count(distinct a, b), count(distinct b), count(c) from t group by d") { - fmt.Println(1) - } switch p.MppRunMode { case Mpp1Phase: // 1-phase agg: when the partition columns can be satisfied, where the plan does not need to enforce Exchange From ac6e5a2e86e71e5c4a260e82bf54641c8bf12d9c Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Fri, 30 Dec 2022 15:40:55 +0800 Subject: [PATCH 18/25] address elsa comment Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets.go | 12 ++++++-- expression/grouping_sets_test.go | 47 ++++++++++++++++++++++++++++++++ planner/core/task.go | 5 ++-- 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index b24e1f6aad2eb..27d4b57789e28 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -65,13 +65,14 @@ func (gss GroupingSets) Merge() GroupingSets { newGroupingSets = append(newGroupingSets, newGroupingSet(oneGroupingExpr)) continue } - newGroupingSets = newGroupingSets.mergeOne(oneGroupingExpr) + newGroupingSets = newGroupingSets.MergeOne(oneGroupingExpr) } } return newGroupingSets } -func (gss GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { +// MergeOne is used to merge one grouping expressions into current grouping sets. +func (gss GroupingSets) MergeOne(targetOne GroupingExprs) GroupingSets { // for every existing grouping set, check the grouping-exprs inside and whether the current grouping-exprs is // super-set of it or sub-set of it, adding current one to the correct position of grouping-exprs slice. // @@ -110,6 +111,13 @@ func (gss GroupingSets) mergeOne(targetOne GroupingExprs) GroupingSets { continue } if j == len(oneNewGroupingSet)-1 { + // which means the targetOne itself is the super set of current right-most grouping set. + if cur.SubSetOf(targetOne) { + // the right pos should be the len(oneNewGroupingSet) + oneNewGroupingSet = append(oneNewGroupingSet, targetOne) + gss[i] = oneNewGroupingSet + return gss + } // which means the targetOne can't fit itself in this grouping set, continue next grouping set. break } diff --git a/expression/grouping_sets_test.go b/expression/grouping_sets_test.go index 9a73abe47d186..65d3d3df8beb6 100644 --- a/expression/grouping_sets_test.go +++ b/expression/grouping_sets_test.go @@ -78,6 +78,53 @@ func TestGroupSetsTargetOne(t *testing.T) { // shuffler take place. } +func TestGroupingSetsMergeOneUnitTest(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + } + b := &Column{ + UniqueID: 2, + } + c := &Column{ + UniqueID: 3, + } + d := &Column{ + UniqueID: 4, + } + // test case about the right most fitness. + newGroupingSets := make(GroupingSets, 0, 3) + newGroupingSets = newGroupingSets[:0] + groupingExprs := []Expression{a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + newGroupingSets.MergeOne([]Expression{a, b}) + + require.Equal(t, len(newGroupingSets), 1) + require.Equal(t, len(newGroupingSets[0]), 2) + //{a} + require.Equal(t, len(newGroupingSets[0][0]), 1) + //{a,b} + require.Equal(t, len(newGroupingSets[0][1]), 2) + + newGroupingSets = newGroupingSets[:0] + groupingExprs = []Expression{a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + newGroupingSets.MergeOne([]Expression{d, c, b, a}) + newGroupingSets.MergeOne([]Expression{d, b, a}) + newGroupingSets.MergeOne([]Expression{b, a}) + + require.Equal(t, len(newGroupingSets), 1) + require.Equal(t, len(newGroupingSets[0]), 4) + //{a} + require.Equal(t, len(newGroupingSets[0][0]), 1) + //{b,a} + require.Equal(t, len(newGroupingSets[0][1]), 2) + //{d,b,a} + require.Equal(t, len(newGroupingSets[0][2]), 3) + //{d,c,b,a} + require.Equal(t, len(newGroupingSets[0][3]), 4) +} + func TestGroupingSetsMergeUnitTest(t *testing.T) { defer view.Stop() a := &Column{ diff --git a/planner/core/task.go b/planner/core/task.go index f634d987e8e66..69c3eee052fab 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -15,6 +15,8 @@ package core import ( + "math" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/expression" @@ -39,7 +41,6 @@ import ( "github.com/pingcap/tidb/util/size" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" - "math" ) var ( @@ -1954,7 +1955,7 @@ func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss express return false, nil } -// canUse3StageDistinctAgg returns true if this agg can use 3 stage for distinct aggregation +// canUse3Stage4SingleDistinctAgg returns true if this agg can use 3 stage for distinct aggregation func (p *basePhysicalAgg) canUse3Stage4SingleDistinctAgg() bool { num := 0 if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 { From 6322c55609a5e8d02ed6d7ffa2ac8188841d98d5 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Wed, 4 Jan 2023 14:55:37 +0800 Subject: [PATCH 19/25] address haisheng's comment Signed-off-by: AilinKid <314806019@qq.com> --- expression/aggregation/descriptor.go | 2 +- expression/constant.go | 2 +- planner/core/task.go | 22 +++++++++++----------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 3894355a2c546..3015b4718a226 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -43,7 +43,7 @@ type AggFuncDesc struct { // OrderByItems represents the order by clause used in GROUP_CONCAT OrderByItems []*util.ByItems // GroupingID is used for distinguishing with not-set 0, starting from 1. - GroupingID int64 + GroupingID int } // NewAggFuncDesc creates an aggregation function signature descriptor. diff --git a/expression/constant.go b/expression/constant.go index a0807ce6cc1c3..1f48595985c9b 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -53,7 +53,7 @@ func NewZero() *Constant { } // NewUInt64Const stands for constant of a given number. -func NewUInt64Const(num int64) *Constant { +func NewUInt64Const(num int) *Constant { retT := types.NewFieldType(mysql.TypeLonglong) retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion retT.SetFlen(mysql.MaxIntWidth) diff --git a/planner/core/task.go b/planner/core/task.go index 69c3eee052fab..c0a1868df760a 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1898,8 +1898,8 @@ func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss express } } }() - // GroupingSets is alias of []GroupingSet, the below equal to = make([]GroupingSet, 0, 2) - GroupingSets := make(expression.GroupingSets, 0, 2) + // groupingSets is alias of []GroupingSet, the below equal to = make([]GroupingSet, 0, 2) + groupingSets := make(expression.GroupingSets, 0, 2) for _, fun := range p.AggFuncs { if fun.HasDistinct { if fun.Name != ast.AggFuncCount { @@ -1913,10 +1913,10 @@ func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss express } } // here it's a valid count distinct agg with normal column args, collecting its distinct expr. - GroupingSets = append(GroupingSets, expression.GroupingSet{fun.Args}) + groupingSets = append(groupingSets, expression.GroupingSet{fun.Args}) // groupingID now is the offset of target grouping in GroupingSets. // todo: it may be changed after grouping set merge in the future. - fun.GroupingID = int64(len(GroupingSets)) + fun.GroupingID = len(groupingSets) } else if len(fun.Args) > 1 { return false, nil } @@ -1925,32 +1925,32 @@ func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss express return false, nil } } - Compressed := GroupingSets.Merge() - if len(Compressed) != len(GroupingSets) { + compressed := groupingSets.Merge() + if len(compressed) != len(groupingSets) { p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("Some grouping sets should be merged")) // todo arenatlx: some grouping set should be merged which is not supported by now temporarily. return false, nil } - if GroupingSets.NeedCloneColumn() { + if groupingSets.NeedCloneColumn() { // todo: column clone haven't implemented. return false, nil } - if len(GroupingSets) > 1 { + if len(groupingSets) > 1 { // fill the grouping ID for normal agg. for _, fun := range p.AggFuncs { if fun.GroupingID == 0 { // the grouping ID hasn't set. find the targeting grouping set. - groupingSetOffset := GroupingSets.TargetOne(fun.Args) + groupingSetOffset := groupingSets.TargetOne(fun.Args) if groupingSetOffset == -1 { // todo: if we couldn't find a existed current valid group layout, we need to copy the column out from being filled with null value. p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("couldn't find a proper group set for normal agg")) return false, nil } // starting with 1 - fun.GroupingID = int64(groupingSetOffset + 1) + fun.GroupingID = groupingSetOffset + 1 } } - return true, GroupingSets + return true, groupingSets } return false, nil } From 46cafc21ea3d1beaf13f0f7a2f0150e02c53fc59 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Fri, 6 Jan 2023 16:41:58 +0800 Subject: [PATCH 20/25] left distinct count agg to aggregate on whole-scope data Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets.go | 21 ++- expression/grouping_sets_test.go | 59 +++++++ .../testdata/enforce_mpp_suite_in.json | 8 +- .../testdata/enforce_mpp_suite_out.json | 160 +++++++++++++----- planner/core/task.go | 86 +++------- 5 files changed, 222 insertions(+), 112 deletions(-) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index 27d4b57789e28..b252cacdf059a 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -136,8 +136,8 @@ func (gss GroupingSets) MergeOne(targetOne GroupingExprs) GroupingSets { return gss } -// TargetOne is used to find a valid group layout for normal agg. -func (gss GroupingSets) TargetOne(targetOne GroupingExprs) int { +// TargetOne is used to find a valid group layout for normal agg, note that: the args in normal agg are not necessary to be column. +func (gss GroupingSets) TargetOne(normalAggArgs []Expression) int { // it has three cases. // 1: group sets {}, {} when the normal agg(d), agg(d) can only occur in the group of or after tuple split. (normal column are appended) // 2: group sets {}, {} when the normal agg(c), agg(c) can only be found in group of , since it will be filled with null in group of @@ -156,11 +156,26 @@ func (gss GroupingSets) TargetOne(targetOne GroupingExprs) int { // Expand operator, and null value can also influence your non null-strict normal agg, although you don't want them to. // // here we only consider 1&2 since normal agg with multi col args is banned. + columnInNormalAggArgs := make([]*Column, 0, len(normalAggArgs)) + for _, one := range normalAggArgs { + columnInNormalAggArgs = append(columnInNormalAggArgs, ExtractColumns(one)...) + } + if len(columnInNormalAggArgs) == 0 { + // no column in normal agg. eg: count(1), specify the default grouping set ID 0+1. + return 0 + } + // for other normal agg args like: count(a), count(a+b), count(not(a is null)) and so on. + normalAggArgsIDSet := fd.NewFastIntSet() + for _, one := range columnInNormalAggArgs { + normalAggArgsIDSet.Insert(int(one.UniqueID)) + } + + // identify the suitable grouping set for normal agg. allGroupingColIDs := gss.AllSetsColIDs() for idx, groupingSet := range gss { // diffCols are those columns being filled with null in the group row data of current grouping set. diffCols := allGroupingColIDs.Difference(*groupingSet.allSetColIDs()) - if diffCols.Intersects(*targetOne.IDSet()) { + if diffCols.Intersects(normalAggArgsIDSet) { // normal agg's target arg columns are being filled with null value in this grouping set, continue next grouping set check. continue } diff --git a/expression/grouping_sets_test.go b/expression/grouping_sets_test.go index 65d3d3df8beb6..d62cf50e5e6ff 100644 --- a/expression/grouping_sets_test.go +++ b/expression/grouping_sets_test.go @@ -15,6 +15,9 @@ package expression import ( + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/types" "testing" "github.com/stretchr/testify/require" @@ -78,6 +81,62 @@ func TestGroupSetsTargetOne(t *testing.T) { // shuffler take place. } +func TestGroupSetsTargetOneCompoundArgs(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + b := &Column{ + UniqueID: 2, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + c := &Column{ + UniqueID: 3, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + d := &Column{ + UniqueID: 4, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + // grouping set: {a,b}, {c} + newGroupingSets := make(GroupingSets, 0, 3) + groupingExprs1 := []Expression{a, b} + groupingExprs2 := []Expression{c} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs1}) + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs2}) + + var normalAggArgs Expression + // mock normal agg count(1) + normalAggArgs = newLonglong(1) + offset := newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // default + + // mock normal agg count(d+1) + normalAggArgs = newFunction(ast.Plus, d, newLonglong(1)) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // default + + // mock normal agg count(d+c) + normalAggArgs = newFunction(ast.Plus, d, c) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 1) // only {c} can supply d and c + + // mock normal agg count(d+a) + normalAggArgs = newFunction(ast.Plus, d, a) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // only {a,b} can supply d and a + + // mock normal agg count(d+a+c) + normalAggArgs = newFunction(ast.Plus, d, newFunction(ast.Plus, a, c)) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.Equal(t, offset, -1) // couldn't find a group that supply d, a and c simultaneously. +} + func TestGroupingSetsMergeOneUnitTest(t *testing.T) { defer view.Stop() a := &Column{ diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_in.json b/planner/core/casetest/testdata/enforce_mpp_suite_in.json index 9b74e49a81808..885a9709018d8 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_in.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_in.json @@ -142,6 +142,8 @@ "select count(distinct a), count(distinct b) from t", "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", "select count(distinct a), count(distinct b), count(c) from t", + "EXPLAIN select count(distinct a), count(distinct b), count(c+1) from t", + "select count(distinct a), count(distinct b), count(c+1) from t", "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", "select count(distinct a), count(distinct b), sum(c) from t", "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", @@ -150,7 +152,11 @@ "select count(distinct a+b), sum(c) from t", "EXPLAIN select count(distinct a+b), count(distinct b+c), count(c) from t", "select count(distinct a+b), count(distinct b+c), count(c) from t", - "explain select count(distinct a,c), count(distinct b,c), count(c) from t" + "explain select count(distinct a,c), count(distinct b,c), count(c) from t", + "select count(distinct a), count(distinct b), count(*) from t", + "explain select count(distinct a), count(distinct b), count(*) from t", + "select count(distinct a), count(distinct b), avg(c+d) from t", + "explain select count(distinct a), count(distinct b), avg(c+d) from t" ] } ] diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/testdata/enforce_mpp_suite_out.json index ebd9545578f3c..54a89e25c8b70 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_out.json @@ -1234,19 +1234,18 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b) from t", "Plan": [ - "TableReader_38 1.00 root data:ExchangeSender_37", - "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7", - " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#6, funcs:sum(Column#16)->Column#7", - " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", - " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", - " └─Projection_32 16000.00 mpp[tiflash] test.t.a, test.t.b, Column#13, case(eq(Column#13, 1), test.t.a, )->Column#15, case(eq(Column#13, 2), test.t.b, )->Column#17", - " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", - " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#13, collate: binary]", - " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#13, test.t.a, test.t.b, ", - " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", - " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#6, funcs:sum(Column#13)->Column#7", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#12, funcs:count(distinct test.t.b)->Column#13", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#11, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#11, test.t.a, test.t.b, ", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -1260,20 +1259,19 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", "Plan": [ - "TableReader_38 1.00 root data:ExchangeSender_37", - "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7, Column#8", - " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", - " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", - " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", - " └─Projection_32 16000.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", - " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", - " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", - " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#18, test.t.a, test.t.b, funcs:count(Column#19)->Column#17", - " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", - " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", - " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), test.t.c, )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -1284,24 +1282,49 @@ ], "Warn": null }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c+1) from t", + "Plan": [ + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), plus(test.t.c, 1), )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), count(c+1) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + }, { "SQL": "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", "Plan": [ - "TableReader_38 1.00 root data:ExchangeSender_37", - "└─ExchangeSender_37 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection_27 1.00 mpp[tiflash] Column#6, Column#7, Column#8", - " └─HashAgg_28 1.00 mpp[tiflash] funcs:sum(Column#20)->Column#6, funcs:sum(Column#22)->Column#7, funcs:sum(Column#24)->Column#8", - " └─ExchangeReceiver_36 1.00 mpp[tiflash] ", - " └─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg_30 1.00 mpp[tiflash] funcs:count(distinct Column#21)->Column#20, funcs:count(distinct Column#23)->Column#22, funcs:sum(Column#17)->Column#24", - " └─Projection_32 16000.00 mpp[tiflash] Column#17, test.t.a, test.t.b, Column#18, case(eq(Column#18, 1), test.t.a, )->Column#21, case(eq(Column#18, 2), test.t.b, )->Column#23", - " └─ExchangeReceiver_34 16000.00 mpp[tiflash] ", - " └─ExchangeSender_33 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#18, collate: binary]", - " └─HashAgg_26 16000.00 mpp[tiflash] group by:Column#26, Column#27, Column#28, funcs:sum(Column#25)->Column#17", - " └─Projection_39 20000.00 mpp[tiflash] cast(Column#19, decimal(10,0) BINARY)->Column#25, test.t.a, test.t.b, Column#18", - " └─Projection_31 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#18, case(eq(Column#18, 1), test.t.c, )->Column#19", - " └─Expand_29 20000.00 mpp[tiflash] group set num:2, groupingID:Column#18, [{},{}]", - " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#22, Column#23, Column#24, funcs:sum(Column#21)->Column#15", + " └─Projection_37 20000.00 mpp[tiflash] cast(Column#17, decimal(10,0) BINARY)->Column#21, test.t.a, test.t.b, Column#16", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), test.t.c, )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -1397,6 +1420,59 @@ " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), count(*) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + }, + { + "SQL": "explain select count(distinct a), count(distinct b), count(*) from t", + "Plan": [ + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, Column#16, case(eq(Column#16, 1), 1, )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), avg(c+d) from t", + "Plan": [ + "5 5 3001.0000" + ], + "Warn": null + }, + { + "SQL": "explain select count(distinct a), count(distinct b), avg(c+d) from t", + "Plan": [ + "TableReader_36 1.00 root data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, div(Column#8, cast(case(eq(Column#19, 0), 1, Column#19), decimal(20,0) BINARY))->Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#25)->Column#6, funcs:sum(Column#26)->Column#7, funcs:sum(Column#27)->Column#19, funcs:sum(Column#28)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#25, funcs:count(distinct test.t.b)->Column#26, funcs:sum(Column#20)->Column#27, funcs:sum(Column#21)->Column#28", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#22, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#31, Column#32, Column#33, funcs:count(Column#29)->Column#20, funcs:sum(Column#30)->Column#21", + " └─Projection_37 20000.00 mpp[tiflash] Column#23, cast(Column#24, decimal(20,0) BINARY)->Column#30, test.t.a, test.t.b, Column#22", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#22, case(eq(Column#22, 1), plus(test.t.c, test.t.d), )->Column#23, case(eq(Column#22, 1), plus(test.t.c, test.t.d), )->Column#24", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#22, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null } ] } diff --git a/planner/core/task.go b/planner/core/task.go index c0a1868df760a..b999721adf49f 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -2134,7 +2134,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *mppTask) task { return mpp } -// scaleStats4GroupingSets scale the derived stats because the lower source has been repeated. +// scaleStats4GroupingSets scale the derived stats because the lower source has been expanded. // // parent OP <- logicalAgg <- children OP (derived stats) // | @@ -2142,9 +2142,9 @@ func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *mppTask) task { // parent OP <- physicalAgg <- children OP (stats used) // | // +----------+----------+----------+ -// Final Mid Partial Repeat +// Final Mid Partial Expand // -// physical agg stats is reasonable from the whole, because repeat operator is designed to facilitate +// physical agg stats is reasonable from the whole, because expand operator is designed to facilitate // the Mid and Partial Agg, which means when leaving the Final, its output rowcount could be exactly // the same as what it derived(estimated) before entering physical optimization phase. // @@ -2250,19 +2250,19 @@ func (p *PhysicalHashAgg) scaleStats4GroupingSets(groupingSets expression.Groupi // +- Expand {}, {} -> expand // +- TableScan foo func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan, canUse3StageAgg bool, - groupingSets expression.GroupingSets, mpp *mppTask) (final, mid, proj1, part, proj2 PhysicalPlan, _ error) { + groupingSets expression.GroupingSets, mpp *mppTask) (final, mid, part, proj4Part PhysicalPlan, _ error) { if !(partialAgg != nil && canUse3StageAgg) { - // quick path: return the original finalAgg and partiAgg - return finalAgg, nil, nil, partialAgg, nil, nil + // quick path: return the original finalAgg and partiAgg. + return finalAgg, nil, partialAgg, nil, nil } if len(groupingSets) == 0 { // single distinct agg mode. clonedAgg, err := finalAgg.Clone() if err != nil { - return nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, err } - // step1: adjust middle agg + // step1: adjust middle agg. middleHashAgg := clonedAgg.(*PhysicalHashAgg) distinctPos := 0 middleSchema := expression.NewSchema() @@ -2285,7 +2285,7 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan } middleHashAgg.schema = middleSchema - // step2: adjust final agg + // step2: adjust final agg. finalHashAgg := finalAgg.(*PhysicalHashAgg) finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) for i, fun := range finalHashAgg.AggFuncs { @@ -2299,7 +2299,7 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan for _, arg := range fun.Args { newCol, err := arg.RemapColumn(schemaMap) if err != nil { - return nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, err } newArgs = append(newArgs, newCol) } @@ -2309,15 +2309,15 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan finalAggDescs = append(finalAggDescs, fun) } finalHashAgg.AggFuncs = finalAggDescs - // partialAgg is im-mutated from args - return finalHashAgg, middleHashAgg, nil, partialAgg, nil, nil + // partialAgg is im-mutated from args. + return finalHashAgg, middleHashAgg, partialAgg, nil, nil } // multi distinct agg mode, having grouping sets. // set the default expression to constant 1 for the convenience to choose default group set data. var groupingIDCol expression.Expression // enforce Expand operator above the children. // physical plan is enumerated without children from itself, use mpp subtree instead p.children. - // scale(len(groupingSets)) will change the NDV, while Repeat doesn't change the NDV and groupNDV. + // scale(len(groupingSets)) will change the NDV, while Expand doesn't change the NDV and groupNDV. stats := mpp.p.statsInfo().Scale(float64(1)) stats.RowCount = stats.RowCount * float64(len(groupingSets)) physicalExpand := PhysicalExpand{ @@ -2340,41 +2340,16 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan // having group sets clonedAgg, err := finalAgg.Clone() if err != nil { - return nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, err } cloneHashAgg := clonedAgg.(*PhysicalHashAgg) // Clone(), it will share same base-plan elements from the finalAgg, including id,tp,stats. Make a new one here. cloneHashAgg.basePlan = newBasePlan(cloneHashAgg.ctx, cloneHashAgg.tp, cloneHashAgg.blockOffset) cloneHashAgg.stats = finalAgg.Stats() // reuse the final agg stats here. - // select count(distinct a), count(distinct b), count(c) from t - // final agg : sum(#1), sum(#2), sum(#3) - // | | +--------------------------------------------+ - // | +--------------------------------+ | - // +-------------+ | | - // v v v - // middle agg : count(distinct caseWhen4a) as #1, count(distinct caseWhen4b) as #2, sum(#4) as #3 # all partial result - // | | - // +---------------------+ | - // v v - // proj4Middle : a, b, groupingID, #4(partial res), caseWhen4a, caseWhen4b # caseWhen4a & caseWhen4a depend on groupingID and a, b - // | - // | - // exchange hash partition: # shuffle data by - // | - // v - // partial agg : a, b, groupingID, count(caseWhen4c) as #4 # since group by a, b, groupingID we can directly output all this three. - // | - // v - // proj4Partial: a, b, c, groupingID, caseWhen4c # caseWhen4c depend on groupingID and c - // - // expand: a, b, c, groupingID # appended groupingID to diff group set - // - // lower source: a, b, c # don't care what the lower source is - // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. // Since we may substitute the first arg of normal agg with case-when expression here, append a - // customized proj here rather than depend on postOptimize to insert a blunt one for us. + // customized proj here rather than depending on postOptimize to insert a blunt one for us. // // proj4Partial output all the base col from lower op + caseWhen proj cols. proj4Partial := new(PhysicalProjection).Init(p.ctx, mpp.p.statsInfo(), mpp.p.SelectBlockOffset()) @@ -2391,6 +2366,7 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan for _, fun := range partialHashAgg.AggFuncs { if !fun.HasDistinct { // for normal agg phase1, we should also modify them to target for specified group data. + // Expr = (case when groupingID = targeted_groupingID then arg else null end) eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) caseWhenProjCol := &expression.Column{ @@ -2404,37 +2380,18 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan } // step2: adjust middle agg - // proj4Middle output all the base col from lower op + caseWhen proj cols.(reuse partial agg stats) - proj4Middle := new(PhysicalProjection).Init(p.ctx, partialHashAgg.stats, partialHashAgg.SelectBlockOffset()) // middleHashAgg shared the same stats with the final agg does. middleHashAgg := cloneHashAgg middleSchema := expression.NewSchema() schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) - for _, col := range partialHashAgg.Schema().Columns { - proj4Middle.Exprs = append(proj4Middle.Exprs, col) - } - proj4Middle.SetSchema(partialHashAgg.Schema().Clone()) for _, fun := range middleHashAgg.AggFuncs { col := &expression.Column{ UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), RetType: fun.RetTp, } if fun.HasDistinct { + // let count distinct agg aggregate on whole-scope data rather using case-when expr to target on specified group. (agg null strict attribute) fun.Mode = aggregation.Partial1Mode - // change the middle distinct agg args to target for specified group data. - // even for distinct agg multi args like count(distinct a,b), either null of a,b will cause count func to skip this row. - // so we only need to set the either one arg of them to be a case when expression to target for its self-group: - // Expr = (case when groupingID = ? then arg else null end) - // count(distinct a,b) => count(distinct (case when groupingID = targetID then a else null end), b) - eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) - caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) - caseWhenProjCol := &expression.Column{ - UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: fun.Args[0].GetType(), - } - proj4Middle.Exprs = append(proj4Middle.Exprs, caseWhen) - proj4Middle.Schema().Append(caseWhenProjCol) - fun.Args[0] = caseWhenProjCol } else { fun.Mode = aggregation.Partial2Mode originalCol := fun.Args[0].(*expression.Column) @@ -2462,7 +2419,7 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan for _, arg := range fun.Args { newCol, err := arg.RemapColumn(schemaMap) if err != nil { - return nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, err } newArgs = append(newArgs, newCol) } @@ -2473,7 +2430,7 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan finalAggDescs = append(finalAggDescs, fun) } finalHashAgg.AggFuncs = finalAggDescs - return finalHashAgg, middleHashAgg, proj4Middle, partialHashAgg, proj4Partial, nil + return finalHashAgg, middleHashAgg, partialHashAgg, proj4Partial, nil } func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { @@ -2549,7 +2506,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { return invalidTask } - final, middle, proj4Middle, partial, proj4Partial, err := p.adjust3StagePhaseAgg(partialAgg, finalAgg, canUse3StageAgg, groupingSets, mpp) + final, middle, partial, proj4Partial, err := p.adjust3StagePhaseAgg(partialAgg, finalAgg, canUse3StageAgg, groupingSets, mpp) if err != nil { return invalidTask } @@ -2580,9 +2537,6 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { exProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} newMpp := mpp.enforceExchanger(exProp) - if proj4Middle != nil { - attachPlan2Task(proj4Middle, newMpp) - } attachPlan2Task(middle, newMpp) mpp = newMpp } From bd49460421719221785f2b8804bff709dbfc0951 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Fri, 6 Jan 2023 17:09:19 +0800 Subject: [PATCH 21/25] fix test Signed-off-by: AilinKid <314806019@qq.com> --- .../casetest/testdata/plan_suite_out.json | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/planner/core/casetest/testdata/plan_suite_out.json b/planner/core/casetest/testdata/plan_suite_out.json index 91669f2dc17d9..8d0e7449613ac 100644 --- a/planner/core/casetest/testdata/plan_suite_out.json +++ b/planner/core/casetest/testdata/plan_suite_out.json @@ -4280,22 +4280,21 @@ "TableReader 1.00 root MppVersion: 1, data:ExchangeSender", "└─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection 1.00 mpp[tiflash] Column#7", - " └─HashAgg 1.00 mpp[tiflash] group by:Column#22, funcs:sum(Column#23)->Column#7", + " └─HashAgg 1.00 mpp[tiflash] group by:Column#18, funcs:sum(Column#19)->Column#7", " └─ExchangeReceiver 1.00 mpp[tiflash] ", - " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#22, collate: binary]", - " └─HashAgg 1.00 mpp[tiflash] group by:Column#25, funcs:count(Column#24)->Column#23", - " └─Projection 1.00 mpp[tiflash] Column#5, plus(Column#6, 1)->Column#25", + " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#18, collate: binary]", + " └─HashAgg 1.00 mpp[tiflash] group by:Column#21, funcs:count(Column#20)->Column#19", + " └─Projection 1.00 mpp[tiflash] Column#5, plus(Column#6, 1)->Column#21", " └─Projection 1.00 mpp[tiflash] Column#5, Column#6", - " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#5, funcs:sum(Column#16)->Column#6", + " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#5, funcs:sum(Column#13)->Column#6", " └─ExchangeReceiver 1.00 mpp[tiflash] ", " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", - " └─Projection 16000.00 mpp[tiflash] test.employee.deptid, test.employee.empid, Column#13, case(eq(Column#13, 1), test.employee.deptid, )->Column#15, case(eq(Column#13, 2), test.employee.empid, )->Column#17", - " └─ExchangeReceiver 16000.00 mpp[tiflash] ", - " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#13, collate: binary]", - " └─HashAgg 16000.00 mpp[tiflash] group by:Column#13, test.employee.deptid, test.employee.empid, ", - " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", - " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" + " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#12, funcs:count(distinct test.employee.empid)->Column#13", + " └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", + " └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", + " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", + " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, { @@ -4306,16 +4305,15 @@ " └─Projection 1.00 mpp[tiflash] Column#7", " └─HashAgg 1.00 mpp[tiflash] group by:Column#6, funcs:count(Column#5)->Column#7", " └─Projection 1.00 mpp[tiflash] Column#5, Column#6", - " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#14)->Column#5, funcs:sum(Column#16)->Column#6", + " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#5, funcs:sum(Column#13)->Column#6", " └─ExchangeReceiver 1.00 mpp[tiflash] ", " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct Column#15)->Column#14, funcs:count(distinct Column#17)->Column#16", - " └─Projection 16000.00 mpp[tiflash] test.employee.deptid, test.employee.empid, Column#13, case(eq(Column#13, 1), test.employee.deptid, )->Column#15, case(eq(Column#13, 2), test.employee.empid, )->Column#17", - " └─ExchangeReceiver 16000.00 mpp[tiflash] ", - " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#13, collate: binary]", - " └─HashAgg 16000.00 mpp[tiflash] group by:Column#13, test.employee.deptid, test.employee.empid, ", - " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#13, [{},{}]", - " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" + " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#12, funcs:count(distinct test.employee.empid)->Column#13", + " └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", + " └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", + " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", + " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, { @@ -4327,16 +4325,15 @@ " ├─ExchangeReceiver(Build) 1.00 mpp[tiflash] ", " │ └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: Broadcast", " │ └─Projection 1.00 mpp[tiflash] Column#9, Column#10", - " │ └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#9, funcs:sum(Column#14)->Column#10", + " │ └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#9, funcs:sum(Column#13)->Column#10", " │ └─ExchangeReceiver 1.00 mpp[tiflash] ", " │ └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " │ └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct Column#13)->Column#12, funcs:count(distinct Column#15)->Column#14", - " │ └─Projection 16000.00 mpp[tiflash] test.employee.deptid, test.employee.empid, Column#11, case(eq(Column#11, 1), test.employee.deptid, )->Column#13, case(eq(Column#11, 2), test.employee.empid, )->Column#15", - " │ └─ExchangeReceiver 16000.00 mpp[tiflash] ", - " │ └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", - " │ └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", - " │ └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", - " │ └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo", + " │ └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#12, funcs:count(distinct test.employee.empid)->Column#13", + " │ └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " │ └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", + " │ └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", + " │ └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", + " │ └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo", " └─TableFullScan(Probe) 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, From 98dd3281746deb3d8f4cf0ec96aa24a075dff2a7 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Fri, 6 Jan 2023 17:37:22 +0800 Subject: [PATCH 22/25] make fmt Signed-off-by: AilinKid <314806019@qq.com> --- expression/grouping_sets_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/expression/grouping_sets_test.go b/expression/grouping_sets_test.go index d62cf50e5e6ff..e4f636bb9845d 100644 --- a/expression/grouping_sets_test.go +++ b/expression/grouping_sets_test.go @@ -15,11 +15,11 @@ package expression import ( + "testing" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" - "testing" - "github.com/stretchr/testify/require" "go.opencensus.io/stats/view" ) From 7fd676673637ece3a2d66018c121748e6271c4e5 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Tue, 7 Feb 2023 18:18:15 +0800 Subject: [PATCH 23/25] fix test Signed-off-by: AilinKid <314806019@qq.com> --- .../testdata/enforce_mpp_suite_out.json | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/testdata/enforce_mpp_suite_out.json index 54a89e25c8b70..5b4e97f542b8e 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_out.json @@ -1210,7 +1210,7 @@ { "SQL": "EXPLAIN select count(distinct a) from t", "Plan": [ - "TableReader_30 1.00 root data:ExchangeSender_29", + "TableReader_30 1.00 root MppVersion: 1, data:ExchangeSender_29", "└─ExchangeSender_29 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_23 1.00 mpp[tiflash] Column#6", " └─HashAgg_24 1.00 mpp[tiflash] funcs:sum(Column#8)->Column#6", @@ -1218,7 +1218,7 @@ " └─ExchangeSender_27 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─HashAgg_24 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#8", " └─ExchangeReceiver_26 1.00 mpp[tiflash] ", - " └─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary]", + " └─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", " └─HashAgg_22 1.00 mpp[tiflash] group by:test.t.a, ", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], @@ -1234,7 +1234,7 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7", " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#6, funcs:sum(Column#13)->Column#7", @@ -1242,7 +1242,7 @@ " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#12, funcs:count(distinct test.t.b)->Column#13", " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", - " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#11, collate: binary]", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#11, collate: binary]", " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#11, test.t.a, test.t.b, ", " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" @@ -1259,7 +1259,7 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", @@ -1267,7 +1267,7 @@ " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", - " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), test.t.c, )->Column#17", " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", @@ -1285,7 +1285,7 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c+1) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", @@ -1293,7 +1293,7 @@ " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", - " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), plus(test.t.c, 1), )->Column#17", " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", @@ -1311,7 +1311,7 @@ { "SQL": "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", @@ -1319,7 +1319,7 @@ " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", - " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#22, Column#23, Column#24, funcs:sum(Column#21)->Column#15", " └─Projection_37 20000.00 mpp[tiflash] cast(Column#17, decimal(10,0) BINARY)->Column#21, test.t.a, test.t.b, Column#16", " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), test.t.c, )->Column#17", @@ -1338,7 +1338,7 @@ { "SQL": "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", "Plan": [ - "TableReader_26 1.00 root data:ExchangeSender_25", + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8, Column#9", " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct test.t.a, test.t.b)->Column#6, funcs:count(distinct test.t.b)->Column#7, funcs:sum(Column#12)->Column#8, funcs:sum(Column#13)->Column#9", @@ -1366,7 +1366,7 @@ { "SQL": "EXPLAIN select count(distinct a+b), sum(c) from t", "Plan": [ - "TableReader_26 1.00 root data:ExchangeSender_25", + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7", " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct Column#10)->Column#6, funcs:sum(Column#11)->Column#7", @@ -1388,7 +1388,7 @@ { "SQL": "EXPLAIN select count(distinct a+b), count(distinct b+c), count(c) from t", "Plan": [ - "TableReader_26 1.00 root data:ExchangeSender_25", + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8", " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct Column#12)->Column#6, funcs:count(distinct Column#13)->Column#7, funcs:sum(Column#14)->Column#8", @@ -1410,7 +1410,7 @@ { "SQL": "explain select count(distinct a,c), count(distinct b,c), count(c) from t", "Plan": [ - "TableReader_26 1.00 root data:ExchangeSender_25", + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8", " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct test.t.a, test.t.c)->Column#6, funcs:count(distinct test.t.b, test.t.c)->Column#7, funcs:sum(Column#10)->Column#8", @@ -1431,7 +1431,7 @@ { "SQL": "explain select count(distinct a), count(distinct b), count(*) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", @@ -1439,7 +1439,7 @@ " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", - " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, Column#16, case(eq(Column#16, 1), 1, )->Column#17", " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", @@ -1457,7 +1457,7 @@ { "SQL": "explain select count(distinct a), count(distinct b), avg(c+d) from t", "Plan": [ - "TableReader_36 1.00 root data:ExchangeSender_35", + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, div(Column#8, cast(case(eq(Column#19, 0), 1, Column#19), decimal(20,0) BINARY))->Column#8", " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#25)->Column#6, funcs:sum(Column#26)->Column#7, funcs:sum(Column#27)->Column#19, funcs:sum(Column#28)->Column#8", @@ -1465,7 +1465,7 @@ " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#25, funcs:count(distinct test.t.b)->Column#26, funcs:sum(Column#20)->Column#27, funcs:sum(Column#21)->Column#28", " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", - " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#22, collate: binary]", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#22, collate: binary]", " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#31, Column#32, Column#33, funcs:count(Column#29)->Column#20, funcs:sum(Column#30)->Column#21", " └─Projection_37 20000.00 mpp[tiflash] Column#23, cast(Column#24, decimal(20,0) BINARY)->Column#30, test.t.a, test.t.b, Column#22", " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#22, case(eq(Column#22, 1), plus(test.t.c, test.t.d), )->Column#23, case(eq(Column#22, 1), plus(test.t.c, test.t.d), )->Column#24", From b353681098fcdbfc1d2c4abca23950057196ebd0 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Mon, 20 Feb 2023 17:26:59 +0800 Subject: [PATCH 24/25] add a switch to avoid plan generation fail before tiflash side merged Signed-off-by: AilinKid <314806019@qq.com> --- planner/core/casetest/enforce_mpp_test.go | 7 +++++-- planner/core/task.go | 2 +- sessionctx/variable/session.go | 3 +++ sessionctx/variable/sysvar.go | 4 ++++ sessionctx/variable/tidb_vars.go | 4 ++++ 5 files changed, 17 insertions(+), 3 deletions(-) diff --git a/planner/core/casetest/enforce_mpp_test.go b/planner/core/casetest/enforce_mpp_test.go index 043256b25c5dd..10c49e9091726 100644 --- a/planner/core/casetest/enforce_mpp_test.go +++ b/planner/core/casetest/enforce_mpp_test.go @@ -20,8 +20,10 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/planner/core/internal" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/testkit/external" "github.com/pingcap/tidb/testkit/testdata" "github.com/pingcap/tidb/util/collate" "github.com/stretchr/testify/require" @@ -492,7 +494,7 @@ func TestMPPSingleDistinct3Stage(t *testing.T) { // // since it doesn't change the schema out (index ref is still the right), so by now it's fine. SEE case: EXPLAIN select count(distinct a), count(distinct b), sum(c) from t. func TestMPPMultiDistinct3Stage(t *testing.T) { - store := testkit.CreateMockStore(t, withMockTiFlash(2)) + store := testkit.CreateMockStore(t, internal.WithMockTiFlash(2)) tk := testkit.NewTestKit(t, store) // test table @@ -503,6 +505,7 @@ func TestMPPMultiDistinct3Stage(t *testing.T) { tb := external.GetTableByName(t, tk, "test", "t") err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) require.NoError(t, err) + tk.MustExec("set @@session.tidb_opt_enable_three_stage_multi_distinct_agg=1") tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\";") tk.MustExec("set @@session.tidb_enforce_mpp=1") tk.MustExec("set @@session.tidb_allow_mpp=ON;") @@ -525,7 +528,7 @@ func TestMPPMultiDistinct3Stage(t *testing.T) { Plan []string Warn []string } - enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData := GetEnforceMPPSuiteData() enforceMPPSuiteData.LoadTestCases(t, &input, &output) for i, tt := range input { testdata.OnRecord(func() { diff --git a/planner/core/task.go b/planner/core/task.go index b999721adf49f..5d1af579df163 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1887,7 +1887,7 @@ func (p *basePhysicalAgg) scale3StageForDistinctAgg() (bool, expression.Grouping // canUse3Stage4MultiDistinctAgg returns true if this agg can use 3 stage for multi distinct aggregation func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss expression.GroupingSets) { - if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 { + if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || !p.ctx.GetSessionVars().Enable3StageMultiDistinctAgg || len(p.GroupByItems) > 0 { return false, nil } defer func() { diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index c3babeb0e908f..3faa2fab07358 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -778,6 +778,9 @@ type SessionVars struct { // Enable3StageDistinctAgg indicates whether to allow 3 stage distinct aggregate Enable3StageDistinctAgg bool + // Enable3StageMultiDistinctAgg indicates whether to allow 3 stage multi distinct aggregate + Enable3StageMultiDistinctAgg bool + // MultiStatementMode permits incorrect client library usage. Not recommended to be turned on. MultiStatementMode int diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index f3ccd7dd534b4..a6a50c7c2ffdc 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -181,6 +181,10 @@ var defaultSysVars = []*SysVar{ s.Enable3StageDistinctAgg = TiDBOptOn(val) return nil }}, + {Scope: ScopeGlobal | ScopeSession, Name: TiDBOptEnable3StageMultiDistinctAgg, Value: BoolToOnOff(DefTiDB3StageMultiDistinctAgg), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { + s.Enable3StageMultiDistinctAgg = TiDBOptOn(val) + return nil + }}, {Scope: ScopeSession, Name: TiDBOptWriteRowID, Value: BoolToOnOff(DefOptWriteRowID), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { s.AllowWriteRowID = TiDBOptOn(val) return nil diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index ccb9aa13af366..04698c4e3d0ed 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -63,6 +63,9 @@ const ( // TiDBOpt3StageDistinctAgg is used to indicate whether to plan and execute the distinct agg in 3 stages TiDBOpt3StageDistinctAgg = "tidb_opt_three_stage_distinct_agg" + // TiDBOptEnable3StageMultiDistinctAgg is used to indicate whether to plan and execute the multi distinct agg in 3 stages + TiDBOptEnable3StageMultiDistinctAgg = "tidb_opt_enable_three_stage_multi_distinct_agg" + // TiDBBCJThresholdSize is used to limit the size of small table for mpp broadcast join. // Its unit is bytes, if the size of small table is larger than it, we will not use bcj. TiDBBCJThresholdSize = "tidb_broadcast_join_threshold_size" @@ -1106,6 +1109,7 @@ const ( DefTiDBRemoveOrderbyInSubquery = false DefTiDBSkewDistinctAgg = false DefTiDB3StageDistinctAgg = true + DefTiDB3StageMultiDistinctAgg = false DefTiDBReadStaleness = 0 DefTiDBGCMaxWaitTime = 24 * 60 * 60 DefMaxAllowedPacket uint64 = 67108864 From d9f693b05f9f2e0309cbb7b9f8303d0b9320f116 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Mon, 20 Feb 2023 17:59:40 +0800 Subject: [PATCH 25/25] fix other test Signed-off-by: AilinKid <314806019@qq.com> --- planner/core/casetest/enforce_mpp_test.go | 1 + .../casetest/testdata/plan_suite_out.json | 38 +++++++------------ 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/planner/core/casetest/enforce_mpp_test.go b/planner/core/casetest/enforce_mpp_test.go index 10c49e9091726..152b649437f7d 100644 --- a/planner/core/casetest/enforce_mpp_test.go +++ b/planner/core/casetest/enforce_mpp_test.go @@ -506,6 +506,7 @@ func TestMPPMultiDistinct3Stage(t *testing.T) { err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) require.NoError(t, err) tk.MustExec("set @@session.tidb_opt_enable_three_stage_multi_distinct_agg=1") + defer tk.MustExec("set @@session.tidb_opt_enable_three_stage_multi_distinct_agg=0") tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\";") tk.MustExec("set @@session.tidb_enforce_mpp=1") tk.MustExec("set @@session.tidb_allow_mpp=ON;") diff --git a/planner/core/casetest/testdata/plan_suite_out.json b/planner/core/casetest/testdata/plan_suite_out.json index 8d0e7449613ac..86ee4c51259b2 100644 --- a/planner/core/casetest/testdata/plan_suite_out.json +++ b/planner/core/casetest/testdata/plan_suite_out.json @@ -4280,21 +4280,17 @@ "TableReader 1.00 root MppVersion: 1, data:ExchangeSender", "└─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection 1.00 mpp[tiflash] Column#7", - " └─HashAgg 1.00 mpp[tiflash] group by:Column#18, funcs:sum(Column#19)->Column#7", + " └─HashAgg 1.00 mpp[tiflash] group by:Column#12, funcs:sum(Column#13)->Column#7", " └─ExchangeReceiver 1.00 mpp[tiflash] ", - " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#18, collate: binary]", - " └─HashAgg 1.00 mpp[tiflash] group by:Column#21, funcs:count(Column#20)->Column#19", - " └─Projection 1.00 mpp[tiflash] Column#5, plus(Column#6, 1)->Column#21", + " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#12, collate: binary]", + " └─HashAgg 1.00 mpp[tiflash] group by:Column#15, funcs:count(Column#14)->Column#13", + " └─Projection 1.00 mpp[tiflash] Column#5, plus(Column#6, 1)->Column#15", " └─Projection 1.00 mpp[tiflash] Column#5, Column#6", - " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#5, funcs:sum(Column#13)->Column#6", + " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#5, funcs:count(distinct test.employee.empid)->Column#6", " └─ExchangeReceiver 1.00 mpp[tiflash] ", " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#12, funcs:count(distinct test.employee.empid)->Column#13", - " └─ExchangeReceiver 16000.00 mpp[tiflash] ", - " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", - " └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", - " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", - " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" + " └─HashAgg 1.00 mpp[tiflash] group by:test.employee.deptid, test.employee.empid, ", + " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, { @@ -4305,15 +4301,11 @@ " └─Projection 1.00 mpp[tiflash] Column#7", " └─HashAgg 1.00 mpp[tiflash] group by:Column#6, funcs:count(Column#5)->Column#7", " └─Projection 1.00 mpp[tiflash] Column#5, Column#6", - " └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#5, funcs:sum(Column#13)->Column#6", + " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#5, funcs:count(distinct test.employee.empid)->Column#6", " └─ExchangeReceiver 1.00 mpp[tiflash] ", " └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#12, funcs:count(distinct test.employee.empid)->Column#13", - " └─ExchangeReceiver 16000.00 mpp[tiflash] ", - " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", - " └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", - " └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", - " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" + " └─HashAgg 1.00 mpp[tiflash] group by:test.employee.deptid, test.employee.empid, ", + " └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] }, { @@ -4325,15 +4317,11 @@ " ├─ExchangeReceiver(Build) 1.00 mpp[tiflash] ", " │ └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: Broadcast", " │ └─Projection 1.00 mpp[tiflash] Column#9, Column#10", - " │ └─HashAgg 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#9, funcs:sum(Column#13)->Column#10", + " │ └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#9, funcs:count(distinct test.employee.empid)->Column#10", " │ └─ExchangeReceiver 1.00 mpp[tiflash] ", " │ └─ExchangeSender 1.00 mpp[tiflash] ExchangeType: PassThrough", - " │ └─HashAgg 1.00 mpp[tiflash] funcs:count(distinct test.employee.deptid)->Column#12, funcs:count(distinct test.employee.empid)->Column#13", - " │ └─ExchangeReceiver 16000.00 mpp[tiflash] ", - " │ └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.employee.deptid, collate: binary], [name: test.employee.empid, collate: binary], [name: Column#11, collate: binary]", - " │ └─HashAgg 16000.00 mpp[tiflash] group by:Column#11, test.employee.deptid, test.employee.empid, ", - " │ └─Expand 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", - " │ └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo", + " │ └─HashAgg 1.00 mpp[tiflash] group by:test.employee.deptid, test.employee.empid, ", + " │ └─TableFullScan 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo", " └─TableFullScan(Probe) 10000.00 mpp[tiflash] table:employee keep order:false, stats:pseudo" ] },