diff --git a/query/graphql/parser/commit.go b/query/graphql/parser/commit.go index af3fc72427..bfafb23b9a 100644 --- a/query/graphql/parser/commit.go +++ b/query/graphql/parser/commit.go @@ -45,8 +45,6 @@ type CommitSelect struct { Limit *Limit OrderBy *OrderBy - Counts []PropertyTransformation - Sums []PropertyTransformation Fields []Selection @@ -73,21 +71,12 @@ func (c CommitSelect) GetSelections() []Selection { return c.Fields } -func (s *CommitSelect) AddCount(transformationDefinition PropertyTransformation) { - s.Counts = append(s.Counts, transformationDefinition) -} - -func (s *CommitSelect) AddSum(transformationDefinition PropertyTransformation) { - s.Sums = append(s.Sums, transformationDefinition) -} - func (c CommitSelect) ToSelect() *Select { return &Select{ Name: c.Name, Alias: c.Alias, Limit: c.Limit, OrderBy: c.OrderBy, - Counts: c.Counts, Statement: c.Statement, Fields: c.Fields, Root: CommitSelection, @@ -131,7 +120,5 @@ func parseCommitSelect(field *ast.Field) (*CommitSelect, error) { var err error commit.Fields, err = parseSelectFields(commit.GetRoot(), field.SelectionSet) - parseAggregates(commit) - return commit, err } diff --git a/query/graphql/parser/query.go b/query/graphql/parser/query.go index 2866ad91b8..245e0a4814 100644 --- a/query/graphql/parser/query.go +++ b/query/graphql/parser/query.go @@ -47,6 +47,11 @@ var ReservedFields = map[string]bool{ DocKeyFieldName: true, } +var aggregates = map[string]struct{}{ + CountFieldName: {}, + SumFieldName: {}, +} + type Query struct { Queries []*OperationDefinition Mutations []*OperationDefinition @@ -82,8 +87,6 @@ type Selection interface { type baseSelect interface { Selection - AddCount(transformationDefinition PropertyTransformation) - AddSum(transformationDefinition PropertyTransformation) } // Select is a complex Field with strong typing @@ -105,8 +108,6 @@ type Select struct { Limit *Limit OrderBy *OrderBy GroupBy *GroupBy - Counts []PropertyTransformation - Sums []PropertyTransformation Fields []Selection @@ -134,14 +135,6 @@ func (s Select) GetAlias() string { return s.Alias } -func (s *Select) AddCount(transformationDefinition PropertyTransformation) { - s.Counts = append(s.Counts, transformationDefinition) -} - -func (s *Select) AddSum(transformationDefinition PropertyTransformation) { - s.Sums = append(s.Sums, transformationDefinition) -} - // Field implements Selection type Field struct { Name string @@ -178,14 +171,6 @@ type GroupBy struct { Fields []string } -// Contains mapping information between a source and destination properties -type PropertyTransformation struct { - // Where the result of transformation should be written to - Destination string - // Where the data to be transformed should be read from - Source []string -} - type SortDirection string const ( @@ -398,11 +383,6 @@ func parseSelect(rootType SelectionType, field *ast.Field) (*Select, error) { return nil, err } - err = parseAggregates(slct) - if err != nil { - return nil, err - } - return slct, err } @@ -413,7 +393,10 @@ func parseSelectFields(root SelectionType, fields *ast.SelectionSet) ([]Selectio switch node := selection.(type) { case *ast.Field: if node.SelectionSet == nil { // regular field - f := parseField(root, node) + f, err := parseField(i, root, node) + if err != nil { + return nil, err + } selections[i] = f } else { // sub type with extra fields subroot := root @@ -435,16 +418,31 @@ func parseSelectFields(root SelectionType, fields *ast.SelectionSet) ([]Selectio // parseField simply parses the Name/Alias // into a Field type -func parseField(root SelectionType, field *ast.Field) *Field { +func parseField(i int, root SelectionType, field *ast.Field) (*Field, error) { + var name string + var alias string + + if _, isAggregate := aggregates[field.Name.Value]; isAggregate { + name = fmt.Sprintf("_agg%v", i) + if field.Alias == nil { + alias = field.Name.Value + } else { + alias = field.Alias.Value + } + } else { + name = field.Name.Value + if field.Alias != nil { + alias = field.Alias.Value + } + } + f := &Field{ Root: root, - Name: field.Name.Value, + Name: name, Statement: field, + Alias: alias, } - if field.Alias != nil { - f.Alias = field.Alias.Value - } - return f + return f, nil } func parseAPIQuery(field *ast.Field) (Selection, error) { @@ -456,49 +454,20 @@ func parseAPIQuery(field *ast.Field) (Selection, error) { } } -// Parses requested aggregates, creating a virtual name (alias) if an alias is not provided to allow for multiple aggregate -// fields. Mutates any aggregate field properties on the select, and adds the result onto the given select object. -func parseAggregates(slct baseSelect) error { - for i, field := range slct.GetSelections() { - switch field.GetName() { - case CountFieldName: - err, transformation := parseAggregate(i, field) - if err != nil { - return err - } - slct.AddCount(transformation) - case SumFieldName: - err, transformation := parseAggregate(i, field) - if err != nil { - return err - } - slct.AddSum(transformation) - } - } - - return nil -} - -// Parses the given aggregate field selector, mutating the given field and returning the resultant PropertyTransformation -func parseAggregate(i int, field Selection) (error, PropertyTransformation) { - virtualName := fmt.Sprintf("_agg%v", i) - f := field.(*Field) - if f.Alias == "" { - f.Alias = f.Name - } - f.Name = virtualName - +// Returns the source of the aggregate as requested by the consumer +func (field Field) GetAggregateSource() ([]string, error) { var path []string - if len(f.Statement.Arguments) == 0 { + + if len(field.Statement.Arguments) == 0 { path = []string{} } else { - switch arguementValue := f.Statement.Arguments[0].Value.GetValue().(type) { + switch arguementValue := field.Statement.Arguments[0].Value.GetValue().(type) { case string: path = []string{arguementValue} case []*ast.ObjectField: if len(arguementValue) == 0 { //Note: Scalar arrays will hit this clause and should be handled appropriately (not adding now as not testable in a time-efficient manner) - return fmt.Errorf("Unexpected error: aggregate field contained no child field selector"), PropertyTransformation{} + return []string{}, fmt.Errorf("Unexpected error: aggregate field contained no child field selector") } innerPath := arguementValue[0].Value.GetValue() if innerPathStringValue, isString := innerPath.(string); isString { @@ -510,5 +479,5 @@ func parseAggregate(i int, field Selection) (error, PropertyTransformation) { } } - return nil, PropertyTransformation{Destination: virtualName, Source: path} + return path, nil } diff --git a/query/graphql/planner/count.go b/query/graphql/planner/count.go index 824c601d02..3c473edd4c 100644 --- a/query/graphql/planner/count.go +++ b/query/graphql/planner/count.go @@ -28,10 +28,15 @@ type countNode struct { virtualFieldId string } -func (p *Planner) Count(c *parser.PropertyTransformation) (*countNode, error) { +func (p *Planner) Count(field parser.Field) (*countNode, error) { + source, err := field.GetAggregateSource() + if err != nil { + return nil, err + } + var sourceProperty string - if len(c.Source) == 1 { - sourceProperty = c.Source[0] + if len(source) == 1 { + sourceProperty = source[0] } else { sourceProperty = "" } @@ -39,7 +44,7 @@ func (p *Planner) Count(c *parser.PropertyTransformation) (*countNode, error) { return &countNode{ p: p, sourceProperty: sourceProperty, - virtualFieldId: c.Destination, + virtualFieldId: field.Name, }, nil } @@ -74,3 +79,5 @@ func (n *countNode) Values() map[string]interface{} { func (n *countNode) Next() (bool, error) { return n.plan.Next() } + +func (n *countNode) SetPlan(p planNode) { n.plan = p } diff --git a/query/graphql/planner/planner.go b/query/graphql/planner/planner.go index e44e35c625..cd9852a40e 100644 --- a/query/graphql/planner/planner.go +++ b/query/graphql/planner/planner.go @@ -224,15 +224,15 @@ func (p *Planner) expandSelectTopNodePlan(plan *selectTopNode, parentPlan *selec return nil } -func (p *Planner) expandAggregatePlans(plan *selectTopNode) { - for _, countPlan := range plan.countPlans { - countPlan.plan = plan.plan - plan.plan = countPlan - } +type aggregateNode interface { + planNode + SetPlan(plan planNode) +} - for _, sumPlan := range plan.sumPlans { - sumPlan.plan = plan.plan - plan.plan = sumPlan +func (p *Planner) expandAggregatePlans(plan *selectTopNode) { + for _, aggregate := range plan.aggregates { + aggregate.SetPlan(plan.plan) + plan.plan = aggregate } } @@ -305,7 +305,7 @@ func (p *Planner) expandLimitPlan(plan *selectTopNode, parentPlan *selectTopNode // if this is a child node, and the parent select has an aggregate then we need to // replace the hard limit with a render limit to allow the full set of child records // to be aggregated - if parentPlan != nil && (len(parentPlan.countPlans) > 0 || len(parentPlan.sumPlans) > 0) { + if parentPlan != nil /*&& (len(parentPlan.countPlans) > 0 || len(parentPlan.sumPlans) > 0)*/ { renderLimit, err := p.RenderLimit(&parser.Limit{ Offset: l.offset, Limit: l.limit, diff --git a/query/graphql/planner/select.go b/query/graphql/planner/select.go index c700564801..5755298f14 100644 --- a/query/graphql/planner/select.go +++ b/query/graphql/planner/select.go @@ -27,9 +27,8 @@ type selectTopNode struct { group *groupNode sort *sortNode limit planNode - countPlans []*countNode - sumPlans []*sumNode render *renderNode + aggregates []aggregateNode // top of the plan graph plan planNode @@ -131,13 +130,13 @@ func (n *selectNode) Close() error { // creating scanNodes, typeIndexJoinNodes, and splitting // the necessary filters. Its designed to work with the // planner.Select construction call. -func (n *selectNode) initSource(parsed *parser.Select) error { +func (n *selectNode) initSource(parsed *parser.Select) ([]aggregateNode, error) { if parsed.CollectionName == "" { parsed.CollectionName = parsed.Name } sourcePlan, err := n.p.getSource(parsed.CollectionName) if err != nil { - return err + return nil, err } n.source = sourcePlan.plan n.origSource = sourcePlan.plan @@ -171,7 +170,7 @@ func (n *selectNode) initSource(parsed *parser.Select) error { return n.initFields(parsed) } -func (n *selectNode) initFields(parsed *parser.Select) error { +func (n *selectNode) initFields(parsed *parser.Select) ([]aggregateNode, error) { // re-organize the fields slice into reverse-alphabetical // this makes sure the reserved database fields that start with // a "_" end up at the end. So if/when we build our MultiNode @@ -180,57 +179,78 @@ func (n *selectNode) initFields(parsed *parser.Select) error { return !(strings.Compare(parsed.Fields[i].GetName(), parsed.Fields[j].GetName()) < 0) }) + aggregates := []aggregateNode{} // loop over the sub type // at the moment, we're only testing a single sub selection for _, field := range parsed.Fields { - if subtype, ok := field.(*parser.Select); ok { + switch f := field.(type) { + case *parser.Select: // @todo: check select type: // - TypeJoin // - commitScan - if subtype.Name == "_version" { // reserved sub type for object queries + if f.Name == "_version" { // reserved sub type for object queries commitSlct := &parser.CommitSelect{ - Name: subtype.Name, - Alias: subtype.Alias, + Name: f.Name, + Alias: f.Alias, Type: parser.LatestCommits, - Fields: subtype.Fields, + Fields: f.Fields, } commitPlan, err := n.p.CommitSelect(commitSlct) if err != nil { - return err + return nil, err } if err := n.addSubPlan(field.GetName(), commitPlan); err != nil { - return err + return nil, err } - } else if subtype.Root == parser.ObjectSelection { - if subtype.Name == parser.GroupFieldName { - n.groupSelect = subtype + } else if f.Root == parser.ObjectSelection { + if f.Name == parser.GroupFieldName { + n.groupSelect = f } else { - n.addTypeIndexJoin(subtype) + n.addTypeIndexJoin(f) } } - } - } + case *parser.Field: + var plan aggregateNode + var aggregateError error + + switch f.Statement.Name.Value { + case parser.CountFieldName: + plan, aggregateError = n.p.Count(*f) + case parser.SumFieldName: + plan, aggregateError = n.p.Sum(&n.sourceInfo, *f) + } - // Handle aggregates of child collection that are not rendered - for _, count := range parsed.Counts { - n.joinAggregatedChild(parsed, count) - } + if aggregateError != nil { + return nil, aggregateError + } + + if plan != nil { + aggregates = append(aggregates, plan) - for _, sum := range parsed.Sums { - n.joinAggregatedChild(parsed, sum) + aggregateError = n.joinAggregatedChild(parsed, *f) + if aggregateError != nil { + return nil, aggregateError + } + } + } } - return nil + return aggregates, nil } // Join any child collections required by the given transformation if the child collections have not been requested for render by the consumer -func (n *selectNode) joinAggregatedChild(parsed *parser.Select, transformation parser.PropertyTransformation) { - if len(transformation.Source) == 0 { - return +func (n *selectNode) joinAggregatedChild(parsed *parser.Select, field parser.Field) error { + source, err := field.GetAggregateSource() + if err != nil { + return err + } + + if len(source) == 0 { + return nil } - fieldName := transformation.Source[0] + fieldName := source[0] hasChildProperty := false for _, field := range parsed.Fields { if fieldName == field.GetName() { @@ -253,6 +273,8 @@ func (n *selectNode) joinAggregatedChild(parsed *parser.Select, transformation p n.addTypeIndexJoin(subtype) } } + + return nil } func (n *selectNode) addTypeIndexJoin(subSelect *parser.Select) error { @@ -318,7 +340,8 @@ func (p *Planner) SelectFromSource(parsed *parser.Select, source planNode, fromC s.sourceInfo = sourceInfo{desc} } - if err := s.initFields(parsed); err != nil { + aggregates, err := s.initFields(parsed) + if err != nil { return nil, err } @@ -337,32 +360,13 @@ func (p *Planner) SelectFromSource(parsed *parser.Select, source planNode, fromC return nil, err } - countPlans := []*countNode{} - for _, countItem := range parsed.Counts { - countNode, err := p.Count(&countItem) - if err != nil { - return nil, err - } - countPlans = append(countPlans, countNode) - } - - sumPlans := []*sumNode{} - for _, sumItem := range parsed.Sums { - sumNode, err := p.Sum(&s.sourceInfo, &sumItem) - if err != nil { - return nil, err - } - sumPlans = append(sumPlans, sumNode) - } - top := &selectTopNode{ source: s, render: p.render(parsed), limit: limitPlan, sort: sortPlan, group: groupPlan, - countPlans: countPlans, - sumPlans: sumPlans, + aggregates: aggregates, } return top, nil } @@ -376,7 +380,8 @@ func (p *Planner) Select(parsed *parser.Select) (planNode, error) { groupBy := parsed.GroupBy s.renderInfo = &renderInfo{} - if err := s.initSource(parsed); err != nil { + aggregates, err := s.initSource(parsed) + if err != nil { return nil, err } @@ -395,32 +400,13 @@ func (p *Planner) Select(parsed *parser.Select) (planNode, error) { return nil, err } - countPlans := []*countNode{} - for _, countItem := range parsed.Counts { - countNode, err := p.Count(&countItem) - if err != nil { - return nil, err - } - countPlans = append(countPlans, countNode) - } - - sumPlans := []*sumNode{} - for _, sumItem := range parsed.Sums { - sumNode, err := p.Sum(&s.sourceInfo, &sumItem) - if err != nil { - return nil, err - } - sumPlans = append(sumPlans, sumNode) - } - top := &selectTopNode{ source: s, render: p.render(parsed), limit: limitPlan, sort: sortPlan, group: groupPlan, - countPlans: countPlans, - sumPlans: sumPlans, + aggregates: aggregates, } return top, nil } diff --git a/query/graphql/planner/sum.go b/query/graphql/planner/sum.go index 08ae2ccbde..3a7e307d9c 100644 --- a/query/graphql/planner/sum.go +++ b/query/graphql/planner/sum.go @@ -27,14 +27,19 @@ type sumNode struct { virtualFieldId string } -func (p *Planner) Sum(sourceInfo *sourceInfo, c *parser.PropertyTransformation) (*sumNode, error) { +func (p *Planner) Sum(sourceInfo *sourceInfo, field parser.Field) (*sumNode, error) { var sourceProperty string var sourceCollection string var isFloat bool - if len(c.Source) == 1 { + source, err := field.GetAggregateSource() + if err != nil { + return nil, err + } + + if len(source) == 1 { // If path length is one - we are summing an inline array - sourceCollection = c.Source[0] + sourceCollection = source[0] sourceProperty = "" fieldDescription, fieldDescriptionFound := sourceInfo.collectionDescription.GetField(sourceCollection) @@ -43,10 +48,10 @@ func (p *Planner) Sum(sourceInfo *sourceInfo, c *parser.PropertyTransformation) } isFloat = fieldDescription.Kind == base.FieldKind_FLOAT_ARRAY - } else if len(c.Source) == 2 { + } else if len(source) == 2 { // If path length is two, we are summing a group or a child relationship - sourceCollection = c.Source[0] - sourceProperty = c.Source[1] + sourceCollection = source[0] + sourceProperty = source[1] var childFieldDescription base.FieldDescription if sourceCollection == parser.GroupFieldName { @@ -83,7 +88,7 @@ func (p *Planner) Sum(sourceInfo *sourceInfo, c *parser.PropertyTransformation) isFloat: isFloat, sourceCollection: sourceCollection, sourceProperty: sourceProperty, - virtualFieldId: c.Destination, + virtualFieldId: field.Name, }, nil } @@ -143,3 +148,5 @@ func (n *sumNode) Values() map[string]interface{} { func (n *sumNode) Next() (bool, error) { return n.plan.Next() } + +func (n *sumNode) SetPlan(p planNode) { n.plan = p }