From f5d6ea75ab963b6ba11c31b445da5d9604146d5a Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 18 Oct 2018 19:24:37 +0800 Subject: [PATCH 1/3] util: move disjoint set to util package --- expression/constant_propagation.go | 46 ++++++------------------ util/disjointset/int_set.go | 45 +++++++++++++++++++++++ util/disjointset/int_set_test.go | 57 ++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 35 deletions(-) create mode 100644 util/disjointset/int_set.go create mode 100644 util/disjointset/int_set_test.go diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index 87b217100627a..351fe81d2526d 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/disjointset" "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -27,34 +28,11 @@ import ( // MaxPropagateColsCnt means the max number of columns that can participate propagation. var MaxPropagateColsCnt = 100 -type multiEqualSet struct { - parent []int -} - -func (m *multiEqualSet) init(l int) { - m.parent = make([]int, l) - for i := range m.parent { - m.parent[i] = i - } -} - -func (m *multiEqualSet) addRelation(a int, b int) { - m.parent[m.findRoot(a)] = m.findRoot(b) -} - -func (m *multiEqualSet) findRoot(a int) int { - if a == m.parent[a] { - return a - } - m.parent[a] = m.findRoot(m.parent[a]) - return m.parent[a] -} - type basePropConstSolver struct { - colMapper map[int64]int // colMapper maps column to its index - eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i] - unionSet *multiEqualSet // unionSet stores the relations like col_i = col_j - columns []*Column // columns stores all columns appearing in the conditions + colMapper map[int64]int // colMapper maps column to its index + eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i] + unionSet *disjointset.IntSet // unionSet stores the relations like col_i = col_j + columns []*Column // columns stores all columns appearing in the conditions ctx sessionctx.Context } @@ -208,8 +186,7 @@ func (s *propConstSolver) propagateConstantEQ() { // We maintain a unionSet representing the equivalent for every two columns. func (s *propConstSolver) propagateColumnEQ() { visited := make([]bool, len(s.conditions)) - s.unionSet = &multiEqualSet{} - s.unionSet.init(len(s.columns)) + s.unionSet = disjointset.NewIntSet(len(s.columns)) for i := range s.conditions { if fun, ok := s.conditions[i].(*ScalarFunction); ok && fun.FuncName.L == ast.EQ { lCol, lOk := fun.GetArgs()[0].(*Column) @@ -217,7 +194,7 @@ func (s *propConstSolver) propagateColumnEQ() { if lOk && rOk { lID := s.getColID(lCol) rID := s.getColID(rCol) - s.unionSet.addRelation(lID, rID) + s.unionSet.AddRelation(lID, rID) visited[i] = true } } @@ -227,7 +204,7 @@ func (s *propConstSolver) propagateColumnEQ() { for i, coli := range s.columns { for j := i + 1; j < len(s.columns); j++ { // unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation - if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) { + if s.unionSet.FindRoot(i) != s.unionSet.FindRoot(j) { continue } colj := s.columns[j] @@ -489,8 +466,7 @@ func (s *propOuterJoinConstSolver) deriveConds(outerCol, innerCol *Column, schem // Derived new expressions must be appended into join condition, not filter condition. func (s *propOuterJoinConstSolver) propagateColumnEQ() { visited := make([]bool, len(s.joinConds)+len(s.filterConds)) - s.unionSet = &multiEqualSet{} - s.unionSet.init(len(s.columns)) + s.unionSet = disjointset.NewIntSet(len(s.columns)) var outerCol, innerCol *Column // Only consider column equal condition in joinConds. // If we have column equal in filter condition, the outer join should have been simplified already. @@ -499,7 +475,7 @@ func (s *propOuterJoinConstSolver) propagateColumnEQ() { if outerCol != nil { outerID := s.getColID(outerCol) innerID := s.getColID(innerCol) - s.unionSet.addRelation(outerID, innerID) + s.unionSet.AddRelation(outerID, innerID) visited[i] = true } } @@ -508,7 +484,7 @@ func (s *propOuterJoinConstSolver) propagateColumnEQ() { for i, coli := range s.columns { for j := i + 1; j < len(s.columns); j++ { // unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation. - if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) { + if s.unionSet.FindRoot(i) != s.unionSet.FindRoot(j) { continue } colj := s.columns[j] diff --git a/util/disjointset/int_set.go b/util/disjointset/int_set.go new file mode 100644 index 0000000000000..319cf051c9d95 --- /dev/null +++ b/util/disjointset/int_set.go @@ -0,0 +1,45 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package disjointset + +type IntSet struct { + parent []int +} + +func NewIntSet(size int) *IntSet { + p := make([]int, size) + for i := range p { + p[i] = i + } + return &IntSet{parent: p} +} + +func (m *IntSet) Init(l int) { + m.parent = make([]int, l) + for i := range m.parent { + m.parent[i] = i + } +} + +func (m *IntSet) AddRelation(a int, b int) { + m.parent[m.FindRoot(a)] = m.FindRoot(b) +} + +func (m *IntSet) FindRoot(a int) int { + if a == m.parent[a] { + return a + } + m.parent[a] = m.FindRoot(m.parent[a]) + return m.parent[a] +} diff --git a/util/disjointset/int_set_test.go b/util/disjointset/int_set_test.go new file mode 100644 index 0000000000000..f53ac9f7f2346 --- /dev/null +++ b/util/disjointset/int_set_test.go @@ -0,0 +1,57 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package disjointset + +import ( + "testing" + + . "github.com/pingcap/check" +) + +var _ = Suite(&testDisjointSetSuite{}) + +func TestT(t *testing.T) { + CustomVerboseFlag = true + TestingT(t) +} + +type testDisjointSetSuite struct { +} + +func (s *testDisjointSetSuite) TestIntDisjointSet(c *C) { + set := NewIntSet(5) + c.Assert(len(set.parent), Equals, 5) + for i := range set.parent { + c.Assert(set.parent[i], Equals, i) + } + set.Init(10) + c.Assert(len(set.parent), Equals, 10) + for i := range set.parent { + c.Assert(set.parent[i], Equals, i) + } + set.AddRelation(0, 1) + set.AddRelation(1, 3) + set.AddRelation(4, 2) + set.AddRelation(2, 6) + set.AddRelation(3, 5) + set.AddRelation(7, 8) + set.AddRelation(9, 6) + c.Assert(set.FindRoot(0), Equals, set.FindRoot(1)) + c.Assert(set.FindRoot(3), Equals, set.FindRoot(1)) + c.Assert(set.FindRoot(5), Equals, set.FindRoot(1)) + c.Assert(set.FindRoot(2), Equals, set.FindRoot(4)) + c.Assert(set.FindRoot(6), Equals, set.FindRoot(4)) + c.Assert(set.FindRoot(9), Equals, set.FindRoot(2)) + c.Assert(set.FindRoot(7), Equals, set.FindRoot(8)) +} From 837b355f1897c1fbbe586d70f57a40c5daa4ecfc Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 18 Oct 2018 19:40:29 +0800 Subject: [PATCH 2/3] fix make check --- util/disjointset/int_set.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/util/disjointset/int_set.go b/util/disjointset/int_set.go index 319cf051c9d95..e4ba1bca55d79 100644 --- a/util/disjointset/int_set.go +++ b/util/disjointset/int_set.go @@ -17,6 +17,7 @@ type IntSet struct { parent []int } +// NewIntSet returns a new int disjoint set. func NewIntSet(size int) *IntSet { p := make([]int, size) for i := range p { @@ -25,6 +26,7 @@ func NewIntSet(size int) *IntSet { return &IntSet{parent: p} } +// Init inits or reset the int disjoint set. func (m *IntSet) Init(l int) { m.parent = make([]int, l) for i := range m.parent { @@ -32,10 +34,12 @@ func (m *IntSet) Init(l int) { } } +// AddRelation merges two sets in int disjoint set. func (m *IntSet) AddRelation(a int, b int) { m.parent[m.FindRoot(a)] = m.FindRoot(b) } +// FindRoot finds the representative element of the set that `a` belongs to. func (m *IntSet) FindRoot(a int) int { if a == m.parent[a] { return a From 8b2cef60fa887adb58c09982e6eca7d5006b33f2 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 18 Oct 2018 19:44:19 +0800 Subject: [PATCH 3/3] address comments --- expression/constant_propagation.go | 4 ++-- util/disjointset/int_set.go | 13 +++---------- util/disjointset/int_set_test.go | 21 ++++++++------------- 3 files changed, 13 insertions(+), 25 deletions(-) diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index 351fe81d2526d..11f9bbe9b8055 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -194,7 +194,7 @@ func (s *propConstSolver) propagateColumnEQ() { if lOk && rOk { lID := s.getColID(lCol) rID := s.getColID(rCol) - s.unionSet.AddRelation(lID, rID) + s.unionSet.Union(lID, rID) visited[i] = true } } @@ -475,7 +475,7 @@ func (s *propOuterJoinConstSolver) propagateColumnEQ() { if outerCol != nil { outerID := s.getColID(outerCol) innerID := s.getColID(innerCol) - s.unionSet.AddRelation(outerID, innerID) + s.unionSet.Union(outerID, innerID) visited[i] = true } } diff --git a/util/disjointset/int_set.go b/util/disjointset/int_set.go index e4ba1bca55d79..0881b4aa1b03f 100644 --- a/util/disjointset/int_set.go +++ b/util/disjointset/int_set.go @@ -13,6 +13,7 @@ package disjointset +// IntSet is the int disjoint set. type IntSet struct { parent []int } @@ -26,16 +27,8 @@ func NewIntSet(size int) *IntSet { return &IntSet{parent: p} } -// Init inits or reset the int disjoint set. -func (m *IntSet) Init(l int) { - m.parent = make([]int, l) - for i := range m.parent { - m.parent[i] = i - } -} - -// AddRelation merges two sets in int disjoint set. -func (m *IntSet) AddRelation(a int, b int) { +// Union unions two sets in int disjoint set. +func (m *IntSet) Union(a int, b int) { m.parent[m.FindRoot(a)] = m.FindRoot(b) } diff --git a/util/disjointset/int_set_test.go b/util/disjointset/int_set_test.go index f53ac9f7f2346..222c63ec5fc9f 100644 --- a/util/disjointset/int_set_test.go +++ b/util/disjointset/int_set_test.go @@ -30,23 +30,18 @@ type testDisjointSetSuite struct { } func (s *testDisjointSetSuite) TestIntDisjointSet(c *C) { - set := NewIntSet(5) - c.Assert(len(set.parent), Equals, 5) - for i := range set.parent { - c.Assert(set.parent[i], Equals, i) - } - set.Init(10) + set := NewIntSet(10) c.Assert(len(set.parent), Equals, 10) for i := range set.parent { c.Assert(set.parent[i], Equals, i) } - set.AddRelation(0, 1) - set.AddRelation(1, 3) - set.AddRelation(4, 2) - set.AddRelation(2, 6) - set.AddRelation(3, 5) - set.AddRelation(7, 8) - set.AddRelation(9, 6) + set.Union(0, 1) + set.Union(1, 3) + set.Union(4, 2) + set.Union(2, 6) + set.Union(3, 5) + set.Union(7, 8) + set.Union(9, 6) c.Assert(set.FindRoot(0), Equals, set.FindRoot(1)) c.Assert(set.FindRoot(3), Equals, set.FindRoot(1)) c.Assert(set.FindRoot(5), Equals, set.FindRoot(1))