Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
38610: exec: add support for RIGHT OUTER in merge join r=yuzefovich a=yuzefovich

First commit adds support for RIGHT OUTER JOIN in vectorized merge join.

Second commits does some housekeeping.

Addresses: cockroachdb#37166.

Co-authored-by: Yahor Yuzefovich <[email protected]>
  • Loading branch information
craig[bot] and yuzefovich committed Jul 11, 2019
2 parents 3f42686 + 77ad34e commit e41c2da
Show file tree
Hide file tree
Showing 7 changed files with 517 additions and 125 deletions.
7 changes: 5 additions & 2 deletions pkg/sql/distsqlrun/column_exec_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,11 @@ func newColOperator(
if !core.MergeJoiner.OnExpr.Empty() {
return nil, nil, errors.Newf("can't plan merge join with on expressions")
}
if core.MergeJoiner.Type != sqlbase.InnerJoin && core.MergeJoiner.Type != sqlbase.LeftOuterJoin {
return nil, nil, errors.Newf("can plan only inner and left outer merge join")
if core.MergeJoiner.Type != sqlbase.JoinType_INNER &&
core.MergeJoiner.Type != sqlbase.JoinType_LEFT_OUTER &&
core.MergeJoiner.Type != sqlbase.JoinType_RIGHT_OUTER {
return nil, nil, errors.Newf("can plan only inner, left outer, and " +
"right outer merge joins")
}

leftTypes := conv.FromColumnTypes(spec.Input[0].ColumnTypes)
Expand Down
102 changes: 54 additions & 48 deletions pkg/sql/distsqlrun/columnar_operators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package distsqlrun

import (
"context"
"fmt"
"math/rand"
"sort"
"testing"
Expand Down Expand Up @@ -117,63 +118,68 @@ func TestMergeJoinerAgainstProcessor(t *testing.T) {
var da sqlbase.DatumAlloc
evalCtx := tree.MakeTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
rng, _ := randutil.NewPseudoRand()
seed := rand.Int()
rng := rand.New(rand.NewSource(int64(seed)))

nRows := 100
maxCols := 5
maxNum := 10
nRuns := 10
nRows := 10
maxCols := 3
maxNum := 5
nullProbability := 0.1
typs := make([]types.T, maxCols)
for i := range typs {
// TODO (georgeutsin): Randomize the types of the columns.
typs[i] = *types.Int
}
for _, joinType := range []sqlbase.JoinType{
sqlbase.JoinType_INNER,
sqlbase.JoinType_LEFT_OUTER,
} {
for nCols := 1; nCols <= maxCols; nCols++ {
inputTypes := typs[:nCols]
// Note: we're only generating column orderings on all nCols columns since
// if there are columns not in the ordering, the results are not fully
// deterministic.
lOrderingCols := generateColumnOrdering(rng, nCols, nCols)
rOrderingCols := generateColumnOrdering(rng, nCols, nCols)
// Set the directions of both columns to be the same.
for i, lCol := range lOrderingCols {
rOrderingCols[i].Direction = lCol.Direction
}
for run := 1; run < nRuns; run++ {
for _, joinType := range []sqlbase.JoinType{
sqlbase.JoinType_INNER,
sqlbase.JoinType_LEFT_OUTER,
sqlbase.JoinType_RIGHT_OUTER,
} {
for nCols := 1; nCols <= maxCols; nCols++ {
for nOrderingCols := 1; nOrderingCols <= nCols; nOrderingCols++ {
inputTypes := typs[:nCols]
lOrderingCols := generateColumnOrdering(rng, nCols, nOrderingCols)
rOrderingCols := generateColumnOrdering(rng, nCols, nOrderingCols)
// Set the directions of both columns to be the same.
for i, lCol := range lOrderingCols {
rOrderingCols[i].Direction = lCol.Direction
}

lRows := sqlbase.MakeRandIntRowsInRange(rng, nRows, nCols, maxNum, nullProbability)
rRows := sqlbase.MakeRandIntRowsInRange(rng, nRows, nCols, maxNum, nullProbability)
lMatchedCols := distsqlpb.ConvertToColumnOrdering(distsqlpb.Ordering{Columns: lOrderingCols})
rMatchedCols := distsqlpb.ConvertToColumnOrdering(distsqlpb.Ordering{Columns: rOrderingCols})
sort.Slice(lRows, func(i, j int) bool {
cmp, err := lRows[i].Compare(inputTypes, &da, lMatchedCols, &evalCtx, lRows[j])
if err != nil {
t.Fatal(err)
}
return cmp < 0
})
sort.Slice(rRows, func(i, j int) bool {
cmp, err := rRows[i].Compare(inputTypes, &da, rMatchedCols, &evalCtx, rRows[j])
if err != nil {
t.Fatal(err)
}
return cmp < 0
})
lRows := sqlbase.MakeRandIntRowsInRange(rng, nRows, nCols, maxNum, nullProbability)
rRows := sqlbase.MakeRandIntRowsInRange(rng, nRows, nCols, maxNum, nullProbability)
lMatchedCols := distsqlpb.ConvertToColumnOrdering(distsqlpb.Ordering{Columns: lOrderingCols})
rMatchedCols := distsqlpb.ConvertToColumnOrdering(distsqlpb.Ordering{Columns: rOrderingCols})
sort.Slice(lRows, func(i, j int) bool {
cmp, err := lRows[i].Compare(inputTypes, &da, lMatchedCols, &evalCtx, lRows[j])
if err != nil {
t.Fatal(err)
}
return cmp < 0
})
sort.Slice(rRows, func(i, j int) bool {
cmp, err := rRows[i].Compare(inputTypes, &da, rMatchedCols, &evalCtx, rRows[j])
if err != nil {
t.Fatal(err)
}
return cmp < 0
})

mjSpec := &distsqlpb.MergeJoinerSpec{
LeftOrdering: distsqlpb.Ordering{Columns: lOrderingCols},
RightOrdering: distsqlpb.Ordering{Columns: rOrderingCols},
Type: joinType,
}
pspec := &distsqlpb.ProcessorSpec{
Input: []distsqlpb.InputSyncSpec{{ColumnTypes: inputTypes}, {ColumnTypes: inputTypes}},
Core: distsqlpb.ProcessorCoreUnion{MergeJoiner: mjSpec},
}
if err := verifyColOperator(false /* anyOrder */, [][]types.T{inputTypes, inputTypes}, []sqlbase.EncDatumRows{lRows, rRows}, append(inputTypes, inputTypes...), pspec); err != nil {
t.Fatal(err)
mjSpec := &distsqlpb.MergeJoinerSpec{
LeftOrdering: distsqlpb.Ordering{Columns: lOrderingCols},
RightOrdering: distsqlpb.Ordering{Columns: rOrderingCols},
Type: joinType,
}
pspec := &distsqlpb.ProcessorSpec{
Input: []distsqlpb.InputSyncSpec{{ColumnTypes: inputTypes}, {ColumnTypes: inputTypes}},
Core: distsqlpb.ProcessorCoreUnion{MergeJoiner: mjSpec},
}
if err := verifyColOperator(false /* anyOrder */, [][]types.T{inputTypes, inputTypes}, []sqlbase.EncDatumRows{lRows, rRows}, append(inputTypes, inputTypes...), pspec); err != nil {
fmt.Printf("--- seed = %d run = %d ---\n", seed, run)
t.Fatal(err)
}
}
}
}
}
Expand Down
24 changes: 18 additions & 6 deletions pkg/sql/exec/execgen/cmd/execgen/mergejoiner_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ type selPermutation struct {
}

type joinTypeInfo struct {
IsInner bool
IsLeftOuter bool
String string
IsInner bool
IsLeftOuter bool
IsRightOuter bool

String string
}

func genMergeJoinOps(wr io.Writer) error {
Expand Down Expand Up @@ -72,12 +74,21 @@ func genMergeJoinOps(wr io.Writer) error {
leftUnmatchedGroupSwitch := makeFunctionRegex("_LEFT_UNMATCHED_GROUP_SWITCH", 1)
s = leftUnmatchedGroupSwitch.ReplaceAllString(s, `{{template "leftUnmatchedGroupSwitch" buildDict "Global" $ "JoinType" $1}}`)

rightUnmatchedGroupSwitch := makeFunctionRegex("_RIGHT_UNMATCHED_GROUP_SWITCH", 1)
s = rightUnmatchedGroupSwitch.ReplaceAllString(s, `{{template "rightUnmatchedGroupSwitch" buildDict "Global" $ "JoinType" $1}}`)

nullFromLeftSwitch := makeFunctionRegex("_NULL_FROM_LEFT_SWITCH", 1)
s = nullFromLeftSwitch.ReplaceAllString(s, `{{template "nullFromLeftSwitch" buildDict "Global" $ "JoinType" $1}}`)

nullFromRightSwitch := makeFunctionRegex("_NULL_FROM_RIGHT_SWITCH", 1)
s = nullFromRightSwitch.ReplaceAllString(s, `{{template "nullFromRightSwitch" buildDict "Global" $ "JoinType" $1}}`)

incrementLeftSwitch := makeFunctionRegex("_INCREMENT_LEFT_SWITCH", 4)
s = incrementLeftSwitch.ReplaceAllString(s, `{{template "incrementLeftSwitch" buildDict "Global" $ "JoinType" $1 "Sel" $2 "MJOverload" $3 "lHasNulls" $4}}`)

incrementRightSwitch := makeFunctionRegex("_INCREMENT_RIGHT_SWITCH", 4)
s = incrementRightSwitch.ReplaceAllString(s, `{{template "incrementRightSwitch" buildDict "Global" $ "JoinType" $1 "Sel" $2 "MJOverload" $3 "rHasNulls" $4}}`)

processNotLastGroupInColumnSwitch := makeFunctionRegex("_PROCESS_NOT_LAST_GROUP_IN_COLUMN_SWITCH", 1)
s = processNotLastGroupInColumnSwitch.ReplaceAllString(s, `{{template "processNotLastGroupInColumnSwitch" buildDict "Global" $ "JoinType" $1}}`)

Expand All @@ -93,9 +104,6 @@ func genMergeJoinOps(wr io.Writer) error {
rightSwitch := makeFunctionRegex("_RIGHT_SWITCH", 2)
s = rightSwitch.ReplaceAllString(s, `{{template "rightSwitch" buildDict "Global" $ "IsSel" $1 "HasNulls" $2 }}`)

nullInBufferedGroupSwitch := makeFunctionRegex("_NULL_IN_BUFFERED_GROUP_SWITCH", 1)
s = nullInBufferedGroupSwitch.ReplaceAllString(s, `{{template "nullInBufferedGroupSwitch" buildDict "Global" $ "JoinType" $1}}`)

assignEqRe := makeFunctionRegex("_ASSIGN_EQ", 3)
s = assignEqRe.ReplaceAllString(s, `{{.Eq.Assign $1 $2 $3}}`)

Expand Down Expand Up @@ -162,6 +170,10 @@ func genMergeJoinOps(wr io.Writer) error {
IsLeftOuter: true,
String: "LeftOuter",
},
{
IsRightOuter: true,
String: "RightOuter",
},
}

return tmpl.Execute(wr, struct {
Expand Down
41 changes: 5 additions & 36 deletions pkg/sql/exec/mergejoiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ type group struct {
// nullGroup indicates whether the output corresponding to the group should
// consist of all nulls.
nullGroup bool
// unmatched indicates that the row in the group does not have matching rows
// unmatched indicates that the rows in the group do not have matching rows
// from the other side (i.e. other side's group will be a null group).
// NOTE: at the moment, the assumption is that such group will consist of a
// single row.
// TODO(yuzefovich): update the logic if the assumption ever changes.
// NOTE: during the probing phase, the assumption is that such group will
// consist of a single row.
unmatched bool
}

Expand Down Expand Up @@ -174,9 +173,6 @@ var _ Operator = &feedOperator{}
// The second pass is where the groups and their associated cross products are
// materialized into the full output.

// TODO(georgeutsin): Add outer joins functionality and templating to support
// different equality types.

// Two buffers are used, one for the group on the left table and one for the
// group on the right table. These buffers are only used if the group ends with
// a batch, to make sure that we don't miss any cross product entries while
Expand Down Expand Up @@ -204,6 +200,8 @@ func NewMergeJoinOp(
return &mergeJoinInnerOp{base}, err
case sqlbase.JoinType_LEFT_OUTER:
return &mergeJoinLeftOuterOp{base}, err
case sqlbase.JoinType_RIGHT_OUTER:
return &mergeJoinRightOuterOp{base}, err
default:
panic("unsupported join type")
}
Expand Down Expand Up @@ -459,35 +457,6 @@ func (o *mergeJoinBase) setBuilderSourceToBufferedGroup() {
o.proberState.rBufferedGroup.needToReset = true
}

// exhaustLeftSourceForLeftOuter sets up the builder state for emitting
// remaining tuples on the left with nulls for the right side of output. It
// should only be called once the right source has been exhausted, and if
// we're doing LEFT OUTER join.
func (o *mergeJoinBase) exhaustLeftSourceForLeftOuter() {
// The capacity of builder state lGroups and rGroups is always at least 1
// given the init.
o.builderState.lGroups = o.builderState.lGroups[:1]
o.builderState.lGroups[0] = group{
rowStartIdx: o.proberState.lIdx,
rowEndIdx: o.proberState.lLength,
numRepeats: 1,
toBuild: o.proberState.lLength - o.proberState.lIdx,
unmatched: true,
}
o.builderState.rGroups = o.builderState.rGroups[:1]
o.builderState.rGroups[0] = group{
rowStartIdx: o.proberState.lIdx,
rowEndIdx: o.proberState.lLength,
numRepeats: 1,
toBuild: o.proberState.lLength - o.proberState.lIdx,
nullGroup: true,
}
o.builderState.lBatch = o.proberState.lBatch
o.builderState.rBatch = o.proberState.rBatch

o.proberState.lIdx = o.proberState.lLength
}

// build creates the cross product, and writes it to the output member.
func (o *mergeJoinBase) build() {
if o.output.Width() != 0 {
Expand Down
Loading

0 comments on commit e41c2da

Please sign in to comment.