diff --git a/checker/cost.go b/checker/cost.go index 1b325eac..b9cd8a2e 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -930,6 +930,14 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate { if size, ok := c.computedSizes[e.ID()]; ok { return &size } + // Ensure size estimates are computed first as users may choose to override the costs that + // CEL would otherwise ascribe to the type. + node := astNode{expr: e, path: c.getPath(e), t: c.getType(e)} + if size := c.estimator.EstimateSize(node); size != nil { + // storing the computed size should reduce calls to EstimateSize() + c.computedSizes[e.ID()] = *size + return size + } if size := computeExprSize(e); size != nil { return size } @@ -942,12 +950,6 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate { return v.size } } - node := astNode{expr: e, path: c.getPath(e), t: c.getType(e)} - if size := c.estimator.EstimateSize(node); size != nil { - // storing the computed size should reduce calls to EstimateSize() - c.computedSizes[e.ID()] = *size - return size - } return nil } @@ -1014,8 +1016,7 @@ func computeTypeSize(t *types.Type) *SizeEstimate { // in addition to protobuf.Any and protobuf.Value (their size is not knowable at compile time). func isScalar(t *types.Type) bool { switch t.Kind() { - case types.BoolKind, types.DoubleKind, types.DurationKind, types.IntKind, - types.NullTypeKind, types.TimestampKind, types.TypeKind, types.UintKind: + case types.BoolKind, types.DoubleKind, types.DurationKind, types.IntKind, types.TimestampKind, types.UintKind: return true case types.OpaqueKind: if t.TypeName() == "optional_type" { diff --git a/checker/cost_test.go b/checker/cost_test.go index fb74cdf2..2bec0e94 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -628,6 +628,93 @@ func TestCost(t *testing.T) { expr: `['hello', 'hi'][0] != ['hello', 'bye'][1]`, wanted: CostEstimate{Min: 23, Max: 23}, }, + { + name: "type call", + expr: `type(1)`, + wanted: CostEstimate{Min: 1, Max: 1}, + }, + { + name: "type call variable", + expr: `type(self.val1)`, + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.IntType)), + }, + wanted: CostEstimate{Min: 3, Max: 3}, + }, + { + name: "type call variable equality", + expr: `type(self.val1) == int`, + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.IntType)), + }, + wanted: CostEstimate{Min: 5, Max: 1844674407370955268}, + }, + { + name: "type literal equality cost", + expr: `type(1) == int`, + wanted: CostEstimate{Min: 3, Max: 1844674407370955266}, + }, + { + name: "type variable equality cost", + expr: `type(1) == int`, + wanted: CostEstimate{Min: 3, Max: 1844674407370955266}, + }, + { + name: "namespace variable equality", + expr: `self.val1 == 1.0`, + vars: []*decls.VariableDecl{ + decls.NewVariable("self.val1", types.DoubleType), + }, + wanted: CostEstimate{Min: 2, Max: 2}, + }, + { + name: "simple map variable equality", + expr: `self.val1 == 1.0`, + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.DoubleType)), + }, + wanted: CostEstimate{Min: 3, Max: 3}, + }, + { + name: "date-time math", + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.TimestampType)), + }, + expr: `self.val1 == timestamp('2011-08-18T00:00:00.000+01:00') + duration('19h3m37s10ms')`, + wanted: FixedCostEstimate(6), + }, + { + name: "date-time math self-conversion", + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.TimestampType)), + }, + expr: `timestamp(self.val1) == timestamp('2011-08-18T00:00:00.000+01:00') + duration('19h3m37s10ms')`, + wanted: FixedCostEstimate(7), + }, + { + name: "boolean vars equal", + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.BoolType)), + }, + expr: `self.val1 != self.val2`, + wanted: FixedCostEstimate(5), + }, + { + name: "boolean var equals literal", + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.BoolType)), + }, + expr: `self.val1 != true`, + wanted: FixedCostEstimate(3), + }, + { + name: "double var equals literal", + vars: []*decls.VariableDecl{ + decls.NewVariable("self", types.NewMapType(types.StringType, types.DoubleType)), + }, + expr: `self.val1 == 1.0`, + wanted: FixedCostEstimate(3), + }, } for _, tst := range cases {