Skip to content

Commit

Permalink
planner: refactor the join reorder codes (#34380)
Browse files Browse the repository at this point in the history
ref #29932
  • Loading branch information
Reminiscent authored May 11, 2022
1 parent bbd7541 commit fe07324
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 61 deletions.
73 changes: 70 additions & 3 deletions planner/core/rule_join_reorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"sort"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/plancodec"
"github.com/pingcap/tidb/util/tracing"
Expand Down Expand Up @@ -102,7 +103,10 @@ func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalP
baseGroupSolver := &baseSingleGroupJoinOrderSolver{
ctx: ctx,
otherConds: otherConds,
eqEdges: eqEdges,
joinTypes: joinTypes,
}

originalSchema := p.Schema()

// Not support outer join reorder when using the DP algorithm
Expand All @@ -116,16 +120,14 @@ func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalP
if len(curJoinGroup) > ctx.GetSessionVars().TiDBOptJoinReorderThreshold || !isSupportDP {
groupSolver := &joinReorderGreedySolver{
baseSingleGroupJoinOrderSolver: baseGroupSolver,
eqEdges: eqEdges,
joinTypes: joinTypes,
}
p, err = groupSolver.solve(curJoinGroup, tracer)
} else {
dpSolver := &joinReorderDPSolver{
baseSingleGroupJoinOrderSolver: baseGroupSolver,
}
dpSolver.newJoin = dpSolver.newJoinWithEdges
p, err = dpSolver.solve(curJoinGroup, expression.ScalarFuncs2Exprs(eqEdges), tracer)
p, err = dpSolver.solve(curJoinGroup, tracer)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -168,6 +170,26 @@ type baseSingleGroupJoinOrderSolver struct {
ctx sessionctx.Context
curJoinGroup []*jrNode
otherConds []expression.Expression
eqEdges []*expression.ScalarFunction
joinTypes []JoinType
}

// generateJoinOrderNode used to derive the stats for the joinNodePlans and generate the jrNode groups based on the cost.
func (s *baseSingleGroupJoinOrderSolver) generateJoinOrderNode(joinNodePlans []LogicalPlan, tracer *joinReorderTrace) ([]*jrNode, error) {
joinGroup := make([]*jrNode, 0, len(joinNodePlans))
for _, node := range joinNodePlans {
_, err := node.recursiveDeriveStats(nil)
if err != nil {
return nil, err
}
cost := s.baseNodeCumCost(node)
joinGroup = append(joinGroup, &jrNode{
p: node,
cumCost: cost,
})
tracer.appendLogicalJoinCost(node, cost)
}
return joinGroup, nil
}

// baseNodeCumCost calculate the cumulative cost of the node in the join group.
Expand All @@ -179,6 +201,51 @@ func (s *baseSingleGroupJoinOrderSolver) baseNodeCumCost(groupNode LogicalPlan)
return cost
}

// checkConnection used to check whether two nodes have equal conditions or not.
func (s *baseSingleGroupJoinOrderSolver) checkConnection(leftPlan, rightPlan LogicalPlan) (leftNode, rightNode LogicalPlan, usedEdges []*expression.ScalarFunction, joinType JoinType) {
joinType = InnerJoin
leftNode, rightNode = leftPlan, rightPlan
for idx, edge := range s.eqEdges {
lCol := edge.GetArgs()[0].(*expression.Column)
rCol := edge.GetArgs()[1].(*expression.Column)
if leftPlan.Schema().Contains(lCol) && rightPlan.Schema().Contains(rCol) {
joinType = s.joinTypes[idx]
usedEdges = append(usedEdges, edge)
} else if rightPlan.Schema().Contains(lCol) && leftPlan.Schema().Contains(rCol) {
joinType = s.joinTypes[idx]
if joinType != InnerJoin {
rightNode, leftNode = leftPlan, rightPlan
usedEdges = append(usedEdges, edge)
} else {
newSf := expression.NewFunctionInternal(s.ctx, ast.EQ, edge.GetType(), rCol, lCol).(*expression.ScalarFunction)
usedEdges = append(usedEdges, newSf)
}
}
}
return
}

// makeJoin build join tree for the nodes which have equal conditions to connect them.
func (s *baseSingleGroupJoinOrderSolver) makeJoin(leftPlan, rightPlan LogicalPlan, eqEdges []*expression.ScalarFunction, joinType JoinType) (LogicalPlan, []expression.Expression) {
remainOtherConds := make([]expression.Expression, len(s.otherConds))
copy(remainOtherConds, s.otherConds)
var otherConds []expression.Expression
var leftConds []expression.Expression
var rightConds []expression.Expression
mergedSchema := expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema())

remainOtherConds, leftConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool {
return expression.ExprFromSchema(expr, leftPlan.Schema()) && !expression.ExprFromSchema(expr, rightPlan.Schema())
})
remainOtherConds, rightConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool {
return expression.ExprFromSchema(expr, rightPlan.Schema()) && !expression.ExprFromSchema(expr, leftPlan.Schema())
})
remainOtherConds, otherConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool {
return expression.ExprFromSchema(expr, mergedSchema)
})
return s.newJoinWithEdges(leftPlan, rightPlan, eqEdges, otherConds, leftConds, rightConds, joinType), remainOtherConds
}

// makeBushyJoin build bushy tree for the nodes which have no equal condition to connect them.
func (s *baseSingleGroupJoinOrderSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan {
resultJoinGroup := make([]LogicalPlan, 0, (len(cartesianJoinGroup)+1)/2)
Expand Down
3 changes: 2 additions & 1 deletion planner/core/rule_join_reorder_dp.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ type joinGroupNonEqEdge struct {
expr expression.Expression
}

func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, eqConds []expression.Expression, tracer *joinReorderTrace) (LogicalPlan, error) {
func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, tracer *joinReorderTrace) (LogicalPlan, error) {
eqConds := expression.ScalarFuncs2Exprs(s.eqEdges)
for _, node := range joinGroup {
_, err := node.recursiveDeriveStats(nil)
if err != nil {
Expand Down
20 changes: 14 additions & 6 deletions planner/core/rule_join_reorder_dp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,21 @@ func TestDPReorderTPCHQ5(t *testing.T) {
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[4].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[3].Schema().Columns[0], joinGroups[4].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[4].Schema().Columns[0], joinGroups[5].Schema().Columns[0]))
eqEdges := make([]*expression.ScalarFunction, 0, len(eqConds))
for _, cond := range eqConds {
sf, isSF := cond.(*expression.ScalarFunction)
require.True(t, isSF)
eqEdges = append(eqEdges, sf)
}
baseGroupSolver := &baseSingleGroupJoinOrderSolver{
ctx: ctx,
eqEdges: eqEdges,
}
solver := &joinReorderDPSolver{
baseSingleGroupJoinOrderSolver: &baseSingleGroupJoinOrderSolver{
ctx: ctx,
},
newJoin: newMockJoin(ctx, statsMap),
baseSingleGroupJoinOrderSolver: baseGroupSolver,
newJoin: newMockJoin(ctx, statsMap),
}
result, err := solver.solve(joinGroups, eqConds, nil)
result, err := solver.solve(joinGroups, nil)
require.NoError(t, err)

expected := "MockJoin{supplier, MockJoin{lineitem, MockJoin{orders, MockJoin{customer, MockJoin{nation, region}}}}}"
Expand All @@ -210,7 +218,7 @@ func TestDPReorderAllCartesian(t *testing.T) {
},
newJoin: newMockJoin(ctx, statsMap),
}
result, err := solver.solve(joinGroup, nil, nil)
result, err := solver.solve(joinGroup, nil)
require.NoError(t, err)

expected := "MockJoin{MockJoin{a, b}, MockJoin{c, d}}"
Expand Down
57 changes: 6 additions & 51 deletions planner/core/rule_join_reorder_greedy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ import (
"sort"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/ast"
)

type joinReorderGreedySolver struct {
*baseSingleGroupJoinOrderSolver
eqEdges []*expression.ScalarFunction
joinTypes []JoinType
}

// solve reorders the join nodes in the group based on a greedy algorithm.
Expand All @@ -43,19 +40,11 @@ type joinReorderGreedySolver struct {
// For the nodes and join trees which don't have a join equal condition to
// connect them, we make a bushy join tree to do the cartesian joins finally.
func (s *joinReorderGreedySolver) solve(joinNodePlans []LogicalPlan, tracer *joinReorderTrace) (LogicalPlan, error) {
for _, node := range joinNodePlans {
_, err := node.recursiveDeriveStats(nil)
if err != nil {
return nil, err
}
cost := s.baseNodeCumCost(node)
s.curJoinGroup = append(s.curJoinGroup, &jrNode{
p: node,
cumCost: cost,
})
tracer.appendLogicalJoinCost(node, cost)
var err error
s.curJoinGroup, err = s.generateJoinOrderNode(joinNodePlans, tracer)
if err != nil {
return nil, err
}

// Sort plans by cost
sort.SliceStable(s.curJoinGroup, func(i, j int) bool {
return s.curJoinGroup[i].cumCost < s.curJoinGroup[j].cumCost
Expand Down Expand Up @@ -114,43 +103,9 @@ func (s *joinReorderGreedySolver) constructConnectedJoinTree(tracer *joinReorder
}

func (s *joinReorderGreedySolver) checkConnectionAndMakeJoin(leftPlan, rightPlan LogicalPlan) (LogicalPlan, []expression.Expression) {
var usedEdges []*expression.ScalarFunction
remainOtherConds := make([]expression.Expression, len(s.otherConds))
copy(remainOtherConds, s.otherConds)
joinType := InnerJoin
for idx, edge := range s.eqEdges {
lCol := edge.GetArgs()[0].(*expression.Column)
rCol := edge.GetArgs()[1].(*expression.Column)
if leftPlan.Schema().Contains(lCol) && rightPlan.Schema().Contains(rCol) {
joinType = s.joinTypes[idx]
usedEdges = append(usedEdges, edge)
} else if rightPlan.Schema().Contains(lCol) && leftPlan.Schema().Contains(rCol) {
joinType = s.joinTypes[idx]
if joinType != InnerJoin {
rightPlan, leftPlan = leftPlan, rightPlan
usedEdges = append(usedEdges, edge)
} else {
newSf := expression.NewFunctionInternal(s.ctx, ast.EQ, edge.GetType(), rCol, lCol).(*expression.ScalarFunction)
usedEdges = append(usedEdges, newSf)
}
}
}
leftPlan, rightPlan, usedEdges, joinType := s.checkConnection(leftPlan, rightPlan)
if len(usedEdges) == 0 {
return nil, nil
}
var otherConds []expression.Expression
var leftConds []expression.Expression
var rightConds []expression.Expression
mergedSchema := expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema())

remainOtherConds, leftConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool {
return expression.ExprFromSchema(expr, leftPlan.Schema()) && !expression.ExprFromSchema(expr, rightPlan.Schema())
})
remainOtherConds, rightConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool {
return expression.ExprFromSchema(expr, rightPlan.Schema()) && !expression.ExprFromSchema(expr, leftPlan.Schema())
})
remainOtherConds, otherConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool {
return expression.ExprFromSchema(expr, mergedSchema)
})
return s.newJoinWithEdges(leftPlan, rightPlan, usedEdges, otherConds, leftConds, rightConds, joinType), remainOtherConds
return s.makeJoin(leftPlan, rightPlan, usedEdges, joinType)
}

0 comments on commit fe07324

Please sign in to comment.