diff --git a/checker/cost.go b/checker/cost.go index 3a6eb0d2..1b325eac 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -28,15 +28,20 @@ import ( // CostEstimator estimates the sizes of variable length input data and the costs of functions. type CostEstimator interface { - // EstimateSize returns a SizeEstimate for the given AstNode, or nil if - // the estimator has no estimate to provide. The size is equivalent to the result of the CEL `size()` function: - // length of strings and bytes, number of map entries or number of list items. - // EstimateSize is only called for AstNodes where - // CEL does not know the size; EstimateSize is not called for values defined inline in CEL where the size - // is already obvious to CEL. + // EstimateSize returns a SizeEstimate for the given AstNode, or nil if the estimator has no + // estimate to provide. + // + // The size is equivalent to the result of the CEL `size()` function: + // * Number of unicode characters in a string + // * Number of bytes in a sequence + // * Number of map entries or number of list items. + // + // EstimateSize is only called for AstNodes where CEL does not know the size; EstimateSize is not + // called for values defined inline in CEL where the size is already obvious to CEL. EstimateSize(element AstNode) *SizeEstimate - // EstimateCallCost returns the estimated cost of an invocation, or nil if - // the estimator has no estimate to provide. + + // EstimateCallCost returns the estimated cost of an invocation, or nil if the estimator has no + // estimate to provide. EstimateCallCost(function, overloadID string, target *AstNode, args []AstNode) *CallEstimate } @@ -44,6 +49,7 @@ type CostEstimator interface { // The ResultSize should only be provided if the call results in a map, list, string or bytes. type CallEstimate struct { CostEstimate + ResultSize *SizeEstimate } @@ -53,10 +59,13 @@ type AstNode interface { // represent type directly reachable from the provided type declarations. // The first path element is a variable. All subsequent path elements are one of: field name, '@items', '@keys', '@values'. Path() []string + // Type returns the deduced type of the AstNode. Type() *types.Type + // Expr returns the expression of the AstNode. Expr() ast.Expr + // ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression. // For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings // and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no @@ -84,36 +93,7 @@ func (e astNode) Expr() ast.Expr { } func (e astNode) ComputedSize() *SizeEstimate { - if e.derivedSize != nil { - return e.derivedSize - } - var v uint64 - switch e.expr.Kind() { - case ast.LiteralKind: - switch ck := e.expr.AsLiteral().(type) { - case types.String: - // converting to runes here is an O(n) operation, but - // this is consistent with how size is computed at runtime, - // and how the language definition defines string size - v = uint64(len([]rune(ck))) - case types.Bytes: - v = uint64(len(ck)) - case types.Bool, types.Double, types.Duration, - types.Int, types.Timestamp, types.Uint, - types.Null: - v = uint64(1) - default: - return nil - } - case ast.ListKind: - v = uint64(e.expr.AsList().Size()) - case ast.MapKind: - v = uint64(e.expr.AsMap().Size()) - default: - return nil - } - - return &SizeEstimate{Min: v, Max: v} + return e.derivedSize } // SizeEstimate represents an estimated size of a variable length string, bytes, map or list. @@ -121,6 +101,16 @@ type SizeEstimate struct { Min, Max uint64 } +// UnknownSizeEstimate returns a size between 0 and max uint +func UnknownSizeEstimate() SizeEstimate { + return unknownSizeEstimate +} + +// FixedSizeEstimate returns a size estimate with a fixed min and max range. +func FixedSizeEstimate(size uint64) SizeEstimate { + return SizeEstimate{Min: size, Max: size} +} + // Add adds to another SizeEstimate and returns the sum. // If add would result in an uint64 overflow, the result is math.MaxUint64. func (se SizeEstimate) Add(sizeEstimate SizeEstimate) SizeEstimate { @@ -175,12 +165,22 @@ type CostEstimate struct { Min, Max uint64 } +// UnknownCostEstimate returns a cost with an unknown impact. +func UnknownCostEstimate() CostEstimate { + return unknownCostEstimate +} + +// FixedCostEstimate returns a cost with a fixed min and max range. +func FixedCostEstimate(cost uint64) CostEstimate { + return CostEstimate{Min: cost, Max: cost} +} + // Add adds the costs and returns the sum. // If add would result in an uint64 overflow for the min or max, the value is set to math.MaxUint64. func (ce CostEstimate) Add(cost CostEstimate) CostEstimate { return CostEstimate{ - addUint64NoOverflow(ce.Min, cost.Min), - addUint64NoOverflow(ce.Max, cost.Max), + Min: addUint64NoOverflow(ce.Min, cost.Min), + Max: addUint64NoOverflow(ce.Max, cost.Max), } } @@ -188,8 +188,8 @@ func (ce CostEstimate) Add(cost CostEstimate) CostEstimate { // If multiply would result in an uint64 overflow, the result is math.MaxUint64. func (ce CostEstimate) Multiply(cost CostEstimate) CostEstimate { return CostEstimate{ - multiplyUint64NoOverflow(ce.Min, cost.Min), - multiplyUint64NoOverflow(ce.Max, cost.Max), + Min: multiplyUint64NoOverflow(ce.Min, cost.Min), + Max: multiplyUint64NoOverflow(ce.Max, cost.Max), } } @@ -197,8 +197,8 @@ func (ce CostEstimate) Multiply(cost CostEstimate) CostEstimate { // nearest integer of the result, rounded up. func (ce CostEstimate) MultiplyByCostFactor(costPerUnit float64) CostEstimate { return CostEstimate{ - multiplyByCostFactor(ce.Min, costPerUnit), - multiplyByCostFactor(ce.Max, costPerUnit), + Min: multiplyByCostFactor(ce.Min, costPerUnit), + Max: multiplyByCostFactor(ce.Max, costPerUnit), } } @@ -245,49 +245,6 @@ func multiplyByCostFactor(x uint64, y float64) uint64 { return uint64(ceil) } -var ( - selectAndIdentCost = CostEstimate{Min: common.SelectAndIdentCost, Max: common.SelectAndIdentCost} - constCost = CostEstimate{Min: common.ConstCost, Max: common.ConstCost} - - createListBaseCost = CostEstimate{Min: common.ListCreateBaseCost, Max: common.ListCreateBaseCost} - createMapBaseCost = CostEstimate{Min: common.MapCreateBaseCost, Max: common.MapCreateBaseCost} - createMessageBaseCost = CostEstimate{Min: common.StructCreateBaseCost, Max: common.StructCreateBaseCost} -) - -type coster struct { - // exprPath maps from Expr Id to field path. - exprPath map[int64][]string - // iterRanges tracks the iterRange of each iterVar. - iterRanges iterRangeScopes - // computedSizes tracks the computed sizes of call results. - computedSizes map[int64]SizeEstimate - checkedAST *ast.AST - estimator CostEstimator - overloadEstimators map[string]FunctionEstimator - // presenceTestCost will either be a zero or one based on whether has() macros count against cost computations. - presenceTestCost CostEstimate -} - -// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names. -type iterRangeScopes map[string][]int64 - -func (vs iterRangeScopes) push(varName string, expr ast.Expr) { - vs[varName] = append(vs[varName], expr.ID()) -} - -func (vs iterRangeScopes) pop(varName string) { - varStack := vs[varName] - vs[varName] = varStack[:len(varStack)-1] -} - -func (vs iterRangeScopes) peek(varName string) (int64, bool) { - varStack := vs[varName] - if len(varStack) > 0 { - return varStack[len(varStack)-1], true - } - return 0, false -} - // CostOption configures flags which affect cost computations. type CostOption func(*coster) error @@ -300,7 +257,7 @@ func PresenceTestHasCost(hasCost bool) CostOption { c.presenceTestCost = selectAndIdentCost return nil } - c.presenceTestCost = CostEstimate{Min: 0, Max: 0} + c.presenceTestCost = FixedCostEstimate(0) return nil } } @@ -325,10 +282,11 @@ func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEs checkedAST: checked, estimator: estimator, overloadEstimators: map[string]FunctionEstimator{}, - exprPath: map[int64][]string{}, - iterRanges: map[string][]int64{}, + exprPaths: map[int64][]string{}, + localVars: make(scopes), computedSizes: map[int64]SizeEstimate{}, - presenceTestCost: CostEstimate{Min: 1, Max: 1}, + computedEntrySizes: map[int64]entrySizeEstimate{}, + presenceTestCost: FixedCostEstimate(1), } for _, opt := range opts { err := opt(c) @@ -339,6 +297,165 @@ func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEs return c.cost(checked.Expr()), nil } +type coster struct { + // exprPaths maps from Expr Id to field path. + exprPaths map[int64][]string + // localVars tracks the local and iteration variables assigned during evaluation. + localVars scopes + // computedSizes tracks the computed sizes of call results. + computedSizes map[int64]SizeEstimate + // computedEntrySizes tracks the size of list and map entries + computedEntrySizes map[int64]entrySizeEstimate + + checkedAST *ast.AST + estimator CostEstimator + overloadEstimators map[string]FunctionEstimator + // presenceTestCost will either be a zero or one based on whether has() macros count against cost computations. + presenceTestCost CostEstimate +} + +// entrySizeEstimate captures the container kind and associated key/index and value SizeEstimate values. +// +// An entrySizeEstimate only exists if both the key/index and the value have SizeEstimate values, otherwise +// a nil entrySizeEstimate should be used. +type entrySizeEstimate struct { + containerKind types.Kind + key SizeEstimate + val SizeEstimate +} + +// container returns the container kind (list or map) of the entry. +func (s *entrySizeEstimate) container() types.Kind { + if s == nil { + return types.UnknownKind + } + return s.containerKind +} + +// keySize returns the SizeEstimate for the key if one exists. +func (s *entrySizeEstimate) keySize() *SizeEstimate { + if s == nil { + return nil + } + return &s.key +} + +// valSize returns the SizeEstimate for the value if one exists. +func (s *entrySizeEstimate) valSize() *SizeEstimate { + if s == nil { + return nil + } + return &s.val +} + +func (s *entrySizeEstimate) union(other *entrySizeEstimate) *entrySizeEstimate { + if s == nil || other == nil { + return nil + } + sk := s.key.Union(other.key) + sv := s.val.Union(other.val) + return &entrySizeEstimate{ + containerKind: s.containerKind, + key: sk, + val: sv, + } +} + +// localVar captures the local variable size and entrySize estimates if they exist for variables +type localVar struct { + exprID int64 + path []string + size *SizeEstimate + entrySize *entrySizeEstimate +} + +// scopes is a stack of variable name to integer id stack to handle scopes created by cel.bind() like macros +type scopes map[string][]*localVar + +func (s scopes) push(varName string, expr ast.Expr, path []string, size *SizeEstimate, entrySize *entrySizeEstimate) { + s[varName] = append(s[varName], &localVar{ + exprID: expr.ID(), + path: path, + size: size, + entrySize: entrySize, + }) +} + +func (s scopes) pop(varName string) { + varStack := s[varName] + s[varName] = varStack[:len(varStack)-1] +} + +func (s scopes) peek(varName string) (*localVar, bool) { + varStack := s[varName] + if len(varStack) > 0 { + return varStack[len(varStack)-1], true + } + return nil, false +} + +func (c *coster) pushIterKey(varName string, rangeExpr ast.Expr) { + entrySize := c.computeEntrySize(rangeExpr) + size := entrySize.keySize() + path := c.getPath(rangeExpr) + container := entrySize.container() + if container == types.UnknownKind { + container = c.getType(rangeExpr).Kind() + } + subpath := "@keys" + if container == types.ListKind { + subpath = "@indices" + } + c.localVars.push(varName, rangeExpr, append(path, subpath), size, nil) +} + +func (c *coster) pushIterValue(varName string, rangeExpr ast.Expr) { + entrySize := c.computeEntrySize(rangeExpr) + size := entrySize.valSize() + path := c.getPath(rangeExpr) + container := entrySize.container() + if container == types.UnknownKind { + container = c.getType(rangeExpr).Kind() + } + subpath := "@values" + if container == types.ListKind { + subpath = "@items" + } + c.localVars.push(varName, rangeExpr, append(path, subpath), size, nil) +} + +func (c *coster) pushIterSingle(varName string, rangeExpr ast.Expr) { + entrySize := c.computeEntrySize(rangeExpr) + size := entrySize.keySize() + subpath := "@keys" + container := entrySize.container() + if container == types.UnknownKind { + container = c.getType(rangeExpr).Kind() + } + if container == types.ListKind { + size = entrySize.valSize() + subpath = "@items" + } + path := c.getPath(rangeExpr) + c.localVars.push(varName, rangeExpr, append(path, subpath), size, nil) +} + +func (c *coster) pushLocalVar(varName string, e ast.Expr) { + path := c.getPath(e) + // note: retrieve the entry size for the local variable based on the size of the binding expression + // since the binding expression could be a list or map, the entry size should also be propagated + entrySize := c.computeEntrySize(e) + c.localVars.push(varName, e, path, c.computeSize(e), entrySize) +} + +func (c *coster) peekLocalVar(varName string) (*localVar, bool) { + return c.localVars.peek(varName) +} + +func (c *coster) popLocalVar(varName string) { + c.localVars.pop(varName) +} + func (c *coster) cost(e ast.Expr) CostEstimate { if e == nil { return CostEstimate{} @@ -360,7 +477,11 @@ func (c *coster) cost(e ast.Expr) CostEstimate { case ast.StructKind: cost = c.costCreateStruct(e) case ast.ComprehensionKind: - cost = c.costComprehension(e) + if c.isBind(e) { + cost = c.costBind(e) + } else { + cost = c.costComprehension(e) + } default: return CostEstimate{} } @@ -370,17 +491,11 @@ func (c *coster) cost(e ast.Expr) CostEstimate { func (c *coster) costIdent(e ast.Expr) CostEstimate { identName := e.AsIdent() // build and track the field path - if iterRange, ok := c.iterRanges.peek(identName); ok { - switch c.checkedAST.GetType(iterRange).Kind() { - case types.ListKind: - c.addPath(e, append(c.exprPath[iterRange], "@items")) - case types.MapKind: - c.addPath(e, append(c.exprPath[iterRange], "@keys")) - } + if v, ok := c.peekLocalVar(identName); ok { + c.addPath(e, v.path) } else { c.addPath(e, []string{identName}) } - return selectAndIdentCost } @@ -405,14 +520,18 @@ func (c *coster) costSelect(e ast.Expr) CostEstimate { // build and track the field path c.addPath(e, append(c.getPath(sel.Operand()), sel.FieldName())) - return sum } func (c *coster) costCall(e ast.Expr) CostEstimate { + // Dyn is just a way to disable type-checking, so return the cost of 1 with the cost of the argument + if dynEstimate := c.maybeUnwrapDynCall(e); dynEstimate != nil { + return *dynEstimate + } + + // Continue estimating the cost of all other calls. call := e.AsCall() args := call.Args() - var sum CostEstimate argTypes := make([]AstNode, len(args)) @@ -435,7 +554,7 @@ func (c *coster) costCall(e ast.Expr) CostEstimate { fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0} var resultSize *SizeEstimate for _, overload := range overloadIDs { - overloadCost := c.functionCost(call.FunctionName(), overload, &targetType, argTypes, argCosts) + overloadCost := c.functionCost(e, call.FunctionName(), overload, &targetType, argTypes, argCosts) fnCost = fnCost.Union(overloadCost.CostEstimate) if overloadCost.ResultSize != nil { if resultSize == nil { @@ -449,37 +568,73 @@ func (c *coster) costCall(e ast.Expr) CostEstimate { switch overload { case overloads.IndexList: if len(args) > 0 { + // note: assigning resultSize here could be redundant with the path-based lookup later + resultSize = c.computeEntrySize(args[0]).valSize() c.addPath(e, append(c.getPath(args[0]), "@items")) } case overloads.IndexMap: if len(args) > 0 { + resultSize = c.computeEntrySize(args[0]).valSize() c.addPath(e, append(c.getPath(args[0]), "@values")) } } + if resultSize == nil { + resultSize = c.computeSize(e) + } } - if resultSize != nil { - c.computedSizes[e.ID()] = *resultSize - } + c.setSize(e, resultSize) return sum.Add(fnCost) } +func (c *coster) maybeUnwrapDynCall(e ast.Expr) *CostEstimate { + call := e.AsCall() + if call.FunctionName() != "dyn" { + return nil + } + arg := call.Args()[0] + argCost := c.cost(arg) + c.copySizeEstimates(e, arg) + callCost := FixedCostEstimate(1).Add(argCost) + return &callCost +} + func (c *coster) costCreateList(e ast.Expr) CostEstimate { create := e.AsList() var sum CostEstimate + itemSize := SizeEstimate{Min: math.MaxUint64, Max: 0} + if create.Size() == 0 { + itemSize.Min = 0 + } for _, e := range create.Elements() { sum = sum.Add(c.cost(e)) + is := c.sizeOrUnknown(e) + itemSize = itemSize.Union(is) } + c.setEntrySize(e, &entrySizeEstimate{containerKind: types.ListKind, key: FixedSizeEstimate(1), val: itemSize}) return sum.Add(createListBaseCost) } func (c *coster) costCreateMap(e ast.Expr) CostEstimate { mapVal := e.AsMap() var sum CostEstimate + keySize := SizeEstimate{Min: math.MaxUint64, Max: 0} + valSize := SizeEstimate{Min: math.MaxUint64, Max: 0} + if mapVal.Size() == 0 { + valSize.Min = 0 + keySize.Min = 0 + } for _, ent := range mapVal.Entries() { entry := ent.AsMapEntry() sum = sum.Add(c.cost(entry.Key())) sum = sum.Add(c.cost(entry.Value())) + // Compute the key size range + ks := c.sizeOrUnknown(entry.Key()) + keySize = keySize.Union(ks) + // Compute the value size range + vs := c.sizeOrUnknown(entry.Value()) + valSize = valSize.Union(vs) } + c.setEntrySize(e, &entrySizeEstimate{containerKind: types.MapKind, key: keySize, val: valSize}) return sum.Add(createMapBaseCost) } @@ -498,43 +653,76 @@ func (c *coster) costComprehension(e ast.Expr) CostEstimate { var sum CostEstimate sum = sum.Add(c.cost(comp.IterRange())) sum = sum.Add(c.cost(comp.AccuInit())) + c.pushLocalVar(comp.AccuVar(), comp.AccuInit()) - // Track the iterRange of each IterVar for field path construction - c.iterRanges.push(comp.IterVar(), comp.IterRange()) + // Track the iterRange of each IterVar and AccuVar for field path construction + if comp.HasIterVar2() { + c.pushIterKey(comp.IterVar(), comp.IterRange()) + c.pushIterValue(comp.IterVar2(), comp.IterRange()) + } else { + c.pushIterSingle(comp.IterVar(), comp.IterRange()) + } + + // Determine the cost for each element in the loop loopCost := c.cost(comp.LoopCondition()) stepCost := c.cost(comp.LoopStep()) - c.iterRanges.pop(comp.IterVar()) - sum = sum.Add(c.cost(comp.Result())) - rangeCnt := c.sizeEstimate(c.newAstNode(comp.IterRange())) - c.computedSizes[e.ID()] = rangeCnt + // Clear the intermediate variable tracking. + c.popLocalVar(comp.IterVar()) + if comp.HasIterVar2() { + c.popLocalVar(comp.IterVar2()) + } + + // Determine the result cost. + sum = sum.Add(c.cost(comp.Result())) + c.localVars.pop(comp.AccuVar()) + // Estimate the cost of the loop. + rangeCnt := c.sizeOrUnknown(comp.IterRange()) rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost)) sum = sum.Add(rangeCost) + switch k := comp.AccuInit().Kind(); k { + case ast.LiteralKind: + c.setSize(e, c.computeSize(comp.AccuInit())) + case ast.ListKind, ast.MapKind: + c.setSize(e, &rangeCnt) + // For a step which produces a container value, it will have an entry size associated + // with its expression id. + if stepEntrySize := c.computeEntrySize(comp.LoopStep()); stepEntrySize != nil { + c.setEntrySize(e, stepEntrySize) + break + } + } return sum } -func (c *coster) sizeEstimate(t AstNode) SizeEstimate { - if l := t.ComputedSize(); l != nil { - return *l - } - if l := c.estimator.EstimateSize(t); l != nil { - return *l - } - // return an estimate of 1 for return types of set - // lengths, since strings/bytes/more complex objects could be of - // variable length - if isScalar(t.Type()) { - // TODO: since the logic for size estimation is split between - // ComputedSize and isScalar, changing one will likely require changing - // the other, so they should be merged in the future if possible - return SizeEstimate{Min: 1, Max: 1} - } - return SizeEstimate{Min: 0, Max: math.MaxUint64} +func (c *coster) isBind(e ast.Expr) bool { + comp := e.AsComprehension() + iterRange := comp.IterRange() + loopCond := comp.LoopCondition() + return iterRange.Kind() == ast.ListKind && iterRange.AsList().Size() == 0 && + loopCond.Kind() == ast.LiteralKind && loopCond.AsLiteral() == types.False && + comp.AccuVar() != parser.AccumulatorName +} + +func (c *coster) costBind(e ast.Expr) CostEstimate { + comp := e.AsComprehension() + var sum CostEstimate + // Binds are lazily initialized, so we retain the cost of an empty iteration range. + sum = sum.Add(c.cost(comp.IterRange())) + sum = sum.Add(c.cost(comp.AccuInit())) + + c.pushLocalVar(comp.AccuVar(), comp.AccuInit()) + sum = sum.Add(c.cost(comp.Result())) + c.popLocalVar(comp.AccuVar()) + + // Associate the bind output size with the result size. + c.copySizeEstimates(e, comp.Result()) + return sum } -func (c *coster) functionCost(function, overloadID string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate { +func (c *coster) functionCost(e ast.Expr, function, overloadID string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate { argCostSum := func() CostEstimate { var sum CostEstimate for _, a := range argCosts { @@ -559,35 +747,42 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args case overloads.ExtFormatString: if target != nil { // ResultSize not calculated because we can't bound the max size. - return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} + return CallEstimate{ + CostEstimate: c.sizeOrUnknown(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } case overloads.StringToBytes: if len(args) == 1 { - sz := c.sizeEstimate(args[0]) + sz := c.sizeOrUnknown(args[0]) // ResultSize max is when each char converts to 4 bytes. - return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min, Max: sz.Max * 4}} + return CallEstimate{ + CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), + ResultSize: &SizeEstimate{Min: sz.Min, Max: sz.Max * 4}} } case overloads.BytesToString: if len(args) == 1 { - sz := c.sizeEstimate(args[0]) + sz := c.sizeOrUnknown(args[0]) // ResultSize min is when 4 bytes convert to 1 char. - return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min / 4, Max: sz.Max}} + return CallEstimate{ + CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), + ResultSize: &SizeEstimate{Min: sz.Min / 4, Max: sz.Max}} } case overloads.ExtQuoteString: if len(args) == 1 { - sz := c.sizeEstimate(args[0]) + sz := c.sizeOrUnknown(args[0]) // ResultSize max is when each char is escaped. 2 quote chars always added. - return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min + 2, Max: sz.Max*2 + 2}} + return CallEstimate{ + CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), + ResultSize: &SizeEstimate{Min: sz.Min + 2, Max: sz.Max*2 + 2}} } case overloads.StartsWithString, overloads.EndsWithString: if len(args) == 1 { - return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} + return CallEstimate{CostEstimate: c.sizeOrUnknown(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } case overloads.InList: // If a list is composed entirely of constant values this is O(1), but we don't account for that here. // We just assume all list containment checks are O(n). if len(args) == 2 { - return CallEstimate{CostEstimate: c.sizeEstimate(args[1]).MultiplyByCostFactor(1).Add(argCostSum())} + return CallEstimate{CostEstimate: c.sizeOrUnknown(args[1]).MultiplyByCostFactor(1).Add(argCostSum())} } // O(nm) functions case overloads.MatchesString: @@ -595,19 +790,19 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args if target != nil && len(args) == 1 { // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 // in case where string is empty but regex is still expensive. - strCost := c.sizeEstimate(*target).Add(SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor) + strCost := c.sizeOrUnknown(*target).Add(SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor) // We don't know how many expressions are in the regex, just the string length (a huge // improvement here would be to somehow get a count the number of expressions in the regex or // how many states are in the regex state machine and use that to measure regex cost). // For now, we're making a guess that each expression in a regex is typically at least 4 chars // in length. - regexCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor) + regexCost := c.sizeOrUnknown(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor) return CallEstimate{CostEstimate: strCost.Multiply(regexCost).Add(argCostSum())} } case overloads.ContainsString: if target != nil && len(args) == 1 { - strCost := c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor) - substrCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor) + strCost := c.sizeOrUnknown(*target).MultiplyByCostFactor(common.StringTraversalCostFactor) + substrCost := c.sizeOrUnknown(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor) return CallEstimate{CostEstimate: strCost.Multiply(substrCost).Add(argCostSum())} } case overloads.LogicalOr, overloads.LogicalAnd: @@ -617,7 +812,9 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args argCost := CostEstimate{Min: lhs.Min, Max: lhs.Add(rhs).Max} return CallEstimate{CostEstimate: argCost} case overloads.Conditional: - size := c.sizeEstimate(args[1]).Union(c.sizeEstimate(args[2])) + size := c.sizeOrUnknown(args[1]).Union(c.sizeOrUnknown(args[2])) + resultEntrySize := c.computeEntrySize(args[1].Expr()).union(c.computeEntrySize(args[2].Expr())) + c.setEntrySize(e, resultEntrySize) conditionalCost := argCosts[0] ifTrueCost := argCosts[1] ifFalseCost := argCosts[2] @@ -625,13 +822,19 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args return CallEstimate{CostEstimate: argCost, ResultSize: &size} case overloads.AddString, overloads.AddBytes, overloads.AddList: if len(args) == 2 { - lhsSize := c.sizeEstimate(args[0]) - rhsSize := c.sizeEstimate(args[1]) + lhsSize := c.sizeOrUnknown(args[0]) + rhsSize := c.sizeOrUnknown(args[1]) resultSize := lhsSize.Add(rhsSize) + rhsEntrySize := c.computeEntrySize(args[0].Expr()) + lhsEntrySize := c.computeEntrySize(args[1].Expr()) + resultEntrySize := rhsEntrySize.union(lhsEntrySize) + if resultEntrySize != nil { + c.setEntrySize(e, resultEntrySize) + } switch overloadID { case overloads.AddList: // list concatenation is O(1), but we handle it here to track size - return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum()), ResultSize: &resultSize} + return CallEstimate{CostEstimate: FixedCostEstimate(1).Add(argCostSum()), ResultSize: &resultSize} default: return CallEstimate{CostEstimate: resultSize.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &resultSize} } @@ -639,8 +842,8 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString, overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes, overloads.Equals, overloads.NotEquals: - lhsCost := c.sizeEstimate(args[0]) - rhsCost := c.sizeEstimate(args[1]) + lhsCost := c.sizeOrUnknown(args[0]) + rhsCost := c.sizeOrUnknown(args[1]) min := uint64(0) smallestMax := lhsCost.Max if rhsCost.Max < smallestMax { @@ -650,14 +853,16 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args min = 1 } // equality of 2 scalar values results in a cost of 1 - return CallEstimate{CostEstimate: CostEstimate{Min: min, Max: smallestMax}.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} + return CallEstimate{ + CostEstimate: CostEstimate{Min: min, Max: smallestMax}.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), + } } // O(1) functions // See CostTracker.costCall for more details about O(1) cost calculations // Benchmarks suggest that most of the other operations take +/- 50% of a base cost unit // which on an Intel xeon 2.20GHz CPU is 50ns. - return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())} + return CallEstimate{CostEstimate: FixedCostEstimate(1).Add(argCostSum())} } func (c *coster) getType(e ast.Expr) *types.Type { @@ -665,11 +870,16 @@ func (c *coster) getType(e ast.Expr) *types.Type { } func (c *coster) getPath(e ast.Expr) []string { - return c.exprPath[e.ID()] + if e.Kind() == ast.IdentKind { + if v, found := c.peekLocalVar(e.AsIdent()); found { + return v.path[:] + } + } + return c.exprPaths[e.ID()][:] } func (c *coster) addPath(e ast.Expr, path []string) { - c.exprPath[e.ID()] = path + c.exprPaths[e.ID()] = path } func isAccumulatorVar(name string) bool { @@ -682,15 +892,121 @@ func (c *coster) newAstNode(e ast.Expr) *astNode { // only provide paths to root vars; omit accumulator vars path = nil } - var derivedSize *SizeEstimate - if size, ok := c.computedSizes[e.ID()]; ok { - derivedSize = &size - } return &astNode{ path: path, t: c.getType(e), expr: e, - derivedSize: derivedSize} + derivedSize: c.computeSize(e)} +} + +func (c *coster) setSize(e ast.Expr, size *SizeEstimate) { + if size == nil { + return + } + // Store the computed size with the expression + c.computedSizes[e.ID()] = *size +} + +func (c *coster) sizeOrUnknown(node any) SizeEstimate { + switch v := node.(type) { + case ast.Expr: + if sz := c.computeSize(v); sz != nil { + return *sz + } + case AstNode: + if sz := v.ComputedSize(); sz != nil { + return *sz + } + } + return UnknownSizeEstimate() +} + +func (c *coster) copySizeEstimates(dst, src ast.Expr) { + c.setSize(dst, c.computeSize(src)) + c.setEntrySize(dst, c.computeEntrySize(src)) +} + +func (c *coster) computeSize(e ast.Expr) *SizeEstimate { + if size, ok := c.computedSizes[e.ID()]; ok { + return &size + } + if size := computeExprSize(e); size != nil { + return size + } + if size := computeTypeSize(c.getType(e)); size != nil { + return size + } + if e.Kind() == ast.IdentKind { + varName := e.AsIdent() + if v, ok := c.peekLocalVar(varName); ok && v.size != nil { + return v.size + } + } + node := astNode{expr: e, path: c.getPath(e), t: c.getType(e)} + if size := c.estimator.EstimateSize(node); size != nil { + // storing the computed size should reduce calls to EstimateSize() + c.computedSizes[e.ID()] = *size + return size + } + return nil +} + +func (c *coster) setEntrySize(e ast.Expr, size *entrySizeEstimate) { + if size == nil { + return + } + c.computedEntrySizes[e.ID()] = *size +} + +func (c *coster) computeEntrySize(e ast.Expr) *entrySizeEstimate { + if sz, found := c.computedEntrySizes[e.ID()]; found { + return &sz + } + if e.Kind() == ast.IdentKind { + varName := e.AsIdent() + if v, ok := c.peekLocalVar(varName); ok && v.entrySize != nil { + return v.entrySize + } + } + return nil +} + +func computeExprSize(expr ast.Expr) *SizeEstimate { + var v uint64 + switch expr.Kind() { + case ast.LiteralKind: + switch ck := expr.AsLiteral().(type) { + case types.String: + // converting to runes here is an O(n) operation, but + // this is consistent with how size is computed at runtime, + // and how the language definition defines string size + v = uint64(len([]rune(ck))) + case types.Bytes: + v = uint64(len(ck)) + case types.Bool, types.Double, types.Duration, + types.Int, types.Timestamp, types.Uint, + types.Null: + v = uint64(1) + default: + return nil + } + case ast.ListKind: + v = uint64(expr.AsList().Size()) + case ast.MapKind: + v = uint64(expr.AsMap().Size()) + default: + return nil + } + cost := FixedSizeEstimate(v) + return &cost +} + +func computeTypeSize(t *types.Type) *SizeEstimate { + if isScalar(t) { + cost := FixedSizeEstimate(1) + return &cost + } + return nil } // isScalar returns true if the given type is known to be of a constant size at @@ -698,12 +1014,27 @@ func (c *coster) newAstNode(e ast.Expr) *astNode { // in addition to protobuf.Any and protobuf.Value (their size is not knowable at compile time). func isScalar(t *types.Type) bool { switch t.Kind() { - case types.BoolKind, types.DoubleKind, types.DurationKind, types.IntKind, types.TimestampKind, types.UintKind: + case types.BoolKind, types.DoubleKind, types.DurationKind, types.IntKind, + types.NullTypeKind, types.TimestampKind, types.TypeKind, types.UintKind: return true + case types.OpaqueKind: + if t.TypeName() == "optional_type" { + return isScalar(t.Parameters()[0]) + } } return false } var ( doubleTwoTo64 = math.Ldexp(1.0, 64) + + unknownSizeEstimate = SizeEstimate{Min: 0, Max: math.MaxUint64} + unknownCostEstimate = unknownSizeEstimate.MultiplyByCostFactor(1) + + selectAndIdentCost = FixedCostEstimate(common.SelectAndIdentCost) + constCost = FixedCostEstimate(common.ConstCost) + + createListBaseCost = FixedCostEstimate(common.ListCreateBaseCost) + createMapBaseCost = FixedCostEstimate(common.MapCreateBaseCost) + createMessageBaseCost = FixedCostEstimate(common.StructCreateBaseCost) ) diff --git a/checker/cost_test.go b/checker/cost_test.go index 4437663b..fb74cdf2 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -40,7 +40,7 @@ func TestCost(t *testing.T) { nestedMap := types.NewMapType(types.StringType, allMap) zeroCost := CostEstimate{} - oneCost := CostEstimate{Min: 1, Max: 1} + oneCost := FixedCostEstimate(1) cases := []struct { name string expr string @@ -255,6 +255,11 @@ func TestCost(t *testing.T) { expr: `size("123")`, wanted: oneCost, }, + { + name: "bytes size", + expr: `size(b"123")`, + wanted: oneCost, + }, { name: "bytes to string conversion", vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)}, @@ -462,6 +467,36 @@ func TestCost(t *testing.T) { }, wanted: CostEstimate{Min: 5, Max: 5}, }, + { + name: "list size from concat", + expr: `([x, y] + list1 + list2).size()`, + vars: []*decls.VariableDecl{ + decls.NewVariable("x", types.IntType), + decls.NewVariable("y", types.IntType), + decls.NewVariable("list1", types.NewListType(types.IntType)), + decls.NewVariable("list2", types.NewListType(types.IntType)), + }, + hints: map[string]uint64{ + "list1": 10, + "list2": 20, + }, + wanted: CostEstimate{Min: 17, Max: 17}, + }, + { + name: "list cost tracking through comprehension", + expr: `[list1, list2].exists(l, l.exists(v, v.startsWith('hi')))`, + vars: []*decls.VariableDecl{ + decls.NewVariable("list1", types.NewListType(types.StringType)), + decls.NewVariable("list2", types.NewListType(types.StringType)), + }, + hints: map[string]uint64{ + "list1": 10, + "list1.@items": 64, + "list2": 20, + "list2.@items": 128, + }, + wanted: CostEstimate{Min: 21, Max: 265}, + }, { name: "str endsWith equality", expr: `str1.endsWith("abcdefghijklmnopqrstuvwxyz") == str2.endsWith("abcdefghijklmnopqrstuvwxyz")`, @@ -539,27 +574,37 @@ func TestCost(t *testing.T) { wanted: CostEstimate{Min: 61, Max: 61}, }, { - name: "nested array selection", + name: "nested map selection", expr: `{'a': [1,2], 'b': [1,2], 'c': [1,2], 'd': [1,2], 'e': [1,2]}.b`, wanted: CostEstimate{Min: 81, Max: 81}, }, { - // Estimated cost does not track the sizes of nested aggregate types - // (lists, maps, ...) and so assumes a worst case cost when an - // expression applies a comprehension to a nested aggregated type, - // even if the size information is available. - // TODO: This should be fixed. name: "comprehension on nested list", + expr: `[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]].all(y, y.all(y, y == 1))`, + wanted: CostEstimate{Min: 76, Max: 136}, + }, + { + name: "comprehension on transformed nested list", expr: `[1,2,3,4,5].map(x, [x, x]).all(y, y.all(y, y == 1))`, - wanted: CostEstimate{Min: 157, Max: 18446744073709551615}, + wanted: CostEstimate{Min: 157, Max: 217}, }, { - // Make sure we're accounting for not just the iteration range size, - // but also the overall comprehension size. The chained map calls - // will treat the result of one map as the iteration range of the other, - // so they're planned in reverse; however, the `+` should verify that - // the comprehension result has a size. - name: "comprehension size", + name: "comprehension on nested literal list", + expr: `["a", "ab", "abc", "abcd", "abcde"].map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`, + wanted: CostEstimate{Min: 157, Max: 217}, + }, + { + name: "comprehension on nested variable list", + expr: `input.map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`, + vars: []*decls.VariableDecl{decls.NewVariable("input", types.NewListType(types.StringType))}, + hints: map[string]uint64{ + "input": 5, + "input.@items": 10, + }, + wanted: CostEstimate{Min: 13, Max: 208}, + }, + { + name: "comprehension chaining with concat", expr: `[1,2,3,4,5].map(x, x).map(x, x) + [1]`, wanted: CostEstimate{Min: 173, Max: 173}, }, @@ -568,9 +613,25 @@ func TestCost(t *testing.T) { expr: `[1,2,3].all(i, i in [1,2,3].map(j, j + j))`, wanted: CostEstimate{Min: 20, Max: 230}, }, + { + name: "nested dyn comprehension", + expr: `dyn([1,2,3]).all(i, i in dyn([1,2,3]).map(j, j + j))`, + wanted: CostEstimate{Min: 21, Max: 234}, + }, + { + name: "literal map access", + expr: `{'hello': 'hi'}['hello'] != {'hello': 'bye'}['hello']`, + wanted: CostEstimate{Min: 63, Max: 63}, + }, + { + name: "literal list access", + expr: `['hello', 'hi'][0] != ['hello', 'bye'][1]`, + wanted: CostEstimate{Min: 23, Max: 23}, + }, } - for _, tc := range cases { + for _, tst := range cases { + tc := tst t.Run(tc.name, func(t *testing.T) { if tc.hints == nil { tc.hints = map[string]uint64{} diff --git a/ext/bindings_test.go b/ext/bindings_test.go index bd6ecf7b..df89eec4 100644 --- a/ext/bindings_test.go +++ b/ext/bindings_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" @@ -27,49 +28,93 @@ import ( ) var bindingTests = []struct { - expr string - parseOnly bool + name string + expr string + vars []cel.EnvOption + in map[string]any + hints map[string]uint64 + estimatedCost checker.CostEstimate + actualCost uint64 }{ - {expr: `cel.bind(a, 'hell' + 'o' + '!', [a, a, a].join(', ')) == - ['hell' + 'o' + '!', 'hell' + 'o' + '!', 'hell' + 'o' + '!'].join(', ')`}, - // Variable shadowing - {expr: `cel.bind(a, - cel.bind(a, 'world', a + '!'), - 'hello ' + a) == 'hello ' + 'world' + '!'`}, + { + name: "single bind", + expr: `cel.bind(a, 'hell' + 'o' + '!', "%s, %s, %s".format([a, a, a])) == + 'hello!, hello!, hello' + '!'`, + estimatedCost: checker.CostEstimate{Min: 30, Max: 32}, + actualCost: 32, + }, + { + name: "multiple binds", + expr: `cel.bind(a, 'hello!', + cel.bind(b, 'goodbye', + a + ' and, ' + b)) == 'hello! and, goodbye'`, + estimatedCost: checker.CostEstimate{Min: 27, Max: 28}, + actualCost: 28, + }, + { + name: "shadow binds", + expr: `cel.bind(a, + cel.bind(a, 'world', a + '!'), + 'hello ' + a) == 'hello ' + 'world' + '!'`, + estimatedCost: checker.CostEstimate{Min: 30, Max: 31}, + actualCost: 31, + }, + { + name: "nested bind with int list", + expr: `cel.bind(a, x, + cel.bind(b, a[0], + cel.bind(c, a[1], b + c))) == 10`, + vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))}, + in: map[string]any{ + "x": []int64{3, 7}, + }, + hints: map[string]uint64{ + "x": 3, + }, + estimatedCost: checker.CostEstimate{Min: 39, Max: 39}, + actualCost: 39, + }, + { + name: "nested bind with string list", + expr: `cel.bind(a, x, + cel.bind(b, a[0], + cel.bind(c, a[1], b + c))) == "threeseven"`, + vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.StringType))}, + in: map[string]any{ + "x": []string{"three", "seven"}, + }, + hints: map[string]uint64{ + "x": 3, + "x.@items": 10, + }, + estimatedCost: checker.CostEstimate{Min: 38, Max: 40}, + actualCost: 39, + }, } func TestBindings(t *testing.T) { - env, err := cel.NewEnv(Bindings(BindingsVersion(0)), Strings()) - if err != nil { - t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err) - } - for i, tst := range bindingTests { + for _, tst := range bindingTests { tc := tst - t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { var asts []*cel.Ast + opts := append([]cel.EnvOption{Bindings(BindingsVersion(0)), Strings()}, tc.vars...) + env, err := cel.NewEnv(opts...) + if err != nil { + t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err) + } pAst, iss := env.Parse(tc.expr) if iss.Err() != nil { t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) } asts = append(asts, pAst) - if !tc.parseOnly { - cAst, iss := env.Check(pAst) - if iss.Err() != nil { - t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) - } - asts = append(asts, cAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) } + testCheckCost(t, env, cAst, tc.hints, tc.estimatedCost) + asts = append(asts, cAst) for _, ast := range asts { - prg, err := env.Program(ast) - if err != nil { - t.Fatal(err) - } - out, _, err := prg.Eval(cel.NoVars()) - if err != nil { - t.Fatal(err) - } else if out.Value() != true { - t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr) - } + testEvalWithCost(t, env, ast, tc.in, tc.actualCost) } }) } diff --git a/ext/comprehensions_test.go b/ext/comprehensions_test.go index ede8a437..89a6052b 100644 --- a/ext/comprehensions_test.go +++ b/ext/comprehensions_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/types" "github.com/google/cel-go/interpreter" ) @@ -214,6 +215,184 @@ func TestTwoVarComprehensions(t *testing.T) { } } +func TestTwoVarComprehensionsCost(t *testing.T) { + tests := []struct { + name string + expr string + vars []cel.EnvOption + in map[string]any + hints map[string]uint64 + estimatedCost checker.CostEstimate + actualCost uint64 + }{ + { + name: "all list literal", + expr: `[1, 2, 3, 4].all(i, v, i < 5 && v > 0)`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 39}, + actualCost: 39, + }, + { + name: "all map literal - true", + expr: `{1: 1, 2: 2, 3: 3}.all(i, v, i < 5 && v > 0)`, + estimatedCost: checker.CostEstimate{Min: 40, Max: 52}, + actualCost: 52, + }, + { + name: "all map literal - false", + expr: `!{0: 0}.all(i, v, i < 5 && v > 0)`, + estimatedCost: checker.CostEstimate{Min: 35, Max: 39}, + actualCost: 39, + }, + { + name: "all map(int,int) variable", + expr: `m.all(i, v, i < 5 && v > 0)`, + vars: []cel.EnvOption{cel.Variable("m", cel.MapType(cel.IntType, cel.IntType))}, + hints: map[string]uint64{ + "m": 3, + }, + in: map[string]any{ + "m": map[int]int{1: 1, 2: 2}, + }, + estimatedCost: checker.CostEstimate{Min: 2, Max: 23}, + actualCost: 16, + }, + { + name: "all map(string,string) variable", + expr: `m.all(k, v, k < v)`, + vars: []cel.EnvOption{cel.Variable("m", cel.MapType(cel.StringType, cel.StringType))}, + hints: map[string]uint64{ + "m": 3, + "m.@keys": 16, + "m.@values": 128, + }, + in: map[string]any{ + "m": map[string]string{"he": "hello", "go": "goodbye"}, + }, + estimatedCost: checker.CostEstimate{Min: 2, Max: 23}, + actualCost: 14, + }, + { + name: "transformList empty", + expr: `[].transformList(i, v, v) == []`, + estimatedCost: checker.FixedCostEstimate(31), + actualCost: 31, + }, + { + name: "transformList single element", + expr: `[1].transformList(i, v, i) == [0]`, + estimatedCost: checker.FixedCostEstimate(45), + actualCost: 45, + }, + { + name: "transformList with filter", + expr: `[3, 2, 1].transformList(i, v, v > i, v) == [3, 2]`, + estimatedCost: checker.CostEstimate{Min: 44, Max: 80}, + actualCost: 67, + }, + { + name: "transformMap empty list", + expr: `[].transformMap(k, v, v + 1) == {}`, + estimatedCost: checker.FixedCostEstimate(71), + actualCost: 71, + }, + { + name: "transformMap empty map", + expr: `{}.transformMap(k, v, v + 1) == {}`, + estimatedCost: checker.FixedCostEstimate(91), + actualCost: 91, + }, + { + name: "transformMap literal scalar map", + expr: `{1: 2}.transformMap(k, v, v + 1) == {1: 3}`, + estimatedCost: checker.FixedCostEstimate(97), + actualCost: 97, + }, + { + name: "transformMap local bind", + expr: `cel.bind(m, {"hello": "hello"}, + m.transformMap(k, v, v + "world")) == {"hello": "helloworld"}`, + estimatedCost: checker.FixedCostEstimate(108), + actualCost: 108, + }, + { + name: "transformMap filter map", + expr: `{1: 2, 3: 4, 5: 6}.transformMap(k, v, k % 3 == 0, v + 1) == {3: 5}`, + estimatedCost: checker.CostEstimate{Min: 104, Max: 116}, + actualCost: 106, + }, + { + name: "transformMap variable input", + expr: `m.transformMap(k, v, k.startsWith('legacy') && v.size() == 1, v + [2]) == {'legacy-solo': [1, 2]}`, + vars: []cel.EnvOption{ + cel.Variable("m", cel.MapType(cel.StringType, cel.ListType(cel.IntType))), + }, + in: map[string]any{ + "m": map[string][]int{ + "legacy-solo": {1}, + "legacy-pair": {3, 2}, + }, + }, + hints: map[string]uint64{ + "m": 5, + "m.@keys": 16, + "m.@values": 10, + "m.@values.@items": 2, + }, + estimatedCost: checker.CostEstimate{Min: 73, Max: 173}, + actualCost: 100, + }, + { + name: "transformMapEntry literal input", + expr: `{1: 2}.transformMapEntry(k, v, {v: k}) == {2: 1}`, + estimatedCost: checker.FixedCostEstimate(126), + actualCost: 126, + }, + { + name: "transformMapEntry variable input", + expr: `m.transformMapEntry(k, v, {v: k}) == m.transformMapEntry(k, v, {v: k})`, + vars: []cel.EnvOption{ + cel.Variable("m", cel.MapType(cel.StringType, cel.IntType)), + }, + in: map[string]any{ + "m": map[string]int{ + "legacy-solo": 1, + "legacy-pair": 2, + }, + }, + hints: map[string]uint64{ + "m": 5, + "m.@keys": 16, + "m.@values": 10, + }, + estimatedCost: checker.CostEstimate{Min: 65, Max: 405}, + actualCost: 201, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + env := testCompreEnv(t, tc.vars...) + var asts []*cel.Ast + pAst, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, pAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) + } + + testCheckCost(t, env, cAst, tc.hints, tc.estimatedCost) + asts = append(asts, cAst) + for _, ast := range asts { + testEvalWithCost(t, env, ast, tc.in, tc.actualCost) + } + }) + } +} + func TestTwoVarComprehensionsStaticErrors(t *testing.T) { tests := []struct { expr string diff --git a/ext/sets_test.go b/ext/sets_test.go index 70ca2393..74163db6 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -29,7 +29,7 @@ import ( ) func TestSets(t *testing.T) { - setsTests := []struct { + tests := []struct { expr string vars []cel.EnvOption in map[string]any @@ -61,222 +61,222 @@ func TestSets(t *testing.T) { }, { expr: `sets.contains([], [])`, - estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + estimatedCost: checker.FixedCostEstimate(21), actualCost: 21, }, { expr: `sets.contains([1], [])`, - estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + estimatedCost: checker.FixedCostEstimate(21), actualCost: 21, }, { expr: `sets.contains([1], [1])`, - estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + estimatedCost: checker.FixedCostEstimate(22), actualCost: 22, }, { expr: `sets.contains([1], [1, 1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.contains([1, 1], [1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.contains([2, 1], [1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.contains([1, 2, 3, 4], [2, 3])`, - estimatedCost: checker.CostEstimate{Min: 29, Max: 29}, + estimatedCost: checker.FixedCostEstimate(29), actualCost: 29, }, { expr: `sets.contains([1], [1.0, 1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.contains([1, 2], [2u, 2.0])`, - estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + estimatedCost: checker.FixedCostEstimate(25), actualCost: 25, }, { expr: `sets.contains([1, 2u], [2, 2.0])`, - estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + estimatedCost: checker.FixedCostEstimate(25), actualCost: 25, }, { expr: `sets.contains([1, 2.0, 3u], [1.0, 2u, 3])`, - estimatedCost: checker.CostEstimate{Min: 30, Max: 30}, + estimatedCost: checker.FixedCostEstimate(30), actualCost: 30, }, { expr: `sets.contains([[1], [2, 3]], [[2, 3.0]])`, // 10 for each list creation, top-level list sizes are 2, 1 - estimatedCost: checker.CostEstimate{Min: 53, Max: 53}, + estimatedCost: checker.FixedCostEstimate(53), actualCost: 53, }, { expr: `!sets.contains([1], [2])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `!sets.contains([1], [1, 2])`, - estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + estimatedCost: checker.FixedCostEstimate(24), actualCost: 24, }, { expr: `!sets.contains([1], ["1", 1])`, - estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + estimatedCost: checker.FixedCostEstimate(24), actualCost: 24, }, { expr: `!sets.contains([1], [1.1, 1u])`, - estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + estimatedCost: checker.FixedCostEstimate(24), actualCost: 24, }, // set equivalence (note the cost factor is higher as it's basically two contains checks) { expr: `sets.equivalent([], [])`, - estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + estimatedCost: checker.FixedCostEstimate(21), actualCost: 21, }, { expr: `sets.equivalent([1], [1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.equivalent([1], [1, 1])`, - estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + estimatedCost: checker.FixedCostEstimate(25), actualCost: 25, }, { expr: `sets.equivalent([1, 1], [1])`, - estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + estimatedCost: checker.FixedCostEstimate(25), actualCost: 25, }, { expr: `sets.equivalent([1], [1u, 1.0])`, - estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + estimatedCost: checker.FixedCostEstimate(25), actualCost: 25, }, { expr: `sets.equivalent([1], [1u, 1.0])`, - estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + estimatedCost: checker.FixedCostEstimate(25), actualCost: 25, }, { expr: `sets.equivalent([1, 2, 3], [3u, 2.0, 1])`, - estimatedCost: checker.CostEstimate{Min: 39, Max: 39}, + estimatedCost: checker.FixedCostEstimate(39), actualCost: 39, }, { expr: `sets.equivalent([[1.0], [2, 3]], [[1], [2, 3.0]])`, - estimatedCost: checker.CostEstimate{Min: 69, Max: 69}, + estimatedCost: checker.FixedCostEstimate(69), actualCost: 69, }, { expr: `!sets.equivalent([2, 1], [1])`, - estimatedCost: checker.CostEstimate{Min: 26, Max: 26}, + estimatedCost: checker.FixedCostEstimate(26), actualCost: 26, }, { expr: `!sets.equivalent([1], [1, 2])`, - estimatedCost: checker.CostEstimate{Min: 26, Max: 26}, + estimatedCost: checker.FixedCostEstimate(26), actualCost: 26, }, { expr: `!sets.equivalent([1, 2], [2u, 2, 2.0])`, - estimatedCost: checker.CostEstimate{Min: 34, Max: 34}, + estimatedCost: checker.FixedCostEstimate(34), actualCost: 34, }, { expr: `!sets.equivalent([1, 2], [1u, 2, 2.3])`, - estimatedCost: checker.CostEstimate{Min: 34, Max: 34}, + estimatedCost: checker.FixedCostEstimate(34), actualCost: 34, }, // set intersection { expr: `sets.intersects([1], [1])`, - estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + estimatedCost: checker.FixedCostEstimate(22), actualCost: 22, }, { expr: `sets.intersects([1], [1, 1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.intersects([1, 1], [1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.intersects([2, 1], [1])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.intersects([1], [1, 2])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.intersects([1], [1.0, 2])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `sets.intersects([1, 2], [2u, 2, 2.0])`, - estimatedCost: checker.CostEstimate{Min: 27, Max: 27}, + estimatedCost: checker.FixedCostEstimate(27), actualCost: 27, }, { expr: `sets.intersects([1, 2], [1u, 2, 2.3])`, - estimatedCost: checker.CostEstimate{Min: 27, Max: 27}, + estimatedCost: checker.FixedCostEstimate(27), actualCost: 27, }, { expr: `sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]])`, - estimatedCost: checker.CostEstimate{Min: 65, Max: 65}, + estimatedCost: checker.FixedCostEstimate(65), actualCost: 65, }, { expr: `!sets.intersects([], [])`, - estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + estimatedCost: checker.FixedCostEstimate(22), actualCost: 22, }, { expr: `!sets.intersects([1], [])`, - estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + estimatedCost: checker.FixedCostEstimate(22), actualCost: 22, }, { expr: `!sets.intersects([1], [2])`, - estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + estimatedCost: checker.FixedCostEstimate(23), actualCost: 23, }, { expr: `!sets.intersects([1], ["1", 2])`, - estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + estimatedCost: checker.FixedCostEstimate(24), actualCost: 24, }, { expr: `!sets.intersects([1], [1.1, 2u])`, - estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + estimatedCost: checker.FixedCostEstimate(24), actualCost: 24, }, } - for _, tst := range setsTests { + for _, tst := range tests { tc := tst t.Run(tc.expr, func(t *testing.T) { env := testSetsEnv(t, tc.vars...) @@ -291,42 +291,10 @@ func TestSets(t *testing.T) { t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) } - hints := map[string]uint64{} - if len(tc.hints) != 0 { - hints = tc.hints - } - est, err := env.EstimateCost(cAst, testSetsCostEstimator{hints: hints}) - if err != nil { - t.Fatalf("env.EstimateCost() failed: %v", err) - } - if !reflect.DeepEqual(est, tc.estimatedCost) { - t.Errorf("env.EstimateCost() got %v, wanted %v", est, tc.estimatedCost) - } + testCheckCost(t, env, cAst, tc.hints, tc.estimatedCost) asts = append(asts, cAst) - for _, ast := range asts { - prgOpts := []cel.ProgramOption{} - if ast.IsChecked() { - prgOpts = append(prgOpts, cel.CostTracking(nil)) - } - prg, err := env.Program(ast, prgOpts...) - if err != nil { - t.Fatalf("env.Program() failed: %v", err) - } - in := tc.in - if in == nil { - in = map[string]any{} - } - out, det, err := prg.Eval(in) - if err != nil { - t.Fatalf("prg.Eval() failed: %v", err) - } - if out.Value() != true { - t.Errorf("prg.Eval() got %v, wanted true for expr: %s", out.Value(), tc.expr) - } - if det.ActualCost() != nil && *det.ActualCost() != tc.actualCost { - t.Errorf("prg.Eval() had cost %v, wanted %v", *det.ActualCost(), tc.actualCost) - } + testEvalWithCost(t, env, ast, tc.in, tc.actualCost) } }) } @@ -514,17 +482,56 @@ func testSetsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { return env } -type testSetsCostEstimator struct { +func testCheckCost(t *testing.T, env *cel.Env, ast *cel.Ast, hints map[string]uint64, wantEst checker.CostEstimate) { + t.Helper() + if len(hints) == 0 { + hints = map[string]uint64{} + } + est, err := env.EstimateCost(ast, testCostHintEstimator{hints: hints}) + if err != nil { + t.Fatalf("env.EstimateCost() failed: %v", err) + } + if !reflect.DeepEqual(est, wantEst) { + t.Errorf("env.EstimateCost() got %v, wanted %v", est, wantEst) + } +} + +func testEvalWithCost(t *testing.T, env *cel.Env, ast *cel.Ast, in any, wantCost uint64) { + t.Helper() + prgOpts := []cel.ProgramOption{} + if ast.IsChecked() { + prgOpts = append(prgOpts, cel.CostTracking(nil)) + } + prg, err := env.Program(ast, prgOpts...) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + var data any = cel.NoVars() + if in != nil { + data = in + } + out, det, err := prg.Eval(data) + if err != nil { + t.Fatal(err) + } else if out.Value() != true { + t.Errorf("got %v, wanted true", out.Value()) + } + if det.ActualCost() != nil && *det.ActualCost() != wantCost { + t.Errorf("prg.Eval() had cost %v, wanted %v", *det.ActualCost(), wantCost) + } +} + +type testCostHintEstimator struct { hints map[string]uint64 } -func (tc testSetsCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { +func (tc testCostHintEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { return &checker.SizeEstimate{Min: 0, Max: l} } return nil } -func (testSetsCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { +func (testCostHintEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { return nil } diff --git a/ext/strings_test.go b/ext/strings_test.go index 23a945eb..22aab84e 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -1615,103 +1615,103 @@ func TestQuoteUnquote(t *testing.T) { { name: "remove quotes only", testStr: "this is a test", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "mid-string newline", testStr: "first\nsecond", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "bell", testStr: "bell\a", - expectedEstimatedCost: checker.CostEstimate{Min: 1, Max: 1}, + expectedEstimatedCost: checker.FixedCostEstimate(1), expectedRuntimeCost: 1, }, { name: "backspace", testStr: "\bbackspace", - expectedEstimatedCost: checker.CostEstimate{Min: 1, Max: 1}, + expectedEstimatedCost: checker.FixedCostEstimate(1), expectedRuntimeCost: 1, }, { name: "form feed", testStr: "\fform feed", - expectedEstimatedCost: checker.CostEstimate{Min: 1, Max: 1}, + expectedEstimatedCost: checker.FixedCostEstimate(1), expectedRuntimeCost: 1, }, { name: "carriage return", testStr: "carriage \r return", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "horizontal tab", testStr: "horizontal \ttab", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "vertical tab", testStr: "vertical \v tab", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "double slash", testStr: "double \\\\ slash", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "two escape sequences", testStr: "two escape sequences \a\n", - expectedEstimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectedEstimatedCost: checker.FixedCostEstimate(3), expectedRuntimeCost: 3, }, { name: "ends with slash", testStr: "ends with \\", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "starts with slash", testStr: "\\ starts with", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "printable unicode", testStr: "printable unicode😀", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "mid-string quote", testStr: "mid-string \" quote", - expectedEstimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectedEstimatedCost: checker.FixedCostEstimate(2), expectedRuntimeCost: 2, }, { name: "single-quote with double quote", testStr: `single-quote with "double quote"`, - expectedEstimatedCost: checker.CostEstimate{Min: 4, Max: 4}, + expectedEstimatedCost: checker.FixedCostEstimate(4), expectedRuntimeCost: 4, }, { name: "CEL-only escape sequences", testStr: "\\? and \\`", - expectedEstimatedCost: checker.CostEstimate{Min: 1, Max: 1}, + expectedEstimatedCost: checker.FixedCostEstimate(1), expectedRuntimeCost: 1, }, { name: "test cost", testStr: "this is a very very very long string used to ensure that cost tracking works", - expectedEstimatedCost: checker.CostEstimate{Min: 8, Max: 8}, + expectedEstimatedCost: checker.FixedCostEstimate(8), expectedRuntimeCost: 8, }, { diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index b9b307c1..8f47c53d 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -198,20 +198,20 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re switch call.OverloadID() { // O(n) functions case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString: - cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) + cost += uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) case overloads.InList: // If a list is composed entirely of constant values this is O(1), but we don't account for that here. // We just assume all list containment checks are O(n). - cost += c.actualSize(args[1]) + cost += actualSize(args[1]) // O(min(m, n)) functions case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString, overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes, overloads.Equals, overloads.NotEquals: // When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.), - // the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost + // the CostTracker.ActualSize() function by definition returns 1 for each operand, resulting in an overall cost // of 1. - lhsSize := c.actualSize(args[0]) - rhsSize := c.actualSize(args[1]) + lhsSize := actualSize(args[0]) + rhsSize := actualSize(args[1]) minSize := lhsSize if rhsSize < minSize { minSize = rhsSize @@ -220,23 +220,23 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re // O(m+n) functions case overloads.AddString, overloads.AddBytes: // In the worst case scenario, we would need to reallocate a new backing store and copy both operands over. - cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor)) + cost += uint64(math.Ceil(float64(actualSize(args[0])+actualSize(args[1])) * common.StringTraversalCostFactor)) // O(nm) functions case overloads.MatchesString: // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 // in case where string is empty but regex is still expensive. - strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor)) + strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor)) // We don't know how many expressions are in the regex, just the string length (a huge // improvement here would be to somehow get a count the number of expressions in the regex or // how many states are in the regex state machine and use that to measure regex cost). // For now, we're making a guess that each expression in a regex is typically at least 4 chars // in length. - regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor)) + regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor)) cost += strCost * regexCost case overloads.ContainsString: - strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) - substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor)) + strCost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) + substrCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.StringTraversalCostFactor)) cost += strCost * substrCost default: @@ -253,11 +253,15 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re return cost } -// actualSize returns the size of value -func (c *CostTracker) actualSize(value ref.Val) uint64 { +// actualSize returns the size of the value for all traits.Sizer values, a fixed size for all proto-based +// objects, and a size of 1 for all other value types. +func actualSize(value ref.Val) uint64 { if sz, ok := value.(traits.Sizer); ok { return uint64(sz.Size().(types.Int)) } + if opt, ok := value.(*types.Optional); ok && opt.HasValue() { + return actualSize(opt.GetValue()) + } return 1 } diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 687a47b8..6e6bc6b5 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -30,7 +30,6 @@ import ( "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/parser" proto3pb "github.com/google/cel-go/test/proto3pb" @@ -230,7 +229,7 @@ func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeE func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { switch overloadID { case overloads.TimestampToYear: - return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}} + return &checker.CallEstimate{CostEstimate: checker.FixedCostEstimate(7)} } return nil } @@ -888,10 +887,3 @@ func TestRuntimeCost(t *testing.T) { }) } } - -func actualSize(val ref.Val) uint64 { - if sz, ok := val.(traits.Sizer); ok { - return uint64(sz.Size().(types.Int)) - } - return 1 -}