diff --git a/executor/partition_table.go b/executor/partition_table.go index de6f17e8d2cc0..714cdc206b95e 100644 --- a/executor/partition_table.go +++ b/executor/partition_table.go @@ -56,6 +56,8 @@ func updateExecutorTableID(ctx context.Context, exec *tipb.Executor, recursive b child = exec.Window.Child case tipb.ExecType_TypeSort: child = exec.Sort.Child + case tipb.ExecType_TypeExpand: + child = exec.Expand.Child default: return errors.Trace(fmt.Errorf("unknown new tipb protocol %d", exec.Tp)) } diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index fcaa063a23c9c..f7abba73c0e9b 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -54,6 +54,7 @@ go_library( "expression.go", "extension.go", "function_traits.go", + "grouping_sets.go", "helper.go", "partition_pruner.go", "scalar_function.go", @@ -78,6 +79,7 @@ go_library( "//parser/opcode", "//parser/terror", "//parser/types", + "//planner/funcdep", "//privilege", "//sessionctx", "//sessionctx/stmtctx", @@ -171,6 +173,7 @@ go_test( "expr_to_pb_test.go", "expression_test.go", "function_traits_test.go", + "grouping_sets_test.go", "helper_test.go", "main_test.go", "multi_valued_index_test.go", @@ -227,6 +230,7 @@ go_test( "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//tikv", + "@io_opencensus_go//stats/view", "@org_uber_go_goleak//:goleak", ], ) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 2af7d4d49f348..3015b4718a226 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -42,6 +42,8 @@ type AggFuncDesc struct { HasDistinct bool // OrderByItems represents the order by clause used in GROUP_CONCAT OrderByItems []*util.ByItems + // GroupingID is used for distinguishing with not-set 0, starting from 1. + GroupingID int } // NewAggFuncDesc creates an aggregation function signature descriptor. diff --git a/expression/constant.go b/expression/constant.go index db9f94147f38a..1f48595985c9b 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -52,6 +52,18 @@ func NewZero() *Constant { } } +// NewUInt64Const stands for constant of a given number. +func NewUInt64Const(num int) *Constant { + retT := types.NewFieldType(mysql.TypeLonglong) + retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion + retT.SetFlen(mysql.MaxIntWidth) + retT.SetDecimal(0) + return &Constant{ + Value: types.NewDatum(num), + RetType: retT, + } +} + // NewNull stands for null constant. func NewNull() *Constant { retT := types.NewFieldType(mysql.TypeTiny) diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go new file mode 100644 index 0000000000000..b252cacdf059a --- /dev/null +++ b/expression/grouping_sets.go @@ -0,0 +1,398 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "strings" + + "github.com/pingcap/tidb/kv" + fd "github.com/pingcap/tidb/planner/funcdep" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/util/size" + "github.com/pingcap/tipb/go-tipb" +) + +// GroupingSets indicates the grouping sets definition. +type GroupingSets []GroupingSet + +// GroupingSet indicates one grouping set definition. +type GroupingSet []GroupingExprs + +// GroupingExprs indicates one grouping-expressions inside a grouping set. +type GroupingExprs []Expression + +// Merge function will explore the internal grouping expressions and try to find the minimum grouping sets. (prefix merging) +func (gss GroupingSets) Merge() GroupingSets { + // for now, there is precondition that all grouping expressions are columns. + // for example: (a,b,c) and (a,b) and (a) will be merged as a one. + // Eg: + // before merging, there are 4 grouping sets. + // GroupingSets: + // [ + // [[a,b,c],] // every set including a grouping Expressions for initial. + // [[a,b],] + // [[a],] + // [[e],] + // ] + // + // after merging, there is only 2 grouping set. + // GroupingSets: + // [ + // [[a],[a,b],[a,b,c],] // one set including 3 grouping Expressions after merging if possible. + // [[e],] + // ] + // + // care about the prefix order, which should be taken as following the group layout expanding rule. (simple way) + // [[a],[b,a],[c,b,a],] is also conforming the rule, gradually including one/more column(s) inside for one time. + + newGroupingSets := make(GroupingSets, 0, len(gss)) + for _, oneGroupingSet := range gss { + for _, oneGroupingExpr := range oneGroupingSet { + if len(newGroupingSets) == 0 { + // means there is nothing in new grouping sets, adding current one anyway. + newGroupingSets = append(newGroupingSets, newGroupingSet(oneGroupingExpr)) + continue + } + newGroupingSets = newGroupingSets.MergeOne(oneGroupingExpr) + } + } + return newGroupingSets +} + +// MergeOne is used to merge one grouping expressions into current grouping sets. +func (gss GroupingSets) MergeOne(targetOne GroupingExprs) GroupingSets { + // for every existing grouping set, check the grouping-exprs inside and whether the current grouping-exprs is + // super-set of it or sub-set of it, adding current one to the correct position of grouping-exprs slice. + // + // [[a,b]] + // | + // / offset 0 + // | + // [b] when adding [b] grouping expr here, since it's a sub-set of current [a,b] with offset 0, take the offset 0. + // | + // insert with offset 0, and the other elements move right. + // + // [[b], [a,b]] + // offset 0 1 + // \ + // [a,b,c,d] + // when adding [a,b,c,d] grouping expr here, since it's a super-set of current [a,b] with offset 1, take the offset as 1+1. + // + // result grouping set: [[b], [a,b], [a,b,c,d]], expanding with step with two or more columns is acceptable and reasonable. + // | | | + // +----+-------+ every previous one is the subset of the latter one. + // + for i, oneNewGroupingSet := range gss { + // for every group set,try to find its position to insert if possible,otherwise create a new grouping set. + for j := len(oneNewGroupingSet) - 1; j >= 0; j-- { + cur := oneNewGroupingSet[j] + if targetOne.SubSetOf(cur) { + if j == 0 { + // the right pos should be the head (-1) + cp := make(GroupingSet, 0, len(oneNewGroupingSet)+1) + cp = append(cp, targetOne) + cp = append(cp, oneNewGroupingSet...) + gss[i] = cp + return gss + } + // do the left shift to find the right insert pos. + continue + } + if j == len(oneNewGroupingSet)-1 { + // which means the targetOne itself is the super set of current right-most grouping set. + if cur.SubSetOf(targetOne) { + // the right pos should be the len(oneNewGroupingSet) + oneNewGroupingSet = append(oneNewGroupingSet, targetOne) + gss[i] = oneNewGroupingSet + return gss + } + // which means the targetOne can't fit itself in this grouping set, continue next grouping set. + break + } + // successfully fit in, current j is the right pos to insert. + cp := make(GroupingSet, 0, len(oneNewGroupingSet)+1) + cp = append(cp, oneNewGroupingSet[:j+1]...) + cp = append(cp, targetOne) + cp = append(cp, oneNewGroupingSet[j+1:]...) + gss[i] = cp + return gss + } + } + // here means we couldn't find even one GroupingSet to fill the targetOne, creating a new one. + gss = append(gss, newGroupingSet(targetOne)) + // gs is an alias of slice [], we should return it back after being changed. + return gss +} + +// TargetOne is used to find a valid group layout for normal agg, note that: the args in normal agg are not necessary to be column. +func (gss GroupingSets) TargetOne(normalAggArgs []Expression) int { + // it has three cases. + // 1: group sets {}, {} when the normal agg(d), agg(d) can only occur in the group of or after tuple split. (normal column are appended) + // 2: group sets {}, {} when the normal agg(c), agg(c) can only be found in group of , since it will be filled with null in group of + // 3: group sets {}, {} when the normal agg with multi col args, it will be a little difficult, which is banned in canUse3Stage4MultiDistinctAgg by now. + // eg1: agg(c,d), the c and d can only be found in group of in which d is also attached, while c will be filled with null in group of + // eg2: agg(b,c,d), we couldn't find a valid group either in or , unless we copy one column from b and attach it to the group data of . (cp c to is also effective) + // + // from the theoretical, why we have to fill the non-current-group-column with null value? even we may fill it as null value and copy it again like c in case 3, + // the basic reason is that we need a unified group layout sequence to feed different layout-required distinct agg. Like what we did here, we should group the + // original source with sequence columns as . For distinct(a,b), we don't wanna c in this group layout to impact what we are targeting for --- the + // real group from . So even we had the groupingID in repeated data row to identify which distinct agg this row is prepared for, we still need to fill the c + // as null and groupingID as 1 to guarantee group operator can equate to group which is equal to group, that's what the upper layer + // distinct(a,b) want for. For distinct(c), the same is true. + // + // For normal agg you better choose your targeting group-set data, otherwise, otherwise your get is all groups,most of them is endless null values filled by + // Expand operator, and null value can also influence your non null-strict normal agg, although you don't want them to. + // + // here we only consider 1&2 since normal agg with multi col args is banned. + columnInNormalAggArgs := make([]*Column, 0, len(normalAggArgs)) + for _, one := range normalAggArgs { + columnInNormalAggArgs = append(columnInNormalAggArgs, ExtractColumns(one)...) + } + if len(columnInNormalAggArgs) == 0 { + // no column in normal agg. eg: count(1), specify the default grouping set ID 0+1. + return 0 + } + // for other normal agg args like: count(a), count(a+b), count(not(a is null)) and so on. + normalAggArgsIDSet := fd.NewFastIntSet() + for _, one := range columnInNormalAggArgs { + normalAggArgsIDSet.Insert(int(one.UniqueID)) + } + + // identify the suitable grouping set for normal agg. + allGroupingColIDs := gss.AllSetsColIDs() + for idx, groupingSet := range gss { + // diffCols are those columns being filled with null in the group row data of current grouping set. + diffCols := allGroupingColIDs.Difference(*groupingSet.allSetColIDs()) + if diffCols.Intersects(normalAggArgsIDSet) { + // normal agg's target arg columns are being filled with null value in this grouping set, continue next grouping set check. + continue + } + return idx + } + // todo: if we couldn't find a existed current valid group layout, we need to copy the column out from being filled with null value. + return -1 +} + +// NeedCloneColumn indicate whether we need to copy column to when expanding datasource. +func (gss GroupingSets) NeedCloneColumn() bool { + // for grouping sets like: {},{} / {},{} + // the column c should be copied one more time here, otherwise it will be filled with null values and not visible for the other grouping set again. + setIDs := make([]*fd.FastIntSet, 0, len(gss)) + for _, groupingSet := range gss { + setIDs = append(setIDs, groupingSet.allSetColIDs()) + } + for idx, oneSetIDs := range setIDs { + for j := idx + 1; j < len(setIDs); j++ { + otherSetIDs := setIDs[j] + if oneSetIDs.Intersects(*otherSetIDs) { + return true + } + } + } + return false +} + +// IsEmpty indicates whether current grouping set is empty. +func (gs GroupingSet) IsEmpty() bool { + if len(gs) == 0 { + return true + } + for _, g := range gs { + if !g.IsEmpty() { + return false + } + } + return true +} + +func (gs GroupingSet) allSetColIDs() *fd.FastIntSet { + res := fd.NewFastIntSet() + for _, groupingExprs := range gs { + for _, one := range groupingExprs { + res.Insert(int(one.(*Column).UniqueID)) + } + } + return &res +} + +// ExtractCols is used to extract basic columns from one grouping set. +func (gs GroupingSet) ExtractCols() []*Column { + cols := make([]*Column, 0, len(gs)) + for _, groupingExprs := range gs { + for _, one := range groupingExprs { + cols = append(cols, one.(*Column)) + } + } + return cols +} + +// Clone is used to clone a copy of current grouping set. +func (gs GroupingSet) Clone() GroupingSet { + gc := make(GroupingSet, 0, len(gs)) + for _, one := range gs { + gc = append(gc, one.Clone()) + } + return gc +} + +// String is used to output a string which simply described current grouping set. +func (gs GroupingSet) String() string { + var str strings.Builder + str.WriteString("{") + for i, one := range gs { + if i != 0 { + str.WriteString(",") + } + str.WriteString(one.String()) + } + str.WriteString("}") + return str.String() +} + +// MemoryUsage is used to output current memory usage by current grouping set. +func (gs GroupingSet) MemoryUsage() int64 { + sum := size.SizeOfSlice + int64(cap(gs))*size.SizeOfPointer + for _, one := range gs { + sum += one.MemoryUsage() + } + return sum +} + +// ToPB is used to convert current grouping set to pb constructor. +func (gs GroupingSet) ToPB(sc *stmtctx.StatementContext, client kv.Client) (*tipb.GroupingSet, error) { + res := &tipb.GroupingSet{} + for _, gExprs := range gs { + gExprsPB, err := ExpressionsToPBList(sc, gExprs, client) + if err != nil { + return nil, err + } + res.GroupingExprs = append(res.GroupingExprs, &tipb.GroupingExpr{GroupingExpr: gExprsPB}) + } + return res, nil +} + +// IsEmpty indicates whether current grouping sets is empty. +func (gss GroupingSets) IsEmpty() bool { + if len(gss) == 0 { + return true + } + for _, gs := range gss { + if !gs.IsEmpty() { + return false + } + } + return true +} + +// AllSetsColIDs is used to collect all the column id inside into a fast int set. +func (gss GroupingSets) AllSetsColIDs() *fd.FastIntSet { + res := fd.NewFastIntSet() + for _, groupingSet := range gss { + res.UnionWith(*groupingSet.allSetColIDs()) + } + return &res +} + +// String is used to output a string which simply described current grouping sets. +func (gss GroupingSets) String() string { + var str strings.Builder + str.WriteString("[") + for i, gs := range gss { + if i != 0 { + str.WriteString(",") + } + str.WriteString(gs.String()) + } + str.WriteString("]") + return str.String() +} + +// ToPB is used to convert current grouping sets to pb constructor. +func (gss GroupingSets) ToPB(sc *stmtctx.StatementContext, client kv.Client) ([]*tipb.GroupingSet, error) { + res := make([]*tipb.GroupingSet, 0, len(gss)) + for _, gs := range gss { + one, err := gs.ToPB(sc, client) + if err != nil { + return nil, err + } + res = append(res, one) + } + return res, nil +} + +func newGroupingSet(oneGroupingExpr GroupingExprs) GroupingSet { + res := make(GroupingSet, 0, 1) + res = append(res, oneGroupingExpr) + return res +} + +// IsEmpty indicates whether current grouping expressions are empty. +func (g GroupingExprs) IsEmpty() bool { + return len(g) == 0 +} + +// SubSetOf is used to do the logical computation of subset between two grouping expressions. +func (g GroupingExprs) SubSetOf(other GroupingExprs) bool { + oldOne := fd.NewFastIntSet() + newOne := fd.NewFastIntSet() + for _, one := range g { + oldOne.Insert(int(one.(*Column).UniqueID)) + } + for _, one := range other { + newOne.Insert(int(one.(*Column).UniqueID)) + } + return oldOne.SubsetOf(newOne) +} + +// IDSet is used to collect column ids inside grouping expressions into a fast int set. +func (g GroupingExprs) IDSet() *fd.FastIntSet { + res := fd.NewFastIntSet() + for _, one := range g { + res.Insert(int(one.(*Column).UniqueID)) + } + return &res +} + +// Clone is used to clone a copy of current grouping expressions. +func (g GroupingExprs) Clone() GroupingExprs { + gc := make(GroupingExprs, 0, len(g)) + for _, one := range g { + gc = append(gc, one.Clone()) + } + return gc +} + +// String is used to output a string which simply described current grouping expressions. +func (g GroupingExprs) String() string { + var str strings.Builder + str.WriteString("<") + for i, one := range g { + if i != 0 { + str.WriteString(",") + } + str.WriteString(one.String()) + } + str.WriteString(">") + return str.String() +} + +// MemoryUsage is used to output current memory usage by current grouping expressions. +func (g GroupingExprs) MemoryUsage() int64 { + sum := size.SizeOfSlice + int64(cap(g))*size.SizeOfInterface + for _, one := range g { + sum += one.MemoryUsage() + } + return sum +} diff --git a/expression/grouping_sets_test.go b/expression/grouping_sets_test.go new file mode 100644 index 0000000000000..e4f636bb9845d --- /dev/null +++ b/expression/grouping_sets_test.go @@ -0,0 +1,248 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "testing" + + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/types" + "github.com/stretchr/testify/require" + "go.opencensus.io/stats/view" +) + +func TestGroupSetsTargetOne(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + } + b := &Column{ + UniqueID: 2, + } + c := &Column{ + UniqueID: 3, + } + d := &Column{ + UniqueID: 4, + } + // non-merged group sets + // 1: group sets {}, {} when the normal agg(d), d can occur in either group of or after tuple split. (normal column are appended) + // 2: group sets {}, {} when the normal agg(c), c can only be found in group of , since it will be filled with null in group of + // 3: group sets {}, {} when the normal agg with multi col args, it will be a little difficult, which is banned in canUse3Stage4MultiDistinctAgg by now. + newGroupingSets := make(GroupingSets, 0, 3) + groupingExprs1 := []Expression{a, b} + groupingExprs2 := []Expression{c} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs1}) + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs2}) + + targetOne := []Expression{d} + offset := newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // and both ok + + targetOne = []Expression{c} + offset = newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 1) // only is ok + + targetOne = []Expression{b} + offset = newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // only is ok + + targetOne = []Expression{a} + offset = newGroupingSets.TargetOne(targetOne) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // only is ok + + targetOne = []Expression{b, c} + offset = newGroupingSets.TargetOne(targetOne) + require.Equal(t, offset, -1) // no valid one can be found. + + // merged group sets + // merged group sets {, } when the normal agg(d), + // from prospect from data grouping, we need keep the widest grouping layout when trying + // to shuffle data targeting for grouping layout {, }; Additional, since column b + // hasn't been pre-agg or something because current layout is , we should keep all the + // complete rows until to the upper receiver, which means there is no agg/group operator before + // shuffler take place. +} + +func TestGroupSetsTargetOneCompoundArgs(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + b := &Column{ + UniqueID: 2, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + c := &Column{ + UniqueID: 3, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + d := &Column{ + UniqueID: 4, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + // grouping set: {a,b}, {c} + newGroupingSets := make(GroupingSets, 0, 3) + groupingExprs1 := []Expression{a, b} + groupingExprs2 := []Expression{c} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs1}) + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs2}) + + var normalAggArgs Expression + // mock normal agg count(1) + normalAggArgs = newLonglong(1) + offset := newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // default + + // mock normal agg count(d+1) + normalAggArgs = newFunction(ast.Plus, d, newLonglong(1)) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // default + + // mock normal agg count(d+c) + normalAggArgs = newFunction(ast.Plus, d, c) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 1) // only {c} can supply d and c + + // mock normal agg count(d+a) + normalAggArgs = newFunction(ast.Plus, d, a) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.NotEqual(t, offset, -1) + require.Equal(t, offset, 0) // only {a,b} can supply d and a + + // mock normal agg count(d+a+c) + normalAggArgs = newFunction(ast.Plus, d, newFunction(ast.Plus, a, c)) + offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) + require.Equal(t, offset, -1) // couldn't find a group that supply d, a and c simultaneously. +} + +func TestGroupingSetsMergeOneUnitTest(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + } + b := &Column{ + UniqueID: 2, + } + c := &Column{ + UniqueID: 3, + } + d := &Column{ + UniqueID: 4, + } + // test case about the right most fitness. + newGroupingSets := make(GroupingSets, 0, 3) + newGroupingSets = newGroupingSets[:0] + groupingExprs := []Expression{a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + newGroupingSets.MergeOne([]Expression{a, b}) + + require.Equal(t, len(newGroupingSets), 1) + require.Equal(t, len(newGroupingSets[0]), 2) + //{a} + require.Equal(t, len(newGroupingSets[0][0]), 1) + //{a,b} + require.Equal(t, len(newGroupingSets[0][1]), 2) + + newGroupingSets = newGroupingSets[:0] + groupingExprs = []Expression{a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + newGroupingSets.MergeOne([]Expression{d, c, b, a}) + newGroupingSets.MergeOne([]Expression{d, b, a}) + newGroupingSets.MergeOne([]Expression{b, a}) + + require.Equal(t, len(newGroupingSets), 1) + require.Equal(t, len(newGroupingSets[0]), 4) + //{a} + require.Equal(t, len(newGroupingSets[0][0]), 1) + //{b,a} + require.Equal(t, len(newGroupingSets[0][1]), 2) + //{d,b,a} + require.Equal(t, len(newGroupingSets[0][2]), 3) + //{d,c,b,a} + require.Equal(t, len(newGroupingSets[0][3]), 4) +} + +func TestGroupingSetsMergeUnitTest(t *testing.T) { + defer view.Stop() + a := &Column{ + UniqueID: 1, + } + b := &Column{ + UniqueID: 2, + } + c := &Column{ + UniqueID: 3, + } + d := &Column{ + UniqueID: 4, + } + // [[c,d,a,b], [b], [a,b]] + newGroupingSets := make(GroupingSets, 0, 3) + groupingExprs := []Expression{c, d, a, b} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{b} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{a, b} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + + // [[b], [a,b], [c,d,a,b]] + newGroupingSets = newGroupingSets.Merge() + require.Equal(t, len(newGroupingSets), 1) + // only one grouping set with 3 grouping expressions. + require.Equal(t, len(newGroupingSets[0]), 3) + // [b] + require.Equal(t, len(newGroupingSets[0][0]), 1) + // [a,b] + require.Equal(t, len(newGroupingSets[0][1]), 2) + // [c,d,a,b]] + require.Equal(t, len(newGroupingSets[0][2]), 4) + + // [ + // [[a],[b,a],[c,b,a],] // one set including 3 grouping Expressions after merging if possible. + // [[d],] + // ] + newGroupingSets = newGroupingSets[:0] + groupingExprs = []Expression{c, b, a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{b, a} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + groupingExprs = []Expression{d} + newGroupingSets = append(newGroupingSets, GroupingSet{groupingExprs}) + + // [[a],[b,a],[c,b,a],] + // [[d],] + newGroupingSets = newGroupingSets.Merge() + require.Equal(t, len(newGroupingSets), 2) + // [a] + require.Equal(t, len(newGroupingSets[0][0]), 1) + // [b,a] + require.Equal(t, len(newGroupingSets[0][1]), 2) + // [c,b,a] + require.Equal(t, len(newGroupingSets[0][2]), 3) + // [d] + require.Equal(t, len(newGroupingSets[1][0]), 1) +} diff --git a/planner/core/casetest/enforce_mpp_test.go b/planner/core/casetest/enforce_mpp_test.go index aaa2037efc0ce..152b649437f7d 100644 --- a/planner/core/casetest/enforce_mpp_test.go +++ b/planner/core/casetest/enforce_mpp_test.go @@ -20,8 +20,10 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/planner/core/internal" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/testkit/external" "github.com/pingcap/tidb/testkit/testdata" "github.com/pingcap/tidb/util/collate" "github.com/stretchr/testify/require" @@ -487,3 +489,63 @@ func TestMPPSingleDistinct3Stage(t *testing.T) { require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) } } + +// todo: some post optimization after resolveIndices will inject another projection below agg, which change the column name used in higher operator, +// +// since it doesn't change the schema out (index ref is still the right), so by now it's fine. SEE case: EXPLAIN select count(distinct a), count(distinct b), sum(c) from t. +func TestMPPMultiDistinct3Stage(t *testing.T) { + store := testkit.CreateMockStore(t, internal.WithMockTiFlash(2)) + tk := testkit.NewTestKit(t, store) + + // test table + tk.MustExec("use test;") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int, d int);") + tk.MustExec("alter table t set tiflash replica 1") + tb := external.GetTableByName(t, tk, "test", "t") + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("set @@session.tidb_opt_enable_three_stage_multi_distinct_agg=1") + defer tk.MustExec("set @@session.tidb_opt_enable_three_stage_multi_distinct_agg=0") + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\";") + tk.MustExec("set @@session.tidb_enforce_mpp=1") + tk.MustExec("set @@session.tidb_allow_mpp=ON;") + // todo: current mock regionCache won't scale the regions among tiFlash nodes. The under layer still collect data from only one of the nodes. + tk.MustExec("split table t BETWEEN (0) AND (5000) REGIONS 5;") + tk.MustExec("insert into t values(1000, 1000, 1000, 1)") + tk.MustExec("insert into t values(1000, 1000, 1000, 1)") + tk.MustExec("insert into t values(2000, 2000, 2000, 1)") + tk.MustExec("insert into t values(2000, 2000, 2000, 1)") + tk.MustExec("insert into t values(3000, 3000, 3000, 1)") + tk.MustExec("insert into t values(3000, 3000, 3000, 1)") + tk.MustExec("insert into t values(4000, 4000, 4000, 1)") + tk.MustExec("insert into t values(4000, 4000, 4000, 1)") + tk.MustExec("insert into t values(5000, 5000, 5000, 1)") + tk.MustExec("insert into t values(5000, 5000, 5000, 1)") + + var input []string + var output []struct { + SQL string + Plan []string + Warn []string + } + enforceMPPSuiteData := GetEnforceMPPSuiteData() + enforceMPPSuiteData.LoadTestCases(t, &input, &output) + for i, tt := range input { + testdata.OnRecord(func() { + output[i].SQL = tt + }) + if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") { + tk.MustExec(tt) + continue + } + testdata.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) + } +} diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_in.json b/planner/core/casetest/testdata/enforce_mpp_suite_in.json index fc6d31dec9da9..885a9709018d8 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_in.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_in.json @@ -132,5 +132,31 @@ "EXPLAIN select count(distinct c+a), count(a) from t;", "EXPLAIN select sum(b), count(distinct c+a, b, e), count(a+b) from t;" ] + }, + { + "name": "TestMPPMultiDistinct3Stage", + "cases": [ + "EXPLAIN select count(distinct a) from t", + "select count(distinct a) from t", + "EXPLAIN select count(distinct a), count(distinct b) from t", + "select count(distinct a), count(distinct b) from t", + "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", + "select count(distinct a), count(distinct b), count(c) from t", + "EXPLAIN select count(distinct a), count(distinct b), count(c+1) from t", + "select count(distinct a), count(distinct b), count(c+1) from t", + "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", + "select count(distinct a), count(distinct b), sum(c) from t", + "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "EXPLAIN select count(distinct a+b), sum(c) from t", + "select count(distinct a+b), sum(c) from t", + "EXPLAIN select count(distinct a+b), count(distinct b+c), count(c) from t", + "select count(distinct a+b), count(distinct b+c), count(c) from t", + "explain select count(distinct a,c), count(distinct b,c), count(c) from t", + "select count(distinct a), count(distinct b), count(*) from t", + "explain select count(distinct a), count(distinct b), count(*) from t", + "select count(distinct a), count(distinct b), avg(c+d) from t", + "explain select count(distinct a), count(distinct b), avg(c+d) from t" + ] } ] diff --git a/planner/core/casetest/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/testdata/enforce_mpp_suite_out.json index 483291bfabd22..5b4e97f542b8e 100644 --- a/planner/core/casetest/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/testdata/enforce_mpp_suite_out.json @@ -1203,5 +1203,277 @@ "Warn": null } ] + }, + { + "Name": "TestMPPMultiDistinct3Stage", + "Cases": [ + { + "SQL": "EXPLAIN select count(distinct a) from t", + "Plan": [ + "TableReader_30 1.00 root MppVersion: 1, data:ExchangeSender_29", + "└─ExchangeSender_29 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_23 1.00 mpp[tiflash] Column#6", + " └─HashAgg_24 1.00 mpp[tiflash] funcs:sum(Column#8)->Column#6", + " └─ExchangeReceiver_28 1.00 mpp[tiflash] ", + " └─ExchangeSender_27 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_24 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#8", + " └─ExchangeReceiver_26 1.00 mpp[tiflash] ", + " └─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", + " └─HashAgg_22 1.00 mpp[tiflash] group by:test.t.a, ", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a) from t", + "Plan": [ + "5" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b) from t", + "Plan": [ + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#12)->Column#6, funcs:sum(Column#13)->Column#7", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#12, funcs:count(distinct test.t.b)->Column#13", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#11, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#11, test.t.a, test.t.b, ", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#11, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b) from t", + "Plan": [ + "5 5" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c) from t", + "Plan": [ + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), test.t.c, )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), count(c) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b), count(c+1) from t", + "Plan": [ + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), plus(test.t.c, 1), )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), count(c+1) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a), count(distinct b), sum(c) from t", + "Plan": [ + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#22, Column#23, Column#24, funcs:sum(Column#21)->Column#15", + " └─Projection_37 20000.00 mpp[tiflash] cast(Column#17, decimal(10,0) BINARY)->Column#21, test.t.a, test.t.b, Column#16", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, Column#16, case(eq(Column#16, 1), test.t.c, )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), sum(c) from t", + "Plan": [ + "5 5 30000" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "Plan": [ + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8, Column#9", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct test.t.a, test.t.b)->Column#6, funcs:count(distinct test.t.b)->Column#7, funcs:sum(Column#12)->Column#8, funcs:sum(Column#13)->Column#9", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:Column#16, Column#17, funcs:count(Column#14)->Column#12, funcs:sum(Column#15)->Column#13", + " └─Projection_27 10000.00 mpp[tiflash] test.t.c, cast(test.t.d, decimal(10,0) BINARY)->Column#15, test.t.a, test.t.b", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": [ + "Some grouping sets should be merged", + "Some grouping sets should be merged" + ] + }, + { + "SQL": "select count(distinct a, b), count(distinct b), count(c), sum(d) from t", + "Plan": [ + "5 5 10 10" + ], + "Warn": [ + "Some grouping sets should be merged", + "Some grouping sets should be merged" + ] + }, + { + "SQL": "EXPLAIN select count(distinct a+b), sum(c) from t", + "Plan": [ + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct Column#10)->Column#6, funcs:sum(Column#11)->Column#7", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:Column#13, funcs:sum(Column#12)->Column#11", + " └─Projection_27 10000.00 mpp[tiflash] cast(test.t.c, decimal(10,0) BINARY)->Column#12, plus(test.t.a, test.t.b)->Column#13", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a+b), sum(c) from t", + "Plan": [ + "5 30000" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select count(distinct a+b), count(distinct b+c), count(c) from t", + "Plan": [ + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct Column#12)->Column#6, funcs:count(distinct Column#13)->Column#7, funcs:sum(Column#14)->Column#8", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:Column#16, Column#17, funcs:count(Column#15)->Column#14", + " └─Projection_27 10000.00 mpp[tiflash] test.t.c, plus(test.t.a, test.t.b)->Column#16, plus(test.t.b, test.t.c)->Column#17", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a+b), count(distinct b+c), count(c) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + }, + { + "SQL": "explain select count(distinct a,c), count(distinct b,c), count(c) from t", + "Plan": [ + "TableReader_26 1.00 root MppVersion: 1, data:ExchangeSender_25", + "└─ExchangeSender_25 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_21 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_22 1.00 mpp[tiflash] funcs:count(distinct test.t.a, test.t.c)->Column#6, funcs:count(distinct test.t.b, test.t.c)->Column#7, funcs:sum(Column#10)->Column#8", + " └─ExchangeReceiver_24 1.00 mpp[tiflash] ", + " └─ExchangeSender_23 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_20 1.00 mpp[tiflash] group by:test.t.a, test.t.b, test.t.c, funcs:count(test.t.c)->Column#10", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), count(*) from t", + "Plan": [ + "5 5 10" + ], + "Warn": null + }, + { + "SQL": "explain select count(distinct a), count(distinct b), count(*) from t", + "Plan": [ + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#18)->Column#6, funcs:sum(Column#19)->Column#7, funcs:sum(Column#20)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#18, funcs:count(distinct test.t.b)->Column#19, funcs:sum(Column#15)->Column#20", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#16, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#16, test.t.a, test.t.b, funcs:count(Column#17)->Column#15", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, Column#16, case(eq(Column#16, 1), 1, )->Column#17", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#16, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "select count(distinct a), count(distinct b), avg(c+d) from t", + "Plan": [ + "5 5 3001.0000" + ], + "Warn": null + }, + { + "SQL": "explain select count(distinct a), count(distinct b), avg(c+d) from t", + "Plan": [ + "TableReader_36 1.00 root MppVersion: 1, data:ExchangeSender_35", + "└─ExchangeSender_35 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_26 1.00 mpp[tiflash] Column#6, Column#7, div(Column#8, cast(case(eq(Column#19, 0), 1, Column#19), decimal(20,0) BINARY))->Column#8", + " └─HashAgg_27 1.00 mpp[tiflash] funcs:sum(Column#25)->Column#6, funcs:sum(Column#26)->Column#7, funcs:sum(Column#27)->Column#19, funcs:sum(Column#28)->Column#8", + " └─ExchangeReceiver_34 1.00 mpp[tiflash] ", + " └─ExchangeSender_33 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashAgg_29 1.00 mpp[tiflash] funcs:count(distinct test.t.a)->Column#25, funcs:count(distinct test.t.b)->Column#26, funcs:sum(Column#20)->Column#27, funcs:sum(Column#21)->Column#28", + " └─ExchangeReceiver_32 16000.00 mpp[tiflash] ", + " └─ExchangeSender_31 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary], [name: test.t.b, collate: binary], [name: Column#22, collate: binary]", + " └─HashAgg_25 16000.00 mpp[tiflash] group by:Column#31, Column#32, Column#33, funcs:count(Column#29)->Column#20, funcs:sum(Column#30)->Column#21", + " └─Projection_37 20000.00 mpp[tiflash] Column#23, cast(Column#24, decimal(20,0) BINARY)->Column#30, test.t.a, test.t.b, Column#22", + " └─Projection_30 20000.00 mpp[tiflash] test.t.a, test.t.b, test.t.c, test.t.d, Column#22, case(eq(Column#22, 1), plus(test.t.c, test.t.d), )->Column#23, case(eq(Column#22, 1), plus(test.t.c, test.t.d), )->Column#24", + " └─Expand_28 20000.00 mpp[tiflash] group set num:2, groupingID:Column#22, [{},{}]", + " └─TableFullScan_11 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null + } + ] } ] diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index f9dc616bb19f5..a21d19960f461 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2870,6 +2870,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert partitionCols := la.GetPotentialPartitionKeys() // trying to match the required partitions. if prop.MPPPartitionTp == property.HashType { + // partition key required by upper layer is subset of current layout. if matches := prop.IsSubsetOf(partitionCols); len(matches) != 0 { partitionCols = choosePartitionKeys(partitionCols, matches) } else { @@ -2896,6 +2897,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert } // 2-phase agg + // no partition property down,record partition cols inside agg itself, enforce shuffler latter. childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, RejectSort: true} agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) agg.SetSchema(la.schema.Clone()) @@ -2917,6 +2919,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) agg.SetSchema(la.schema.Clone()) if la.HasDistinct() || la.HasOrderBy() { + // mpp scalar mode means the data will be pass through to only one tiFlash node at last. agg.MppRunMode = MppScalar } else { agg.MppRunMode = MppTiDB diff --git a/planner/core/explain.go b/planner/core/explain.go index 93d3533bbe3cc..e9e4d7fa1877b 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -359,6 +359,18 @@ func (p *PhysicalLimit) ExplainInfo() string { return str.String() } +// ExplainInfo implements Plan interface. +func (p *PhysicalExpand) ExplainInfo() string { + var str strings.Builder + str.WriteString("group set num:") + str.WriteString(strconv.FormatInt(int64(len(p.GroupingSets)), 10)) + str.WriteString(", groupingID:") + str.WriteString(p.GroupingIDCol.String()) + str.WriteString(", ") + str.WriteString(p.GroupingSets.String()) + return str.String() +} + // ExplainInfo implements Plan interface. func (p *basePhysicalAgg) ExplainInfo() string { return p.explainInfo(false) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 47a2027e589ed..8118aa07d1f51 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -2452,7 +2452,7 @@ func (p *LogicalCTE) findBestTask(prop *property.PhysicalProperty, _ *PlanCounte // The physical plan has been build when derive stats. pcte := PhysicalCTE{SeedPlan: p.cte.seedPartPhysicalPlan, RecurPlan: p.cte.recursivePartPhysicalPlan, CTE: p.cte, cteAsName: p.cteAsName, cteName: p.cteName}.Init(p.ctx, p.stats) pcte.SetSchema(p.schema) - t = &rootTask{pcte, false} + t = &rootTask{p: pcte, isEmpty: false} if prop.CanAddEnforcer { t = enforceProperty(prop, t, p.basePlan.ctx) } diff --git a/planner/core/fragment.go b/planner/core/fragment.go index 1a4f5e6c61f9b..a712f93b418ae 100644 --- a/planner/core/fragment.go +++ b/planner/core/fragment.go @@ -151,6 +151,7 @@ func (e *mppTaskGenerator) constructMPPTasksByChildrenTasks(tasks []*kv.MPPTask) newTasks := make([]*kv.MPPTask, 0, len(tasks)) for _, task := range tasks { addr := task.Meta.GetAddress() + // for upper fragment, the task num is equal to address num covered by lower tasks _, ok := addressMap[addr] if !ok { mppTask := &kv.MPPTask{ @@ -193,7 +194,7 @@ func (f *Fragment) init(p PhysicalPlan) error { // We would remove all the union-all operators by 'untwist'ing and copying the plans above union-all. // This will make every route from root (ExchangeSender) to leaf nodes (ExchangeReceiver and TableScan) -// a new ioslated tree (and also a fragment) without union all. These trees (fragments then tasks) will +// a new isolated tree (and also a fragment) without union all. These trees (fragments then tasks) will // finally be gathered to TiDB or be exchanged to upper tasks again. // For instance, given a plan "select c1 from t union all select c1 from s" // after untwist, there will be two plans in `forest` slice: @@ -278,6 +279,7 @@ func (e *mppTaskGenerator) generateMPPTasksForExchangeSender(s *PhysicalExchange } results := make([]*kv.MPPTask, 0, len(frags)) for _, f := range frags { + // from the bottom up tasks, err := e.generateMPPTasksForFragment(f) if err != nil { return nil, nil, errors.Trace(err) @@ -291,6 +293,7 @@ func (e *mppTaskGenerator) generateMPPTasksForExchangeSender(s *PhysicalExchange func (e *mppTaskGenerator) generateMPPTasksForFragment(f *Fragment) (tasks []*kv.MPPTask, err error) { for _, r := range f.ExchangeReceivers { + // chain call: to get lower fragments and tasks r.Tasks, r.frags, err = e.generateMPPTasksForExchangeSender(r.GetExchangeSender()) if err != nil { return nil, errors.Trace(err) diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index da0a1e2b0ddf7..84d6e39fd667c 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tidb/util/ranger" "github.com/pingcap/tidb/util/size" "github.com/pingcap/tidb/util/stringutil" @@ -1495,7 +1496,61 @@ func (p *PhysicalExchangeReceiver) MemoryUsage() (sum int64) { return } -// PhysicalExchangeSender dispatches data to upstream tasks. That means push mode processing, +// PhysicalExpand is used to expand underlying data sources to feed different grouping sets. +type PhysicalExpand struct { + // data after repeat-OP will generate a new grouping-ID column to indicate what grouping set is it for. + physicalSchemaProducer + + // generated grouping ID column itself. + GroupingIDCol *expression.Column + + // GroupingSets is used to define what kind of group layout should the underlying data follow. + // For simple case: select count(distinct a), count(distinct b) from t; the grouping expressions are [a] and [b]. + GroupingSets expression.GroupingSets +} + +// Init only assigns type and context. +func (p PhysicalExpand) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *PhysicalExpand { + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeExpand, &p, offset) + p.stats = stats + return &p +} + +// Clone implements PhysicalPlan interface. +func (p *PhysicalExpand) Clone() (PhysicalPlan, error) { + np := new(PhysicalExpand) + base, err := p.basePhysicalPlan.cloneWithSelf(np) + if err != nil { + return nil, errors.Trace(err) + } + np.basePhysicalPlan = *base + // clone ID cols. + np.GroupingIDCol = p.GroupingIDCol.Clone().(*expression.Column) + + // clone grouping expressions. + clonedGroupingSets := make([]expression.GroupingSet, 0, len(p.GroupingSets)) + for _, one := range p.GroupingSets { + clonedGroupingSets = append(clonedGroupingSets, one.Clone()) + } + np.GroupingSets = p.GroupingSets + return np, nil +} + +// MemoryUsage return the memory usage of PhysicalExpand +func (p *PhysicalExpand) MemoryUsage() (sum int64) { + if p == nil { + return + } + + sum = p.physicalSchemaProducer.MemoryUsage() + size.SizeOfSlice + int64(cap(p.GroupingSets))*size.SizeOfPointer + for _, gs := range p.GroupingSets { + sum += gs.MemoryUsage() + } + sum += p.GroupingIDCol.MemoryUsage() + return +} + +// PhysicalExchangeSender dispatches data to upstream tasks. That means push mode processing. type PhysicalExchangeSender struct { basePhysicalPlan @@ -1507,7 +1562,7 @@ type PhysicalExchangeSender struct { CompressionMode kv.ExchangeCompressionMode } -// Clone implment PhysicalPlan interface. +// Clone implements PhysicalPlan interface. func (p *PhysicalExchangeSender) Clone() (PhysicalPlan, error) { np := new(PhysicalExchangeSender) base, err := p.basePhysicalPlan.cloneWithSelf(np) diff --git a/planner/core/plan_to_pb.go b/planner/core/plan_to_pb.go index 61edd9471a20c..8b385d3d86a43 100644 --- a/planner/core/plan_to_pb.go +++ b/planner/core/plan_to_pb.go @@ -33,6 +33,29 @@ func (p *basePhysicalPlan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Exe return nil, errors.Errorf("plan %s fails converts to PB", p.basePlan.ExplainID()) } +// ToPB implements PhysicalPlan ToPB interface. +func (p *PhysicalExpand) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { + sc := ctx.GetSessionVars().StmtCtx + client := ctx.GetClient() + groupingSetsPB, err := p.GroupingSets.ToPB(sc, client) + if err != nil { + return nil, err + } + expand := &tipb.Expand{ + GroupingSets: groupingSetsPB, + } + executorID := "" + if storeType == kv.TiFlash { + var err error + expand.Child, err = p.children[0].ToPB(ctx, storeType) + if err != nil { + return nil, errors.Trace(err) + } + executorID = p.ExplainID().String() + } + return &tipb.Executor{Tp: tipb.ExecType_TypeExpand, Expand: expand, ExecutorId: &executorID}, nil +} + // ToPB implements PhysicalPlan ToPB interface. func (p *PhysicalHashAgg) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { sc := ctx.GetSessionVars().StmtCtx diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index b602c46a78bd2..a38644ec07564 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -390,7 +390,7 @@ func (p *PhysicalSelection) ResolveIndices() (err error) { return nil } -// ResolveIndicesItself resolve indices for PhyicalPlan itself +// ResolveIndicesItself resolve indices for PhysicalPlan itself func (p *PhysicalExchangeSender) ResolveIndicesItself() (err error) { for i, col := range p.HashCols { colExpr, err1 := col.Col.ResolveIndices(p.children[0].Schema()) @@ -411,6 +411,31 @@ func (p *PhysicalExchangeSender) ResolveIndices() (err error) { return p.ResolveIndicesItself() } +// ResolveIndicesItself resolve indices for PhysicalPlan itself +func (p *PhysicalExpand) ResolveIndicesItself() error { + for _, gs := range p.GroupingSets { + for _, groupingExprs := range gs { + for k, groupingExpr := range groupingExprs { + gExpr, err := groupingExpr.ResolveIndices(p.children[0].Schema()) + if err != nil { + return err + } + groupingExprs[k] = gExpr + } + } + } + return nil +} + +// ResolveIndices implements Plan interface. +func (p *PhysicalExpand) ResolveIndices() (err error) { + err = p.physicalSchemaProducer.ResolveIndices() + if err != nil { + return err + } + return p.ResolveIndicesItself() +} + // ResolveIndices implements Plan interface. func (p *basePhysicalAgg) ResolveIndices() (err error) { err = p.physicalSchemaProducer.ResolveIndices() diff --git a/planner/core/stats.go b/planner/core/stats.go index 96065fccf12b5..eb10273a60189 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -567,6 +567,16 @@ func getColsNDVWithMatchedLen(cols []*expression.Column, schema *expression.Sche return NDV, 1 } +func getColsDNVWithMatchedLenFromUniqueIDs(ids []int64, schema *expression.Schema, profile *property.StatsInfo) (float64, int) { + cols := make([]*expression.Column, 0, len(ids)) + for _, id := range ids { + cols = append(cols, &expression.Column{ + UniqueID: id, + }) + } + return getColsNDVWithMatchedLen(cols, schema, profile) +} + func (p *LogicalProjection) getGroupNDVs(colGroups [][]*expression.Column, childProfile *property.StatsInfo, selfSchema *expression.Schema) []property.GroupNDV { if len(colGroups) == 0 || len(childProfile.GroupNDVs) == 0 { return nil diff --git a/planner/core/task.go b/planner/core/task.go index f111ea62e7c8a..5d1af579df163 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1630,6 +1630,11 @@ func BuildFinalModeAggregation( byItems = append(byItems, &util.ByItems{Expr: getDistinctExpr(byItem.Expr, true), Desc: byItem.Desc}) } + if aggFunc.HasDistinct && isMPPTask && aggFunc.GroupingID > 0 { + // keep the groupingID as it was, otherwise the new split final aggregate's ganna lost its groupingID info. + finalAggFunc.GroupingID = aggFunc.GroupingID + } + finalAggFunc.OrderByItems = byItems finalAggFunc.HasDistinct = aggFunc.HasDistinct // In logical optimize phase, the Agg->PartitionUnion->TableReader may become @@ -1649,12 +1654,14 @@ func BuildFinalModeAggregation( return } if aggregation.NeedCount(finalAggFunc.Name) { + // only Avg and Count need count if isMPPTask && finalAggFunc.Name == ast.AggFuncCount { // For MPP Task, the final count() is changed to sum(). // Note: MPP mode does not run avg() directly, instead, avg() -> sum()/(case when count() = 0 then 1 else count() end), // so we do not process it here. finalAggFunc.Name = ast.AggFuncSum } else { + // avg branch ft := types.NewFieldType(mysql.TypeLonglong) ft.SetFlen(21) ft.SetCharset(charset.CharsetBin) @@ -1713,6 +1720,7 @@ func BuildFinalModeAggregation( args = append(args, aggFunc.Args[len(aggFunc.Args)-1]) } } else { + // other agg desc just split into two parts partialFuncDesc := aggFunc.Clone() partial.AggFuncs = append(partial.AggFuncs, partialFuncDesc) if aggFunc.Name == ast.AggFuncFirstRow { @@ -1787,6 +1795,7 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType exprs = append(exprs, divide) } else { + // other non-avg agg use the old schema as it did. newAggFuncs = append(newAggFuncs, aggFunc) newSchema.Append(p.schema.Columns[i]) exprs = append(exprs, p.schema.Columns[i]) @@ -1865,11 +1874,89 @@ func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTas MppRunMode: p.MppRunMode, }.initForHash(p.ctx, p.stats, p.blockOffset, prop) finalAgg.schema = finalPref.Schema + // partialAgg and finalAgg use the same ref of stats return partialAgg, finalAgg } -// canUse3StageDistinctAgg returns true if this agg can use 3 stage for distinct aggregation -func (p *basePhysicalAgg) canUse3StageDistinctAgg() bool { +func (p *basePhysicalAgg) scale3StageForDistinctAgg() (bool, expression.GroupingSets) { + if p.canUse3Stage4SingleDistinctAgg() { + return true, nil + } + return p.canUse3Stage4MultiDistinctAgg() +} + +// canUse3Stage4MultiDistinctAgg returns true if this agg can use 3 stage for multi distinct aggregation +func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss expression.GroupingSets) { + if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || !p.ctx.GetSessionVars().Enable3StageMultiDistinctAgg || len(p.GroupByItems) > 0 { + return false, nil + } + defer func() { + // some clean work. + if !can { + for _, fun := range p.AggFuncs { + fun.GroupingID = 0 + } + } + }() + // groupingSets is alias of []GroupingSet, the below equal to = make([]GroupingSet, 0, 2) + groupingSets := make(expression.GroupingSets, 0, 2) + for _, fun := range p.AggFuncs { + if fun.HasDistinct { + if fun.Name != ast.AggFuncCount { + // now only for multi count(distinct x) + return false, nil + } + for _, arg := range fun.Args { + // bail out when args are not simple column, see GitHub issue #35417 + if _, ok := arg.(*expression.Column); !ok { + return false, nil + } + } + // here it's a valid count distinct agg with normal column args, collecting its distinct expr. + groupingSets = append(groupingSets, expression.GroupingSet{fun.Args}) + // groupingID now is the offset of target grouping in GroupingSets. + // todo: it may be changed after grouping set merge in the future. + fun.GroupingID = len(groupingSets) + } else if len(fun.Args) > 1 { + return false, nil + } + // banned group_concat(x order by y) + if len(fun.OrderByItems) > 0 || fun.Mode != aggregation.CompleteMode { + return false, nil + } + } + compressed := groupingSets.Merge() + if len(compressed) != len(groupingSets) { + p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("Some grouping sets should be merged")) + // todo arenatlx: some grouping set should be merged which is not supported by now temporarily. + return false, nil + } + if groupingSets.NeedCloneColumn() { + // todo: column clone haven't implemented. + return false, nil + } + if len(groupingSets) > 1 { + // fill the grouping ID for normal agg. + for _, fun := range p.AggFuncs { + if fun.GroupingID == 0 { + // the grouping ID hasn't set. find the targeting grouping set. + groupingSetOffset := groupingSets.TargetOne(fun.Args) + if groupingSetOffset == -1 { + // todo: if we couldn't find a existed current valid group layout, we need to copy the column out from being filled with null value. + p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("couldn't find a proper group set for normal agg")) + return false, nil + } + // starting with 1 + fun.GroupingID = groupingSetOffset + 1 + } + } + return true, groupingSets + } + return false, nil +} + +// canUse3Stage4SingleDistinctAgg returns true if this agg can use 3 stage for distinct aggregation +func (p *basePhysicalAgg) canUse3Stage4SingleDistinctAgg() bool { num := 0 if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 { return false @@ -2047,6 +2134,305 @@ func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *mppTask) task { return mpp } +// scaleStats4GroupingSets scale the derived stats because the lower source has been expanded. +// +// parent OP <- logicalAgg <- children OP (derived stats) +// | +// v +// parent OP <- physicalAgg <- children OP (stats used) +// | +// +----------+----------+----------+ +// Final Mid Partial Expand +// +// physical agg stats is reasonable from the whole, because expand operator is designed to facilitate +// the Mid and Partial Agg, which means when leaving the Final, its output rowcount could be exactly +// the same as what it derived(estimated) before entering physical optimization phase. +// +// From the cost model correctness, for these inserted sub-agg and even expand operator, we should +// recompute the stats for them particularly. +// +// for example: grouping sets {},{}, group by items {a,b,c,groupingID} +// after expand: +// +// a, b, c, groupingID +// ... null c 1 ---+ +// ... null c 1 +------- replica group 1 +// ... null c 1 ---+ +// null ... c 2 ---+ +// null ... c 2 +------- replica group 2 +// null ... c 2 ---+ +// +// since null value is seen the same when grouping data (groupingID in one replica is always the same): +// - so the num of group in replica 1 is equal to NDV(a,c) +// - so the num of group in replica 2 is equal to NDV(b,c) +// +// in a summary, the total num of group of all replica is equal to = Σ:NDV(each-grouping-set-cols, normal-group-cols) +func (p *PhysicalHashAgg) scaleStats4GroupingSets(groupingSets expression.GroupingSets, groupingIDCol *expression.Column, + childSchema *expression.Schema, childStats *property.StatsInfo) { + idSets := groupingSets.AllSetsColIDs() + normalGbyCols := make([]*expression.Column, 0, len(p.GroupByItems)) + for _, gbyExpr := range p.GroupByItems { + cols := expression.ExtractColumns(gbyExpr) + for _, col := range cols { + if !idSets.Has(int(col.UniqueID)) && col.UniqueID != groupingIDCol.UniqueID { + normalGbyCols = append(normalGbyCols, col) + } + } + } + sumNDV := float64(0) + for _, groupingSet := range groupingSets { + // for every grouping set, pick its cols out, and combine with normal group cols to get the NDV. + groupingSetCols := groupingSet.ExtractCols() + groupingSetCols = append(groupingSetCols, normalGbyCols...) + NDV, _ := getColsNDVWithMatchedLen(groupingSetCols, childSchema, childStats) + sumNDV += NDV + } + // After group operator, all same rows are grouped into one row, that means all + // change the sub-agg's stats + if p.stats != nil { + // equivalence to a new cloned one. (cause finalAgg and partialAgg may share a same copy of stats) + cpStats := p.stats.Scale(1) + cpStats.RowCount = sumNDV + // We cannot estimate the ColNDVs for every output, so we use a conservative strategy. + for k := range cpStats.ColNDVs { + cpStats.ColNDVs[k] = sumNDV + } + // for old groupNDV, if it's containing one more grouping set cols, just plus the NDV where the col is excluded. + // for example: old grouping NDV(b,c), where b is in grouping sets {},{}. so when countering the new NDV: + // cases: + // new grouping NDV(b,c) := old NDV(b,c) + NDV(null, c) = old NDV(b,c) + DNV(c). + // new grouping NDV(a,b,c) := old NDV(a,b,c) + NDV(null,b,c) + NDV(a,null,c) = old NDV(a,b,c) + NDV(b,c) + NDV(a,c) + allGroupingSetsIDs := groupingSets.AllSetsColIDs() + for _, oneGNDV := range cpStats.GroupNDVs { + newGNDV := oneGNDV.NDV + intersectionIDs := make([]int64, 0, len(oneGNDV.Cols)) + for i, id := range oneGNDV.Cols { + if allGroupingSetsIDs.Has(int(id)) { + // when meet an id in grouping sets, skip it (cause its null) and append the rest ids to count the incrementNDV. + beforeLen := len(intersectionIDs) + intersectionIDs = append(intersectionIDs, oneGNDV.Cols[i:]...) + incrementNDV, _ := getColsDNVWithMatchedLenFromUniqueIDs(intersectionIDs, childSchema, childStats) + newGNDV += incrementNDV + // restore the before intersectionIDs slice. + intersectionIDs = intersectionIDs[:beforeLen] + } + // insert ids one by one. + intersectionIDs = append(intersectionIDs, id) + } + oneGNDV.NDV = newGNDV + } + p.stats = cpStats + } +} + +// adjust3StagePhaseAgg generate 3 stage aggregation for single/multi count distinct if applicable. +// +// select count(distinct a), count(b) from foo +// +// will generate plan: +// +// HashAgg sum(#1), sum(#2) -> final agg +// +- Exchange Passthrough +// +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg +// +- Exchange HashPartition by a +// +- HashAgg count(b) #3, group by a -> partial agg +// +- TableScan foo +// +// select count(distinct a), count(distinct b), count(c) from foo +// +// will generate plan: +// +// HashAgg sum(#1), sum(#2), sum(#3) -> final agg +// +- Exchange Passthrough +// +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg +// +- Exchange HashPartition by a,b,groupingID +// +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg +// +- Expand {}, {} -> expand +// +- TableScan foo +func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan, canUse3StageAgg bool, + groupingSets expression.GroupingSets, mpp *mppTask) (final, mid, part, proj4Part PhysicalPlan, _ error) { + if !(partialAgg != nil && canUse3StageAgg) { + // quick path: return the original finalAgg and partiAgg. + return finalAgg, nil, partialAgg, nil, nil + } + if len(groupingSets) == 0 { + // single distinct agg mode. + clonedAgg, err := finalAgg.Clone() + if err != nil { + return nil, nil, nil, nil, err + } + + // step1: adjust middle agg. + middleHashAgg := clonedAgg.(*PhysicalHashAgg) + distinctPos := 0 + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + for i, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.RetTp, + } + if fun.HasDistinct { + distinctPos = i + fun.Mode = aggregation.Partial1Mode + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) + } + middleHashAgg.schema = middleSchema + + // step2: adjust final agg. + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if distinctPos == i { + // change count(distinct) to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + newArgs = append(newArgs, middleSchema.Columns[i]) + } else { + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, err + } + newArgs = append(newArgs, newCol) + } + } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + finalAggDescs = append(finalAggDescs, fun) + } + finalHashAgg.AggFuncs = finalAggDescs + // partialAgg is im-mutated from args. + return finalHashAgg, middleHashAgg, partialAgg, nil, nil + } + // multi distinct agg mode, having grouping sets. + // set the default expression to constant 1 for the convenience to choose default group set data. + var groupingIDCol expression.Expression + // enforce Expand operator above the children. + // physical plan is enumerated without children from itself, use mpp subtree instead p.children. + // scale(len(groupingSets)) will change the NDV, while Expand doesn't change the NDV and groupNDV. + stats := mpp.p.statsInfo().Scale(float64(1)) + stats.RowCount = stats.RowCount * float64(len(groupingSets)) + physicalExpand := PhysicalExpand{ + GroupingSets: groupingSets, + }.Init(p.ctx, stats, mpp.p.SelectBlockOffset()) + // generate a new column as groupingID to identify which this row is targeting for. + tp := types.NewFieldType(mysql.TypeLonglong) + tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) + groupingIDCol = &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: tp, + } + // append the physical expand op with groupingID column. + physicalExpand.SetSchema(mpp.p.Schema().Clone()) + physicalExpand.schema.Append(groupingIDCol.(*expression.Column)) + physicalExpand.GroupingIDCol = groupingIDCol.(*expression.Column) + // attach PhysicalExpand to mpp + attachPlan2Task(physicalExpand, mpp) + + // having group sets + clonedAgg, err := finalAgg.Clone() + if err != nil { + return nil, nil, nil, nil, err + } + cloneHashAgg := clonedAgg.(*PhysicalHashAgg) + // Clone(), it will share same base-plan elements from the finalAgg, including id,tp,stats. Make a new one here. + cloneHashAgg.basePlan = newBasePlan(cloneHashAgg.ctx, cloneHashAgg.tp, cloneHashAgg.blockOffset) + cloneHashAgg.stats = finalAgg.Stats() // reuse the final agg stats here. + + // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. + // Since we may substitute the first arg of normal agg with case-when expression here, append a + // customized proj here rather than depending on postOptimize to insert a blunt one for us. + // + // proj4Partial output all the base col from lower op + caseWhen proj cols. + proj4Partial := new(PhysicalProjection).Init(p.ctx, mpp.p.statsInfo(), mpp.p.SelectBlockOffset()) + for _, col := range mpp.p.Schema().Columns { + proj4Partial.Exprs = append(proj4Partial.Exprs, col) + } + proj4Partial.SetSchema(mpp.p.Schema().Clone()) + + partialHashAgg := partialAgg.(*PhysicalHashAgg) + partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) + partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) + // it will create a new stats for partial agg. + partialHashAgg.scaleStats4GroupingSets(groupingSets, groupingIDCol.(*expression.Column), proj4Partial.Schema(), proj4Partial.statsInfo()) + for _, fun := range partialHashAgg.AggFuncs { + if !fun.HasDistinct { + // for normal agg phase1, we should also modify them to target for specified group data. + // Expr = (case when groupingID = targeted_groupingID then arg else null end) + eqExpr := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) + caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) + caseWhenProjCol := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.Args[0].GetType(), + } + proj4Partial.Exprs = append(proj4Partial.Exprs, caseWhen) + proj4Partial.Schema().Append(caseWhenProjCol) + fun.Args[0] = caseWhenProjCol + } + } + + // step2: adjust middle agg + // middleHashAgg shared the same stats with the final agg does. + middleHashAgg := cloneHashAgg + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + for _, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ + UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: fun.RetTp, + } + if fun.HasDistinct { + // let count distinct agg aggregate on whole-scope data rather using case-when expr to target on specified group. (agg null strict attribute) + fun.Mode = aggregation.Partial1Mode + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // record the origin column unique id down before change it to be case when expr. + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) + } + middleHashAgg.schema = middleSchema + + // step3: adjust final agg + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if fun.HasDistinct { + // change count(distinct) agg to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + // count(distinct a,b) -> become a single partial result col. + newArgs = append(newArgs, middleSchema.Columns[i]) + } else { + // remap final normal agg args to be output schema of middle normal agg. + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, err + } + newArgs = append(newArgs, newCol) + } + } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + fun.GroupingID = 0 + finalAggDescs = append(finalAggDescs, fun) + } + finalHashAgg.AggFuncs = finalAggDescs + return finalHashAgg, middleHashAgg, partialHashAgg, proj4Partial, nil +} + func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { t := tasks[0].copy() mpp, ok := t.(*mppTask) @@ -2108,84 +2494,35 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { case MppScalar: prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.SinglePartitionType} if !mpp.needEnforceExchanger(prop) { + // On the one hand: when the low layer already satisfied the single partition layout, just do the all agg computation in the single node. return p.attach2TaskForMpp1Phase(mpp) } + // On the other hand: try to split the mppScalar agg into multi phases agg **down** to multi nodes since data already distributed across nodes. // we have to check it before the content of p has been modified - canUse3StageAgg := p.canUse3StageDistinctAgg() + canUse3StageAgg, groupingSets := p.scale3StageForDistinctAgg() proj := p.convertAvgForMPP() partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true) if finalAgg == nil { return invalidTask } - // generate 3 stage aggregation for single count distinct if applicable. - // select count(distinct a), count(b) from foo - // will generate plan: - // HashAgg sum(#1), sum(#2) -> final agg - // +- Exchange Passthrough - // +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg - // +- Exchange HashPartition by a - // +- HashAgg count(b) #3, group by a -> partial agg - // +- TableScan foo - var middleAgg *PhysicalHashAgg = nil - if partialAgg != nil && canUse3StageAgg { - clonedAgg, err := finalAgg.Clone() - if err != nil { - return invalidTask - } - middleAgg = clonedAgg.(*PhysicalHashAgg) - distinctPos := 0 - middleSchema := expression.NewSchema() - schemaMap := make(map[int64]*expression.Column, len(middleAgg.AggFuncs)) - for i, fun := range middleAgg.AggFuncs { - col := &expression.Column{ - UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: fun.RetTp, - } - if fun.HasDistinct { - distinctPos = i - fun.Mode = aggregation.Partial1Mode - } else { - fun.Mode = aggregation.Partial2Mode - originalCol := fun.Args[0].(*expression.Column) - schemaMap[originalCol.UniqueID] = col - } - middleSchema.Append(col) - } - middleAgg.schema = middleSchema - - finalHashAgg := finalAgg.(*PhysicalHashAgg) - finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) - for i, fun := range finalHashAgg.AggFuncs { - newArgs := make([]expression.Expression, 0, 1) - if distinctPos == i { - // change count(distinct) to sum() - fun.Name = ast.AggFuncSum - fun.HasDistinct = false - newArgs = append(newArgs, middleSchema.Columns[i]) - } else { - for _, arg := range fun.Args { - newCol, err := arg.RemapColumn(schemaMap) - if err != nil { - return invalidTask - } - newArgs = append(newArgs, newCol) - } - } - fun.Mode = aggregation.FinalMode - fun.Args = newArgs - finalAggDescs = append(finalAggDescs, fun) - } - finalHashAgg.AggFuncs = finalAggDescs + final, middle, partial, proj4Partial, err := p.adjust3StagePhaseAgg(partialAgg, finalAgg, canUse3StageAgg, groupingSets, mpp) + if err != nil { + return invalidTask + } + + // partial agg proj would be null if one scalar agg cannot run in two-phase mode + if proj4Partial != nil { + attachPlan2Task(proj4Partial, mpp) } // partial agg would be null if one scalar agg cannot run in two-phase mode - if partialAgg != nil { - attachPlan2Task(partialAgg, mpp) + if partial != nil { + attachPlan2Task(partial, mpp) } - if middleAgg != nil && canUse3StageAgg { - items := partialAgg.(*PhysicalHashAgg).GroupByItems + if middle != nil && canUse3StageAgg { + items := partial.(*PhysicalHashAgg).GroupByItems partitionCols := make([]*property.MPPPartitionColumn, 0, len(items)) for _, expr := range items { col, ok := expr.(*expression.Column) @@ -2198,14 +2535,15 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { }) } - prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} - newMpp := mpp.enforceExchanger(prop) - attachPlan2Task(middleAgg, newMpp) + exProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} + newMpp := mpp.enforceExchanger(exProp) + attachPlan2Task(middle, newMpp) mpp = newMpp } + // prop here still be the first generated single-partition requirement. newMpp := mpp.enforceExchanger(prop) - attachPlan2Task(finalAgg, newMpp) + attachPlan2Task(final, newMpp) if proj == nil { proj = PhysicalProjection{ Exprs: make([]expression.Expression, 0, len(p.Schema().Columns)), diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index c3babeb0e908f..3faa2fab07358 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -778,6 +778,9 @@ type SessionVars struct { // Enable3StageDistinctAgg indicates whether to allow 3 stage distinct aggregate Enable3StageDistinctAgg bool + // Enable3StageMultiDistinctAgg indicates whether to allow 3 stage multi distinct aggregate + Enable3StageMultiDistinctAgg bool + // MultiStatementMode permits incorrect client library usage. Not recommended to be turned on. MultiStatementMode int diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index f3ccd7dd534b4..a6a50c7c2ffdc 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -181,6 +181,10 @@ var defaultSysVars = []*SysVar{ s.Enable3StageDistinctAgg = TiDBOptOn(val) return nil }}, + {Scope: ScopeGlobal | ScopeSession, Name: TiDBOptEnable3StageMultiDistinctAgg, Value: BoolToOnOff(DefTiDB3StageMultiDistinctAgg), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { + s.Enable3StageMultiDistinctAgg = TiDBOptOn(val) + return nil + }}, {Scope: ScopeSession, Name: TiDBOptWriteRowID, Value: BoolToOnOff(DefOptWriteRowID), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { s.AllowWriteRowID = TiDBOptOn(val) return nil diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index ccb9aa13af366..04698c4e3d0ed 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -63,6 +63,9 @@ const ( // TiDBOpt3StageDistinctAgg is used to indicate whether to plan and execute the distinct agg in 3 stages TiDBOpt3StageDistinctAgg = "tidb_opt_three_stage_distinct_agg" + // TiDBOptEnable3StageMultiDistinctAgg is used to indicate whether to plan and execute the multi distinct agg in 3 stages + TiDBOptEnable3StageMultiDistinctAgg = "tidb_opt_enable_three_stage_multi_distinct_agg" + // TiDBBCJThresholdSize is used to limit the size of small table for mpp broadcast join. // Its unit is bytes, if the size of small table is larger than it, we will not use bcj. TiDBBCJThresholdSize = "tidb_broadcast_join_threshold_size" @@ -1106,6 +1109,7 @@ const ( DefTiDBRemoveOrderbyInSubquery = false DefTiDBSkewDistinctAgg = false DefTiDB3StageDistinctAgg = true + DefTiDB3StageMultiDistinctAgg = false DefTiDBReadStaleness = 0 DefTiDBGCMaxWaitTime = 24 * 60 * 60 DefMaxAllowedPacket uint64 = 67108864 diff --git a/store/copr/mpp.go b/store/copr/mpp.go index 1c4ef6a40e85a..3b4efc96b2c2d 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -350,7 +350,7 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *Backoffer, req if !req.IsRoot { return } - + // only root task should establish a stream conn with tiFlash to receive result. m.establishMPPConns(bo, req, taskMeta) } diff --git a/store/mockstore/unistore/cophandler/mpp.go b/store/mockstore/unistore/cophandler/mpp.go index 1cf0746e861dd..adeea777871c7 100644 --- a/store/mockstore/unistore/cophandler/mpp.go +++ b/store/mockstore/unistore/cophandler/mpp.go @@ -200,6 +200,63 @@ func (b *mppExecBuilder) buildLimit(pb *tipb.Limit) (*limitExec, error) { return exec, nil } +func (b *mppExecBuilder) buildExpand(pb *tipb.Expand) (mppExec, error) { + child, err := b.buildMPPExecutor(pb.Child) + if err != nil { + return nil, err + } + exec := &expandExec{ + baseMPPExec: baseMPPExec{sc: b.sc, mppCtx: b.mppCtx, children: []mppExec{child}}, + } + + childFieldTypes := child.getFieldTypes() + // convert the grouping sets. + tidbGss := expression.GroupingSets{} + for _, gs := range pb.GroupingSets { + tidbGs := expression.GroupingSet{} + for _, groupingExprs := range gs.GroupingExprs { + tidbGroupingExprs, err := convertToExprs(b.sc, childFieldTypes, groupingExprs.GroupingExpr) + if err != nil { + return nil, err + } + tidbGs = append(tidbGs, tidbGroupingExprs) + } + tidbGss = append(tidbGss, tidbGs) + } + exec.groupingSets = tidbGss + inGroupingSetMap := make(map[int]struct{}, len(exec.groupingSets)) + for _, gs := range exec.groupingSets { + // for every grouping set, collect column offsets under this grouping set. + for _, groupingExprs := range gs { + for _, groupingExpr := range groupingExprs { + col, ok := groupingExpr.(*expression.Column) + if !ok { + return nil, errors.New("grouping set expr is not column ref") + } + inGroupingSetMap[col.Index] = struct{}{} + } + } + } + mutatedFieldTypes := make([]*types.FieldType, 0, len(childFieldTypes)) + // change the field types return from children tobe nullable. + for offset, f := range childFieldTypes { + cf := f.Clone() + if _, ok := inGroupingSetMap[offset]; ok { + // remove the not null flag, make it nullable. + cf.SetFlag(cf.GetFlag() & ^mysql.NotNullFlag) + } + mutatedFieldTypes = append(mutatedFieldTypes, cf) + } + + // adding groupingID uint64|not-null as last one field types. + groupingIDFieldType := types.NewFieldType(mysql.TypeLonglong) + groupingIDFieldType.SetFlag(mysql.NotNullFlag | mysql.UnsignedFlag) + mutatedFieldTypes = append(mutatedFieldTypes, groupingIDFieldType) + + exec.fieldTypes = mutatedFieldTypes + return exec, nil +} + func (b *mppExecBuilder) buildTopN(pb *tipb.TopN) (mppExec, error) { child, err := b.buildMPPExecutor(pb.Child) if err != nil { @@ -253,18 +310,19 @@ func (b *mppExecBuilder) buildMPPExchangeSender(pb *tipb.ExchangeSender) (*exchS exchangeTp: pb.Tp, } if pb.Tp == tipb.ExchangeType_Hash { - if len(pb.PartitionKeys) != 1 { - return nil, errors.New("The number of hash key must be 1") - } - expr, err := expression.PBToExpr(pb.PartitionKeys[0], child.getFieldTypes(), b.sc) - if err != nil { - return nil, errors.Trace(err) - } - col, ok := expr.(*expression.Column) - if !ok { - return nil, errors.New("Hash key must be column type") + // remove the limitation of len(pb.PartitionKeys) == 1 + for _, partitionKey := range pb.PartitionKeys { + expr, err := expression.PBToExpr(partitionKey, child.getFieldTypes(), b.sc) + if err != nil { + return nil, errors.Trace(err) + } + col, ok := expr.(*expression.Column) + if !ok { + return nil, errors.New("Hash key must be column type") + } + e.hashKeyOffsets = append(e.hashKeyOffsets, col.Index) + e.hashKeyTypes = append(e.hashKeyTypes, e.fieldTypes[col.Index]) } - e.hashKeyOffset = col.Index } for _, taskMeta := range pb.EncodedTaskMeta { @@ -493,6 +551,8 @@ func (b *mppExecBuilder) buildMPPExecutor(exec *tipb.Executor) (mppExec, error) case tipb.ExecType_TypePartitionTableScan: ts := exec.PartitionTableScan return b.buildMPPPartitionTableScan(ts) + case tipb.ExecType_TypeExpand: + return b.buildExpand(exec.Expand) default: return nil, errors.Errorf(ErrExecutorNotSupportedMsg + exec.Tp.String()) } diff --git a/store/mockstore/unistore/cophandler/mpp_exec.go b/store/mockstore/unistore/cophandler/mpp_exec.go index d0d0a71d5b85a..a9233d423b7ef 100644 --- a/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/store/mockstore/unistore/cophandler/mpp_exec.go @@ -17,6 +17,7 @@ package cophandler import ( "bytes" "encoding/binary" + "hash/fnv" "io" "math" "sort" @@ -415,6 +416,110 @@ func (e *limitExec) next() (*chunk.Chunk, error) { return chk, nil } +// expandExec is the basic mock logic for expand executor in uniStore, +// with which we can validate the mpp plan correctness from explain result and result returned. +type expandExec struct { + baseMPPExec + + lastNum int + lastChunk *chunk.Chunk + groupingSets expression.GroupingSets + groupingSetOffsetMap []map[int]struct{} + groupingSetScope map[int]struct{} +} + +func (e *expandExec) open() error { + if err := e.children[0].open(); err != nil { + return err + } + // building the quick finding map + e.groupingSetOffsetMap = make([]map[int]struct{}, 0, len(e.groupingSets)) + e.groupingSetScope = make(map[int]struct{}, len(e.groupingSets)) + for _, gs := range e.groupingSets { + tmp := make(map[int]struct{}, len(gs)) + // for every grouping set, collect column offsets under this grouping set. + for _, groupingExprs := range gs { + for _, groupingExpr := range groupingExprs { + col, ok := groupingExpr.(*expression.Column) + if !ok { + return errors.New("grouping set expr is not column ref") + } + tmp[col.Index] = struct{}{} + e.groupingSetScope[col.Index] = struct{}{} + } + } + e.groupingSetOffsetMap = append(e.groupingSetOffsetMap, tmp) + } + return nil +} + +func (e *expandExec) isGroupingCol(index int) bool { + if _, ok := e.groupingSetScope[index]; ok { + return true + } + return false +} + +func (e *expandExec) next() (*chunk.Chunk, error) { + var ( + err error + ) + if e.groupingSets.IsEmpty() { + return e.children[0].next() + } + resChk := chunk.NewChunkWithCapacity(e.getFieldTypes(), DefaultBatchSize) + for { + if e.lastChunk == nil || e.lastChunk.NumRows() == e.lastNum { + // fetch one chunk from children. + e.lastChunk, err = e.children[0].next() + if err != nil { + return nil, err + } + e.lastNum = 0 + if e.lastChunk == nil || e.lastChunk.NumRows() == 0 { + break + } + e.execSummary.updateOnlyRows(e.lastChunk.NumRows()) + } + numRows := e.lastChunk.NumRows() + numGroupingOffset := len(e.groupingSets) + + for i := e.lastNum; i < numRows; i++ { + row := e.lastChunk.GetRow(i) + e.lastNum++ + // for every grouping set, expand the base row N times. + for g := 0; g < numGroupingOffset; g++ { + repeatRow := chunk.MutRowFromTypes(e.fieldTypes) + // for every targeted grouping set: + // 1: for every column in this grouping set, setting them as it was. + // 2: for every column in other target grouping set, setting them as null. + // 3: for every column not in any grouping set, setting them as it was. + // * normal agg only aimed at one replica of them with groupingID = 1 + // * so we don't need to change non-related column to be nullable. + // * so we don't need to mutate the column to be null when groupingID > 1 + for datumOffset, datumType := range e.fieldTypes[:len(e.fieldTypes)-1] { + if _, ok := e.groupingSetOffsetMap[g][datumOffset]; ok { + repeatRow.SetDatum(datumOffset, row.GetDatum(datumOffset, datumType)) + } else if !e.isGroupingCol(datumOffset) { + repeatRow.SetDatum(datumOffset, row.GetDatum(datumOffset, datumType)) + } else { + repeatRow.SetDatum(datumOffset, types.NewDatum(nil)) + } + } + // the last one column should be groupingID col. + groupingID := g + 1 + repeatRow.SetDatum(len(e.fieldTypes)-1, types.NewDatum(groupingID)) + resChk.AppendRow(repeatRow.ToRow()) + } + if DefaultBatchSize-resChk.NumRows() < numGroupingOffset { + // no enough room for another repeated N rows, return this chunk immediately. + return resChk, nil + } + } + } + return resChk, nil +} + type topNExec struct { baseMPPExec @@ -492,10 +597,11 @@ func (e *topNExec) next() (*chunk.Chunk, error) { type exchSenderExec struct { baseMPPExec - tunnels []*ExchangerTunnel - outputOffsets []uint32 - exchangeTp tipb.ExchangeType - hashKeyOffset int + tunnels []*ExchangerTunnel + outputOffsets []uint32 + exchangeTp tipb.ExchangeType + hashKeyOffsets []int + hashKeyTypes []*types.FieldType } func (e *exchSenderExec) open() error { @@ -548,15 +654,22 @@ func (e *exchSenderExec) next() (*chunk.Chunk, error) { for i := 0; i < len(e.tunnels); i++ { targetChunks = append(targetChunks, chunk.NewChunkWithCapacity(e.fieldTypes, rows)) } + hashVals := fnv.New64() + payload := make([]byte, 1) for i := 0; i < rows; i++ { row := chk.GetRow(i) - d := row.GetDatum(e.hashKeyOffset, e.fieldTypes[e.hashKeyOffset]) - if d.IsNull() { - targetChunks[0].AppendRow(row) - } else { - hashKey := int(d.GetInt64() % int64(len(e.tunnels))) - targetChunks[hashKey].AppendRow(row) + hashVals.Reset() + // use hash values to get unique uint64 to mod. + // collect all the hash key datum. + err := codec.HashChunkRow(e.sc, hashVals, row, e.hashKeyTypes, e.hashKeyOffsets, payload) + if err != nil { + for _, tunnel := range e.tunnels { + tunnel.ErrCh <- err + } + return nil, nil } + hashKey := hashVals.Sum64() % uint64(len(e.tunnels)) + targetChunks[hashKey].AppendRow(row) } for i, tunnel := range e.tunnels { if targetChunks[i].NumRows() > 0 { @@ -618,6 +731,7 @@ func (e *exchRecvExec) init() error { } serverMetas = append(serverMetas, meta) } + // for receiver: open conn worker for every receive meta. for _, meta := range serverMetas { e.wg.Add(1) go e.runTunnelWorker(e.mppCtx.TaskHandler, meta) @@ -668,6 +782,7 @@ func (e *exchRecvExec) EstablishConnAndReceiveData(h *MPPTaskHandler, meta *mpp. } return nil, errors.Trace(err) } + // stream resp if mppResponse == nil { return ret, nil } diff --git a/store/mockstore/unistore/rpc.go b/store/mockstore/unistore/rpc.go index 15fe6828ddabf..3fa7aef4bec23 100644 --- a/store/mockstore/unistore/rpc.go +++ b/store/mockstore/unistore/rpc.go @@ -332,6 +332,7 @@ func (c *RPCClient) handleCopStream(ctx context.Context, req *coprocessor.Reques }, nil } +// handleEstablishMPPConnection handle the mock mpp collection came from root or peers. func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.EstablishMPPConnectionRequest, timeout time.Duration, storeID uint64) (*tikvrpc.MPPStreamResponse, error) { mockServer := new(mockMPPConnectStreamServer) err := c.usSvr.EstablishMPPConnectionWithStoreID(r, mockServer, storeID) @@ -348,6 +349,7 @@ func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.Est _, cancel := context.WithCancel(ctx) streamResp.Lease.Cancel = cancel streamResp.Timeout = timeout + // mock the stream resp from the server's resp slice first, err := streamResp.Recv() if err != nil { if errors.Cause(err) != io.EOF { diff --git a/util/plancodec/id.go b/util/plancodec/id.go index 0366b7dc3f5a6..20ad0e4301cfa 100644 --- a/util/plancodec/id.go +++ b/util/plancodec/id.go @@ -57,6 +57,8 @@ const ( TypeExchangeSender = "ExchangeSender" // TypeExchangeReceiver is the type of mpp exchanger receiver. TypeExchangeReceiver = "ExchangeReceiver" + // TypeExpand is the type of mpp expand source operator. + TypeExpand = "Expand" // TypeMergeJoin is the type of merge join. TypeMergeJoin = "MergeJoin" // TypeIndexJoin is the type of index look up join. @@ -193,6 +195,7 @@ const ( typeShuffleReceiverID int = 55 typeForeignKeyCheck int = 56 typeForeignKeyCascade int = 57 + typeExpandID int = 58 ) // TypeStringToPhysicalID converts the plan type string to plan id. @@ -312,6 +315,8 @@ func TypeStringToPhysicalID(tp string) int { return typeForeignKeyCheck case TypeForeignKeyCascade: return typeForeignKeyCascade + case TypeExpand: + return typeExpandID } // Should never reach here. return 0 @@ -434,6 +439,8 @@ func PhysicalIDToTypeString(id int) string { return TypeForeignKeyCheck case typeForeignKeyCascade: return TypeForeignKeyCascade + case typeExpandID: + return TypeExpand } // Should never reach here.