Skip to content

Commit

Permalink
feat: Add AggregateRel.ToBuilder() (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime authored Feb 11, 2025
1 parent 00a91c9 commit 3be7ba3
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 2 deletions.
17 changes: 15 additions & 2 deletions plan/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,19 +799,32 @@ func extractGroup(expressionReferences []uint32, combinationIndex int) []uint32
return group
}

// addRollup constructs the rollup grouping strategy from the provided grouping references.
// AddRollup constructs the rollup grouping strategy from the provided grouping references.
func (arb *AggregateRelBuilder) AddRollup(groupingReferences []uint32) {
for i := len(groupingReferences); i > 0; i-- {
rollupSet := groupingReferences[:i]
arb.groupingReferences = append(arb.groupingReferences, rollupSet)
}
}

// addGroupingSet adds a new grouping set based on the provided grouping references.
// AddGroupingSet adds a new grouping set based on the provided grouping references.
func (arb *AggregateRelBuilder) AddGroupingSet(groupingReferences []uint32) {
arb.groupingReferences = append(arb.groupingReferences, groupingReferences)
}

func (arb *AggregateRelBuilder) ReplaceInput(rel *Rel) {
arb.input = *rel
}

func (arb *AggregateRelBuilder) ClearMeasures() {
arb.measures = nil
}

func (arb *AggregateRelBuilder) ClearGrouping() {
arb.groupingExpressions = nil
arb.groupingReferences = nil
}

func (arb *AggregateRelBuilder) Build() (*AggregateRel, error) {
if err := arb.validate(); err != nil {
return nil, err
Expand Down
74 changes: 74 additions & 0 deletions plan/plan_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1935,6 +1935,80 @@ func TestAggregateRelBuilder(t *testing.T) {
_, err = arb.Build()
assert.Error(t, err)
})

t.Run("ReplaceInput", func(t *testing.T) {
b := plan.NewBuilderDefault()

e := b.GetExprBuilder()
expr1, _ := e.ScalarFunc(addID).Args(
e.Wrap(expr.NewLiteral(int32(3), false)),
e.Wrap(expr.NewLiteral(int32(3), false))).BuildExpr()

aggCount, err := b.AggregateFn(extensions.SubstraitDefaultURIPrefix+"functions_aggregate_generic.yaml",
"count", nil)
require.NoError(t, err)
arb := b.GetRelBuilder().AggregateRel(b.NamedScan([]string{"test"}, baseSchema), []plan.AggRelMeasure{b.Measure(aggCount, nil)})

ref1 := arb.AddExpression(expr1)
err = arb.AddCube([]uint32{ref1})
assert.NoError(t, err)

newInput := plan.Rel(b.NamedScan([]string{"test"}, baseSchema))
arb.ReplaceInput(&newInput)

aggregateRel, err := arb.Build()
assert.NoError(t, err)
assert.Equal(t, newInput, aggregateRel.Input())
})

t.Run("CleanGroupings cleans groupings", func(t *testing.T) {
b := plan.NewBuilderDefault()

e := b.GetExprBuilder()
expr1, _ := e.ScalarFunc(addID).Args(
e.Wrap(expr.NewLiteral(int32(3), false)),
e.Wrap(expr.NewLiteral(int32(3), false))).BuildExpr()

aggCount, err := b.AggregateFn(extensions.SubstraitDefaultURIPrefix+"functions_aggregate_generic.yaml",
"count", nil)
require.NoError(t, err)
arb := b.GetRelBuilder().AggregateRel(b.NamedScan([]string{"test"}, baseSchema), []plan.AggRelMeasure{b.Measure(aggCount, nil)})

ref1 := arb.AddExpression(expr1)
err = arb.AddCube([]uint32{ref1})
assert.NoError(t, err)

arb.ClearGrouping()

aggregateRel, err := arb.Build()
assert.NoError(t, err)
assert.ElementsMatch(t, [][]uint32{}, aggregateRel.GroupingReferences())
assert.ElementsMatch(t, []expr.Expression{}, aggregateRel.GroupingExpressions())
})

t.Run("CleanMeasures cleans measures", func(t *testing.T) {
b := plan.NewBuilderDefault()

e := b.GetExprBuilder()
expr1, _ := e.ScalarFunc(addID).Args(
e.Wrap(expr.NewLiteral(int32(3), false)),
e.Wrap(expr.NewLiteral(int32(3), false))).BuildExpr()

aggCount, err := b.AggregateFn(extensions.SubstraitDefaultURIPrefix+"functions_aggregate_generic.yaml",
"count", nil)
require.NoError(t, err)
arb := b.GetRelBuilder().AggregateRel(b.NamedScan([]string{"test"}, baseSchema), []plan.AggRelMeasure{b.Measure(aggCount, nil)})

ref1 := arb.AddExpression(expr1)
err = arb.AddCube([]uint32{ref1})
assert.NoError(t, err)

arb.ClearMeasures()

aggregateRel, err := arb.Build()
assert.NoError(t, err)
assert.ElementsMatch(t, []plan.AggRelMeasure{}, aggregateRel.Measures())
})
}

func expectedJsonWithIceberg(metadataURI string, snapshot plan.IcebergSnapshot) string {
Expand Down
15 changes: 15 additions & 0 deletions plan/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,21 @@ func (ar *AggregateRel) Copy(newInputs ...Rel) (Rel, error) {
return &aggregate, nil
}

// ToBuilder returns an AggregateRelBuilder made from the current AggregateRel.
// Copies are made to avoid changes to the original data.
func (ar *AggregateRel) ToBuilder() *AggregateRelBuilder {
newInput := ar.input
newMeasures := make([]AggRelMeasure, len(ar.measures))
copy(newMeasures, ar.measures)
newGroupingExpressions := make([]expr.Expression, len(ar.groupingExpressions))
copy(newGroupingExpressions, ar.groupingExpressions)
newGroupingReferences := make([][]uint32, len(ar.groupingReferences))
copy(newGroupingReferences, ar.groupingReferences)
return &AggregateRelBuilder{
input: newInput, measures: newMeasures,
groupingExpressions: newGroupingExpressions, groupingReferences: newGroupingReferences}
}

func (ar *AggregateRel) rewriteAggregateFunc(rewriteFunc RewriteFunc, f *expr.AggregateFunction) (*expr.AggregateFunction, error) {
if f == nil {
return f, nil
Expand Down
22 changes: 22 additions & 0 deletions plan/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,28 @@ func TestRelations_Copy(t *testing.T) {
}
}

func TestAggregateRelToBuilder(t *testing.T) {
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
}
aggregateFn, err := expr.NewAggregateFunc(extReg,
aggregateFnID, nil, types.AggInvocationAll,
types.AggPhaseInitialToResult, nil, createPrimitiveFloat(1.0))
require.NoError(t, err)

aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1),
groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)},
groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{measure: aggregateFn, filter: expr.NewPrimitiveLiteral(false, false)}}}

builder := aggregateRel.ToBuilder()
got, err := builder.Build()
assert.NoError(t, err)
assert.Equal(t, aggregateRel, got)
}

// fakeRel is a pretend relation that allows direct control of its direct output schema.
type fakeRel struct {
RelCommon
Expand Down

0 comments on commit 3be7ba3

Please sign in to comment.