Skip to content

Commit

Permalink
planner/core: make join reorder by dp work
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros committed Dec 25, 2018
1 parent 1c4d2d9 commit 0fc7469
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 47 deletions.
2 changes: 1 addition & 1 deletion planner/core/cbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ func (s *testAnalyzeSuite) TestEmptyTable(c *C) {
},
{
sql: "select * from t where c1 in (select c1 from t1)",
best: "RightHashJoin{TableReader(Table(t1)->HashAgg)->HashAgg->TableReader(Table(t))}(test.t1.c1,test.t.c1)->Projection",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t1)->HashAgg)->HashAgg}(test.t.c1,test.t1.c1)->Projection",
},
{
sql: "select * from t, t1 where t.c1 = t1.c1",
Expand Down
8 changes: 4 additions & 4 deletions planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,27 +767,27 @@ func (s *testPlanSuite) TestJoinReOrder(c *C) {
}{
{
sql: "select * from t t1, t t2, t t3, t t4, t t5, t t6 where t1.a = t2.b and t2.a = t3.b and t3.c = t4.a and t4.d = t2.c and t5.d = t6.d",
best: "Join{Join{Join{Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.b)->DataScan(t3)}(t2.a,t3.b)->DataScan(t4)}(t3.c,t4.a)(t2.c,t4.d)->Join{DataScan(t5)->DataScan(t6)}(t5.d,t6.d)}->Projection",
best: "Join{Join{Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.b)->Join{DataScan(t4)->DataScan(t3)}(t4.a,t3.c)}(t2.c,t4.d)(t2.a,t3.b)->Join{DataScan(t5)->DataScan(t6)}(t5.d,t6.d)(t2.b,t1.a)}->Projection",
},
{
sql: "select * from t t1, t t2, t t3, t t4, t t5, t t6, t t7, t t8 where t1.a = t8.a",
best: "Join{Join{Join{Join{DataScan(t1)->DataScan(t8)}(t1.a,t8.a)->DataScan(t2)}->Join{DataScan(t3)->DataScan(t4)}}->Join{Join{DataScan(t5)->DataScan(t6)}->DataScan(t7)}}->Projection",
},
{
sql: "select * from t t1, t t2, t t3, t t4, t t5 where t1.a = t5.a and t5.a = t4.a and t4.a = t3.a and t3.a = t2.a and t2.a = t1.a and t1.a = t3.a and t2.a = t4.a and t5.b < 8",
best: "Join{Join{Join{Join{DataScan(t5)->DataScan(t1)}(t5.a,t1.a)->DataScan(t2)}(t1.a,t2.a)->DataScan(t3)}(t2.a,t3.a)(t1.a,t3.a)->DataScan(t4)}(t5.a,t4.a)(t3.a,t4.a)(t2.a,t4.a)->Projection",
best: "Join{Join{Join{Join{DataScan(t1)->DataScan(t5)}(t1.a,t5.a)->DataScan(t3)}(t1.a,t3.a)->DataScan(t2)}(t3.a,t2.a)(t1.a,t2.a)->DataScan(t4)}(t5.a,t4.a)(t3.a,t4.a)(t2.a,t4.a)->Projection",
},
{
sql: "select * from t t1, t t2, t t3, t t4, t t5 where t1.a = t5.a and t5.a = t4.a and t4.a = t3.a and t3.a = t2.a and t2.a = t1.a and t1.a = t3.a and t2.a = t4.a and t3.b = 1 and t4.a = 1",
best: "Join{Join{Join{DataScan(t3)->DataScan(t1)}->Join{DataScan(t2)->DataScan(t4)}}->DataScan(t5)}->Projection",
best: "Join{Join{Join{DataScan(t1)->DataScan(t2)}->Join{DataScan(t3)->DataScan(t4)}}->DataScan(t5)}->Projection",
},
{
sql: "select * from t o where o.b in (select t3.c from t t1, t t2, t t3 where t1.a = t3.a and t2.a = t3.a and t2.a = o.a)",
best: "Apply{DataScan(o)->Join{Join{DataScan(t1)->DataScan(t3)}(t1.a,t3.a)->DataScan(t2)}(t3.a,t2.a)->Projection}->Projection",
},
{
sql: "select * from t o where o.b in (select t3.c from t t1, t t2, t t3 where t1.a = t3.a and t2.a = t3.a and t2.a = o.a and t1.a = 1)",
best: "Apply{DataScan(o)->Join{Join{DataScan(t3)->DataScan(t1)}->DataScan(t2)}->Projection}->Projection",
best: "Apply{DataScan(o)->Join{Join{DataScan(t1)->DataScan(t2)}->DataScan(t3)}->Projection}->Projection",
},
}
for _, tt := range tests {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ var optRuleList = []logicalOptRule{
&partitionProcessor{},
&aggregationPushDownSolver{},
&pushDownTopNOptimizer{},
&joinReOrderGreedySolver{},
&joinReOrderSolver{},
}

// logicalOptRule means a logical optimizing rule, which contains decorrelate, ppd, column pruning, etc.
Expand Down
6 changes: 3 additions & 3 deletions planner/core/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ func (s *testPlanSuite) TestDAGPlanBuilderJoin(c *C) {
},
{
sql: "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.a = t3.a",
best: "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(t1.a,t2.a)->TableReader(Table(t))}(t1.a,t3.a)",
best: "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(t1.a,t3.a)->TableReader(Table(t))}(t1.a,t2.a)->Projection",
},
{
sql: "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.b = t3.a",
best: "LeftHashJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(t1.a,t2.a)->TableReader(Table(t))}(t1.b,t3.a)",
best: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(t1.b,t3.a)->TableReader(Table(t))}(t1.a,t2.a)->Projection",
},
{
sql: "select * from t t1 join t t2 on t1.b = t2.a order by t1.a",
Expand All @@ -269,7 +269,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderJoin(c *C) {
},
{
sql: "select * from t t1 join t t2 on t1.b = t2.b join t t3 on t1.b = t3.b",
best: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(t1.b,t2.b)->TableReader(Table(t))}(t1.b,t3.b)",
best: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(t1.b,t3.b)->TableReader(Table(t))}(t1.b,t2.b)->Projection",
},
{
sql: "select * from t t1 join t t2 on t1.a = t2.a order by t1.a",
Expand Down
135 changes: 113 additions & 22 deletions planner/core/rule_join_reorder_dp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,33 @@ import (

type joinReorderDPSolver struct {
ctx sessionctx.Context
newJoin func(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction) LogicalPlan
newJoin func(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan
}

type joinGroupEdge struct {
type joinGroupEqEdge struct {
nodeIDs []int
edge *expression.ScalarFunction
}

func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.Expression) (LogicalPlan, error) {
type joinGroupNonEqEdge struct {
nodeIDs []int
idMask uint
expr expression.Expression
}

func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, eqConds, otherConds []expression.Expression) (LogicalPlan, error) {
adjacents := make([][]int, len(joinGroup))
totalEdges := make([]joinGroupEdge, 0, len(conds))
addEdge := func(node1, node2 int, edgeContent *expression.ScalarFunction) {
totalEdges = append(totalEdges, joinGroupEdge{
totalEqEdges := make([]joinGroupEqEdge, 0, len(eqConds))
addEqEdge := func(node1, node2 int, edgeContent *expression.ScalarFunction) {
totalEqEdges = append(totalEqEdges, joinGroupEqEdge{
nodeIDs: []int{node1, node2},
edge: edgeContent,
})
adjacents[node1] = append(adjacents[node1], node2)
adjacents[node2] = append(adjacents[node2], node1)
}
// Build Graph for join group
for _, cond := range conds {
for _, cond := range eqConds {
sf := cond.(*expression.ScalarFunction)
lCol := sf.GetArgs()[0].(*expression.Column)
rCol := sf.GetArgs()[1].(*expression.Column)
Expand All @@ -55,7 +61,23 @@ func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.
if err != nil {
return nil, err
}
addEdge(lIdx, rIdx, sf)
addEqEdge(lIdx, rIdx, sf)
}
totalNonEqEdges := make([]joinGroupNonEqEdge, 0, len(otherConds))
for _, cond := range otherConds {
cols := expression.ExtractColumns(cond)
mask := uint(0)
for _, col := range cols {
idx, err := findNodeIndexInGroup(joinGroup, col)
if err != nil {
return nil, err
}
mask |= 1 << uint(idx)
}
totalNonEqEdges = append(totalNonEqEdges, joinGroupNonEqEdge{
idMask: mask,
expr: cond,
})
}
visited := make([]bool, len(joinGroup))
nodeID2VisitID := make([]int, len(joinGroup))
Expand All @@ -66,15 +88,37 @@ func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.
continue
}
visitID2NodeID := s.bfsGraph(i, visited, adjacents, nodeID2VisitID)
nodeIDMask := uint(0)
for _, nodeID := range visitID2NodeID {
nodeIDMask |= uint(nodeID)
}
var subNonEqEdges []joinGroupNonEqEdge
for i := len(totalNonEqEdges) - 1; i >= 0; i-- {
// If this edge is not the subset of the current sub graph.
if totalNonEqEdges[i].idMask&nodeIDMask != nodeIDMask {
continue
}
newMask := uint(0)
for _, nodeID := range totalNonEqEdges[i].nodeIDs {
newMask |= uint(nodeID)
}
totalNonEqEdges[i].idMask = newMask
subNonEqEdges = append(subNonEqEdges, totalNonEqEdges[i])
totalNonEqEdges = append(totalNonEqEdges[:i], totalNonEqEdges[i+1:]...)
}
// Do DP on each sub graph.
join, err := s.dpGraph(visitID2NodeID, nodeID2VisitID, joinGroup, totalEdges)
join, err := s.dpGraph(visitID2NodeID, nodeID2VisitID, joinGroup, totalEqEdges, subNonEqEdges)
if err != nil {
return nil, err
}
joins = append(joins, join)
}
remainedOtherConds := make([]expression.Expression, 0, len(totalNonEqEdges))
for _, edge := range totalNonEqEdges {
remainedOtherConds = append(remainedOtherConds, edge.expr)
}
// Build bushy tree for cartesian joins.
return s.makeBushyJoin(joins), nil
return s.makeBushyJoin(joins, remainedOtherConds), nil
}

// bfsGraph bfs a sub graph starting at startPos. And relabel its label for future use.
Expand All @@ -98,7 +142,8 @@ func (s *joinReorderDPSolver) bfsGraph(startNode int, visited []bool, adjacents
return visitID2NodeID
}

func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGroup []LogicalPlan, totalEdges []joinGroupEdge) (LogicalPlan, error) {
func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGroup []LogicalPlan,
totalEqEdges []joinGroupEqEdge, totalNonEqEdges []joinGroupNonEqEdge) (LogicalPlan, error) {
nodeCnt := uint(len(newPos2OldPos))
bestPlan := make([]LogicalPlan, 1<<nodeCnt)
bestCost := make([]int64, 1<<nodeCnt)
Expand All @@ -122,11 +167,12 @@ func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGr
continue
}
// Get the edge connecting the two parts.
usedEdges := s.nodesAreConnected(sub, remain, oldPos2NewPos, totalEdges)
usedEdges, otherConds := s.nodesAreConnected(sub, remain, oldPos2NewPos, totalEqEdges, totalNonEqEdges)
// Here we only check equal condition currently.
if len(usedEdges) == 0 {
continue
}
join, err := s.newJoinWithEdge(bestPlan[sub], bestPlan[remain], usedEdges)
join, err := s.newJoinWithEdge(bestPlan[sub], bestPlan[remain], usedEdges, otherConds)
if err != nil {
return nil, err
}
Expand All @@ -139,21 +185,36 @@ func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGr
return bestPlan[(1<<nodeCnt)-1], nil
}

func (s *joinReorderDPSolver) nodesAreConnected(leftMask, rightMask uint, oldPos2NewPos []int, totalEdges []joinGroupEdge) []joinGroupEdge {
var usedEdges []joinGroupEdge
func (s *joinReorderDPSolver) nodesAreConnected(leftMask, rightMask uint, oldPos2NewPos []int,
totalEdges []joinGroupEqEdge, totalNonEqEdges []joinGroupNonEqEdge) ([]joinGroupEqEdge, []expression.Expression) {
var (
usedEqEdges []joinGroupEqEdge
otherConds []expression.Expression
)
for _, edge := range totalEdges {
lIdx := uint(oldPos2NewPos[edge.nodeIDs[0]])
rIdx := uint(oldPos2NewPos[edge.nodeIDs[1]])
if (leftMask&(1<<lIdx)) > 0 && (rightMask&(1<<rIdx)) > 0 {
usedEdges = append(usedEdges, edge)
usedEqEdges = append(usedEqEdges, edge)
} else if (leftMask&(1<<rIdx)) > 0 && (rightMask&(1<<lIdx)) > 0 {
usedEdges = append(usedEdges, edge)
usedEqEdges = append(usedEqEdges, edge)
}
}
for _, edge := range totalNonEqEdges {
// If the result is false, means that the current group hasn't covered the columns involved in the expression.
if edge.idMask&(leftMask|rightMask) != edge.idMask {
continue
}
// Check whether this expression is only built from one side of the join.
if edge.idMask&leftMask == 0 || edge.idMask&rightMask == 0 {
continue
}
otherConds = append(otherConds, edge.expr)
}
return usedEdges
return usedEqEdges, otherConds
}

func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, edges []joinGroupEdge) (LogicalPlan, error) {
func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, edges []joinGroupEqEdge, otherConds []expression.Expression) (LogicalPlan, error) {
var eqConds []*expression.ScalarFunction
for _, edge := range edges {
lCol := edge.edge.GetArgs()[0].(*expression.Column)
Expand All @@ -165,21 +226,30 @@ func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, e
eqConds = append(eqConds, newSf)
}
}
join := s.newJoin(leftPlan, rightPlan, eqConds)
join := s.newJoin(leftPlan, rightPlan, eqConds, otherConds)
_, err := join.deriveStats()
return join, err
}

// Make cartesian join as bushy tree.
func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan {
func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan, otherConds []expression.Expression) LogicalPlan {
for len(cartesianJoinGroup) > 1 {
resultJoinGroup := make([]LogicalPlan, 0, len(cartesianJoinGroup))
for i := 0; i < len(cartesianJoinGroup); i += 2 {
if i+1 == len(cartesianJoinGroup) {
resultJoinGroup = append(resultJoinGroup, cartesianJoinGroup[i])
break
}
resultJoinGroup = append(resultJoinGroup, s.newJoin(cartesianJoinGroup[i], cartesianJoinGroup[i+1], nil))
mergedSchema := expression.MergeSchema(cartesianJoinGroup[i].Schema(), cartesianJoinGroup[i+1].Schema())
var usedOtherConds []expression.Expression
for i := len(otherConds) - 1; i >= 0; i-- {
cols := expression.ExtractColumns(otherConds[i])
if mergedSchema.ColumnsIndices(cols) != nil {
usedOtherConds = append(usedOtherConds, otherConds[i])
otherConds = append(otherConds[:i], otherConds[i+1:]...)
}
}
resultJoinGroup = append(resultJoinGroup, s.newJoin(cartesianJoinGroup[i], cartesianJoinGroup[i+1], nil, usedOtherConds))
}
cartesianJoinGroup = resultJoinGroup
}
Expand All @@ -194,3 +264,24 @@ func findNodeIndexInGroup(group []LogicalPlan, col *expression.Column) (int, err
}
return -1, ErrUnknownColumn.GenWithStackByArgs(col, "JOIN REORDER RULE")
}

func (s *joinReorderDPSolver) newJoinWithConds(leftPlan, rightPlan LogicalPlan, eqConds []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan {
join := s.newCartesianJoin(leftPlan, rightPlan)
join.EqualConditions = eqConds
join.OtherConditions = otherConds
for _, eqCond := range join.EqualConditions {
join.LeftJoinKeys = append(join.LeftJoinKeys, eqCond.GetArgs()[0].(*expression.Column))
join.RightJoinKeys = append(join.RightJoinKeys, eqCond.GetArgs()[1].(*expression.Column))
}
return join
}

func (s *joinReorderDPSolver) newCartesianJoin(lChild, rChild LogicalPlan) *LogicalJoin {
join := LogicalJoin{
JoinType: InnerJoin,
reordered: true,
}.Init(s.ctx)
join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema()))
join.SetChildren(lChild, rChild)
return join
}
6 changes: 3 additions & 3 deletions planner/core/rule_join_reorder_dp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (mj *mockLogicalJoin) deriveStats() (*property.StatsInfo, error) {
return mj.statsMap[mj.involvedNodeSet], nil
}

func (s *testJoinReorderDPSuite) newMockJoin(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction) LogicalPlan {
func (s *testJoinReorderDPSuite) newMockJoin(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction, _ []expression.Expression) LogicalPlan {
retJoin := mockLogicalJoin{}.init(s.ctx)
retJoin.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema())
retJoin.statsMap = s.statsMap
Expand Down Expand Up @@ -192,7 +192,7 @@ func (s *testJoinReorderDPSuite) TestDPReorderTPCHQ5(c *C) {
ctx: s.ctx,
newJoin: s.newMockJoin,
}
result, err := solver.solve(joinGroups, eqConds)
result, err := solver.solve(joinGroups, eqConds, nil)
c.Assert(err, IsNil)
c.Assert(s.planToString(result), Equals, "MockJoin{supplier, MockJoin{lineitem, MockJoin{orders, MockJoin{customer, MockJoin{nation, region}}}}}")
}
Expand All @@ -207,7 +207,7 @@ func (s *testJoinReorderDPSuite) TestDPReorderAllCartesian(c *C) {
ctx: s.ctx,
newJoin: s.newMockJoin,
}
result, err := solver.solve(joinGroup, nil)
result, err := solver.solve(joinGroup, nil, nil)
c.Assert(err, IsNil)
c.Assert(s.planToString(result), Equals, "MockJoin{MockJoin{a, b}, MockJoin{c, d}}")
}
38 changes: 25 additions & 13 deletions planner/core/rule_join_reorder_greedy.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func extractJoinGroup(p LogicalPlan) (group []LogicalPlan, eqEdges []*expression
return group, eqEdges, otherConds
}

type joinReOrderGreedySolver struct {
type joinReOrderSolver struct {
}

type joinReorderGreedySingleGroupSolver struct {
Expand Down Expand Up @@ -196,11 +196,11 @@ func (s *joinReorderGreedySingleGroupSolver) newCartesianJoin(lChild, rChild Log
return join
}

func (s *joinReOrderGreedySolver) optimize(p LogicalPlan) (LogicalPlan, error) {
func (s *joinReOrderSolver) optimize(p LogicalPlan) (LogicalPlan, error) {
return s.optimizeRecursive(p.context(), p)
}

func (s *joinReOrderGreedySolver) optimizeRecursive(ctx sessionctx.Context, p LogicalPlan) (LogicalPlan, error) {
func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalPlan) (LogicalPlan, error) {
var err error
curJoinGroup, eqEdges, otherConds := extractJoinGroup(p)
if len(curJoinGroup) > 1 {
Expand All @@ -210,17 +210,29 @@ func (s *joinReOrderGreedySolver) optimizeRecursive(ctx sessionctx.Context, p Lo
return nil, err
}
}
groupSolver := &joinReorderGreedySingleGroupSolver{
ctx: ctx,
curJoinGroup: curJoinGroup,
eqEdges: eqEdges,
otherConds: otherConds,
}
p, err = groupSolver.solve()
if err != nil {
return nil, err
if len(curJoinGroup) > 10 {
greedySolver := &joinReorderGreedySingleGroupSolver{
ctx: ctx,
curJoinGroup: curJoinGroup,
eqEdges: eqEdges,
otherConds: otherConds,
}
p, err = greedySolver.solve()
if err != nil {
return nil, err
}
return p, nil
} else {
dpSolver := &joinReorderDPSolver{
ctx: ctx,
}
dpSolver.newJoin = dpSolver.newJoinWithConds
p, err = dpSolver.solve(curJoinGroup, expression.ScalarFuncs2Exprs(eqEdges), otherConds)
if err != nil {
return nil, err
}
return p, nil
}
return p, nil
}
newChildren := make([]LogicalPlan, 0, len(p.Children()))
for _, child := range p.Children() {
Expand Down

0 comments on commit 0fc7469

Please sign in to comment.