Skip to content

Commit

Permalink
Merge pull request cockroachdb#18386 from richardwu/stddev-var-local-…
Browse files Browse the repository at this point in the history
…final

distsql: refactor distributed aggregation to support arbitrary local aggregate inputs in final stage
  • Loading branch information
richardwu authored Sep 15, 2017
2 parents a8e90b9 + 95cff88 commit 1c6b721
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 105 deletions.
83 changes: 55 additions & 28 deletions pkg/sql/distsql_physical_planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -1372,67 +1372,94 @@ func (dsp *distSQLPlanner) addAggregators(
//
// Count the total number of aggregation in the local/final stages and keep
// track of whether any of them needs a final rendering.
numAgg := 0
nLocalAgg := 0
nFinalAgg := 0
needRender := false
for _, e := range aggregations {
info := distsqlplan.DistAggregationTable[e.Func]
numAgg += len(info.LocalStage)
nLocalAgg += len(info.LocalStage)
nFinalAgg += len(info.FinalStage)
if info.FinalRendering != nil {
needRender = true
}
}

localAgg := make([]distsqlrun.AggregatorSpec_Aggregation, numAgg, numAgg+len(groupCols))
intermediateTypes := make([]sqlbase.ColumnType, numAgg, numAgg+len(groupCols))
finalAgg := make([]distsqlrun.AggregatorSpec_Aggregation, numAgg)
localAgg := make([]distsqlrun.AggregatorSpec_Aggregation, nLocalAgg, nLocalAgg+len(groupCols))
intermediateTypes := make([]sqlbase.ColumnType, nLocalAgg, nLocalAgg+len(groupCols))
finalAgg := make([]distsqlrun.AggregatorSpec_Aggregation, nFinalAgg)
finalGroupCols := make([]uint32, len(groupCols))
var finalPreRenderTypes []sqlbase.ColumnType
if needRender {
finalPreRenderTypes = make([]sqlbase.ColumnType, numAgg)
finalPreRenderTypes = make([]sqlbase.ColumnType, nFinalAgg)
}

// Each aggregation can have multiple aggregations in the local/final
// stages. We concatenate all these into localAgg/finalAgg; aIdx is an index
// inside localAgg/finalAgg.
aIdx := 0
// stages. We concatenate all these into localAgg/finalAgg; localIdx is an index
// inside localAgg and finalIdx is an index inside finalAgg.
localIdx := 0
finalIdx := 0
for _, e := range aggregations {
info := distsqlplan.DistAggregationTable[e.Func]
for i, localFunc := range info.LocalStage {
localAgg[aIdx] = distsqlrun.AggregatorSpec_Aggregation{
// firstLocalIdxCurAgg points to the first local index in the current
// aggregation e.
// This is used when computing AggregatorSpec_Aggregation.ColIdx
// for FinalStage aggregators since their input column indices
// are specified as a relative offset to the first local aggregator's
// index.
// This is required since we append all localAggs and finalAggs across
// all aggregations for the given plan together.
firstLocalIdxCurAgg := uint32(localIdx)
// First prepare and spec local aggregations.
// Note the planNode first feeds the input (inputTypes) into the local aggregators.
for _, localFunc := range info.LocalStage {
localAgg[localIdx] = distsqlrun.AggregatorSpec_Aggregation{
Func: localFunc,
ColIdx: e.ColIdx,
FilterColIdx: e.FilterColIdx,
}

var localResultType sqlbase.ColumnType

argTypes := make([]sqlbase.ColumnType, len(e.ColIdx))
for i, c := range e.ColIdx {
argTypes[i] = inputTypes[c]
}

var err error
_, localResultType, err = distsqlrun.GetAggregateInfo(localFunc, argTypes...)
_, intermediateTypes[localIdx], err = distsqlrun.GetAggregateInfo(localFunc, argTypes...)
if err != nil {
return err
}
intermediateTypes[aIdx] = localResultType
localIdx++
}

finalAgg[aIdx] = distsqlrun.AggregatorSpec_Aggregation{
Func: info.FinalStage[i],
// The input of final expression aIdx is the output of the
// local expression aIdx.
ColIdx: []uint32{uint32(aIdx)},
for _, finalInfo := range info.FinalStage {
// The input of the final aggregators are specified as the indices of the local aggregation
// values. We need to offset firstLocalIdxCurAgg by the relative indices specified in finalInfo.LocalIdxs.
argIdxs := make([]uint32, len(finalInfo.LocalIdxs))
for i, c := range finalInfo.LocalIdxs {
argIdxs[i] = c + firstLocalIdxCurAgg
}
finalAgg[finalIdx] = distsqlrun.AggregatorSpec_Aggregation{
Func: finalInfo.Fn,
ColIdx: argIdxs,
}

if needRender {
_, finalPreRenderTypes[aIdx], err = distsqlrun.GetAggregateInfo(
info.FinalStage[i], localResultType,
argTypes := make([]sqlbase.ColumnType, len(finalInfo.LocalIdxs))
for i, c := range finalInfo.LocalIdxs {
// We want to access the corresponding local output type
// for the current aggregation e. c is the offset from
// the first local aggregator for the current aggregation e.
argTypes[i] = intermediateTypes[firstLocalIdxCurAgg+c]
}
var err error
_, finalPreRenderTypes[finalIdx], err = distsqlrun.GetAggregateInfo(
finalInfo.Fn, argTypes...,
)
if err != nil {
return err
}
}
aIdx++
finalIdx++
}
}

Expand Down Expand Up @@ -1481,21 +1508,21 @@ func (dsp *distSQLPlanner) addAggregators(
// Build rendering expressions.
renderExprs := make([]distsqlrun.Expression, len(aggregations))
h := distsqlplan.MakeTypeIndexedVarHelper(finalPreRenderTypes)
// aIdx is an index inside finalAgg. It is used to keep track of the
// finalIdx is an index inside finalAgg. It is used to keep track of the
// finalAgg results that correspond to each aggregation.
aIdx := 0
finalIdx := 0
for i, e := range aggregations {
info := distsqlplan.DistAggregationTable[e.Func]
if info.FinalRendering == nil {
renderExprs[i] = distsqlplan.MakeExpression(h.IndexedVar(aIdx), nil)
renderExprs[i] = distsqlplan.MakeExpression(h.IndexedVar(finalIdx), nil)
} else {
expr, err := info.FinalRendering(&h, aIdx)
expr, err := info.FinalRendering(&h, finalIdx)
if err != nil {
return err
}
renderExprs[i] = distsqlplan.MakeExpression(expr, nil)
}
aIdx += len(info.LocalStage)
finalIdx += len(info.FinalStage)
}
finalAggPost.RenderExprs = renderExprs
}
Expand Down
100 changes: 84 additions & 16 deletions pkg/sql/distsqlplan/aggregator_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
)

// FinalStageInfo is a wrapper around an aggregation function performed
// in the final stage of distributed aggregations that allows us to specify the
// corresponding inputs from the local aggregations by their indices in the LocalStage.
type FinalStageInfo struct {
Fn distsqlrun.AggregatorSpec_Func
// Specifies the ordered slice of outputs from local aggregations to propagate
// as inputs to Fn. This must be ordered according to the underlying aggregate builtin
// arguments signature found in aggregate_builtins.go.
LocalIdxs []uint32
}

// DistAggregationInfo is a blueprint for planning distributed aggregations. It
// describes two stages - a local stage performs local aggregations wherever
// data is available and generates partial results, and a final stage aggregates
Expand All @@ -45,10 +56,11 @@ type DistAggregationInfo struct {
// the same input.
LocalStage []distsqlrun.AggregatorSpec_Func

// The final stage consists of the same number of aggregations as the local
// stage (the input of each one is the corresponding result from each instance
// of the local stage).
FinalStage []distsqlrun.AggregatorSpec_Func
// The final stage consists of one or more aggregations that take in an
// arbitrary number of inputs from the local stages. The inputs are ordered and
// mapped by the indices of the local aggregations in LocalStage (specified by
// inlocalIdxs).
FinalStage []FinalStageInfo

// An optional rendering expression used to obtain the final result; required
// if there is more than one aggregation in each of the stages.
Expand All @@ -70,52 +82,102 @@ type DistAggregationInfo struct {
FinalRendering func(h *parser.IndexedVarHelper, varIdxOffset int) (parser.TypedExpr, error)
}

// Convenient value for FinalStageInfo.LocalIdxs when there is only one aggregation
// function in each of the LocalStage and FinalStage. Otherwise, specify the explicit
// index corresponding to the local stage.
var passThroughLocalIdxs = []uint32{0}

// DistAggregationTable is DistAggregationInfo look-up table. Functions that
// don't have an entry in the table are not optimized with a local stage.
var DistAggregationTable = map[distsqlrun.AggregatorSpec_Func]DistAggregationInfo{
distsqlrun.AggregatorSpec_IDENT: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_IDENT},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_IDENT},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_IDENT,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_BOOL_AND: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_BOOL_AND},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_BOOL_AND},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_BOOL_AND,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_BOOL_OR: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_BOOL_OR},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_BOOL_OR},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_BOOL_OR,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_COUNT: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_COUNT},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_SUM_INT},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_SUM_INT,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_COUNT_ROWS: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_COUNT_ROWS},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_SUM_INT},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_SUM_INT,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_MAX: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_MAX},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_MAX},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_MAX,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_MIN: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_MIN},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_MIN},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_MIN,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_SUM: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_SUM},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_SUM},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_SUM,
LocalIdxs: passThroughLocalIdxs,
},
},
},

distsqlrun.AggregatorSpec_XOR_AGG: {
LocalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_XOR_AGG},
FinalStage: []distsqlrun.AggregatorSpec_Func{distsqlrun.AggregatorSpec_XOR_AGG},
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_XOR_AGG,
LocalIdxs: passThroughLocalIdxs,
},
},
},

// AVG is more tricky than the ones above; we need two intermediate values in
Expand All @@ -130,9 +192,15 @@ var DistAggregationTable = map[distsqlrun.AggregatorSpec_Func]DistAggregationInf
distsqlrun.AggregatorSpec_SUM,
distsqlrun.AggregatorSpec_COUNT,
},
FinalStage: []distsqlrun.AggregatorSpec_Func{
distsqlrun.AggregatorSpec_SUM,
distsqlrun.AggregatorSpec_SUM_INT,
FinalStage: []FinalStageInfo{
{
Fn: distsqlrun.AggregatorSpec_SUM,
LocalIdxs: []uint32{0},
},
{
Fn: distsqlrun.AggregatorSpec_SUM_INT,
LocalIdxs: []uint32{1},
},
},
FinalRendering: func(h *parser.IndexedVarHelper, varIdxOffset int) (parser.TypedExpr, error) {
sum := h.IndexedVar(varIdxOffset)
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/distsqlplan/aggregator_funcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ func checkDistAggregationInfo(
localAggregations[i] = distsqlrun.AggregatorSpec_Aggregation{Func: fn, ColIdx: []uint32{0}}
}
finalAggregations := make([]distsqlrun.AggregatorSpec_Aggregation, numIntermediary)
for i, fn := range info.FinalStage {
for i, finalInfo := range info.FinalStage {
// Each local aggregation feeds into a final aggregation.
finalAggregations[i] = distsqlrun.AggregatorSpec_Aggregation{
Func: fn,
ColIdx: []uint32{uint32(i)},
Func: finalInfo.Fn,
ColIdx: finalInfo.LocalIdxs,
}
}

Expand Down
40 changes: 32 additions & 8 deletions pkg/sql/distsqlrun/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,30 @@ func (ag *aggregator) accumulateRows(ctx context.Context) (err error) {
continue
}
}
var value parser.Datum
if len(a.ColIdx) != 0 {
c := a.ColIdx[0]
// Extract the corresponding arguments from the row to feed into the
// aggregate function.
// Most functions require at most one argument thus we separate
// the first argument and allocation of (if applicable) a variadic
// collection of arguments thereafter.
var firstArg parser.Datum
var otherArgs parser.Datums
if len(a.ColIdx) > 1 {
otherArgs = make(parser.Datums, len(a.ColIdx)-1)
}
isFirstArg := true
for j, c := range a.ColIdx {
if err := row[c].EnsureDecoded(&ag.datumAlloc); err != nil {
return err
}
value = row[c].Datum
if isFirstArg {
firstArg = row[c].Datum
isFirstArg = false
continue
}
otherArgs[j-1] = row[c].Datum
}
if err := ag.funcs[i].add(ctx, encoded, value); err != nil {

if err := ag.funcs[i].add(ctx, encoded, firstArg, otherArgs); err != nil {
return err
}
}
Expand Down Expand Up @@ -335,12 +350,21 @@ func (ag *aggregator) newAggregateFuncHolder(
}
}

func (a *aggregateFuncHolder) add(ctx context.Context, bucket []byte, d parser.Datum) error {
func (a *aggregateFuncHolder) add(
ctx context.Context, bucket []byte, firstArg parser.Datum, otherArgs parser.Datums,
) error {
if a.seen != nil {
encoded, err := sqlbase.EncodeDatum(bucket, d)
encoded, err := sqlbase.EncodeDatum(bucket, firstArg)
if err != nil {
return err
}
// Encode additional arguments if necessary.
if otherArgs != nil {
encoded, err = sqlbase.EncodeDatums(bucket, otherArgs)
if err != nil {
return err
}
}
if _, ok := a.seen[string(encoded)]; ok {
// skip
return nil
Expand All @@ -367,7 +391,7 @@ func (a *aggregateFuncHolder) add(ctx context.Context, bucket []byte, d parser.D
a.buckets[string(bucket)] = impl
}

return impl.Add(ctx, d)
return impl.Add(ctx, firstArg, otherArgs...)
}

func (a *aggregateFuncHolder) get(bucket string) (parser.Datum, error) {
Expand Down
Loading

0 comments on commit 1c6b721

Please sign in to comment.