Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use inaccessible accumulator var #1097

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,18 +777,9 @@ func TestMacroInterop(t *testing.T) {
}

func TestMacroModern(t *testing.T) {
existsOneMacro := ReceiverMacro("exists_one", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeExistsOne(mef, iterRange, args)
})
transformMacro := ReceiverMacro("transform", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeMap(mef, iterRange, args)
})
filterMacro := ReceiverMacro("filter", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeFilter(mef, iterRange, args)
})
existsOneMacro := ReceiverMacro("exists_one", 2, parser.MakeExistsOne)
transformMacro := ReceiverMacro("transform", 2, parser.MakeMap)
filterMacro := ReceiverMacro("filter", 2, parser.MakeFilter)
pairMacro := GlobalMacro("pair", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return mef.NewMap(mef.NewMapEntry(args[0], args[1], false)), nil
Expand Down
9 changes: 9 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,15 @@ func ParserExpressionSizeLimit(limit int) EnvOption {
}
}

// EnableHiddenAccumulatorName sets the parser to use the identifier '@result' for accumulators
// which is not normally accessible from CEL source.
func EnableHiddenAccumulatorName(enabled bool) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.EnableHiddenAccumulatorName(enabled))
return e, nil
}
}

func maybeInteropProvider(provider any) (types.Provider, error) {
switch p := provider.(type) {
case types.Provider:
Expand Down
6 changes: 5 additions & 1 deletion checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,13 @@ func (c *coster) addPath(e ast.Expr, path []string) {
c.exprPath[e.ID()] = path
}

func isAccumulatorVar(name string) bool {
return name == parser.AccumulatorName || name == parser.HiddenAccumulatorName
}

func (c *coster) newAstNode(e ast.Expr) *astNode {
path := c.getPath(e)
if len(path) > 0 && path[0] == parser.AccumulatorName {
if len(path) > 0 && isAccumulatorVar(path[0]) {
// only provide paths to root vars; omit accumulator vars
path = nil
}
Expand Down
25 changes: 22 additions & 3 deletions common/ast/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type ExprFactory interface {
//comprehension.
NewAccuIdent(id int64) Expr

// AccuIdentName reports the name of the accumulator variable to be used within a comprehension.
AccuIdentName() string

// NewLiteral creates an Expr value representing a literal value, such as a string or integer.
NewLiteral(id int64, value ref.Val) Expr

Expand Down Expand Up @@ -78,11 +81,23 @@ type ExprFactory interface {
isExprFactory()
}

type baseExprFactory struct{}
type baseExprFactory struct {
accumulatorName string
}

// NewExprFactory creates an ExprFactory instance.
func NewExprFactory() ExprFactory {
return &baseExprFactory{}
return &baseExprFactory{
"__result__",
}
}

// NewExprFactoryWithAccumulator creates an ExprFactory instance with a custom
// accumulator identifier name.
func NewExprFactoryWithAccumulator(id string) ExprFactory {
return &baseExprFactory{
id,
}
}

func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr {
Expand Down Expand Up @@ -138,7 +153,11 @@ func (fac *baseExprFactory) NewIdent(id int64, name string) Expr {
}

func (fac *baseExprFactory) NewAccuIdent(id int64) Expr {
return fac.NewIdent(id, "__result__")
return fac.NewIdent(id, fac.AccuIdentName())
}

func (fac *baseExprFactory) AccuIdentName() string {
return fac.accumulatorName
}

func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr {
Expand Down
28 changes: 14 additions & 14 deletions ext/comprehensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewLiteral(types.True),
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewAccuIdent()),
/*step=*/ mef.NewCall(operators.LogicalAnd, mef.NewAccuIdent(), args[2]),
Expand All @@ -267,7 +267,7 @@ func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewLiteral(types.False),
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewCall(operators.LogicalNot, mef.NewAccuIdent())),
/*step=*/ mef.NewCall(operators.LogicalOr, mef.NewAccuIdent(), args[2]),
Expand All @@ -285,7 +285,7 @@ func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.E
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewLiteral(types.Int(0)),
/*condition=*/ mef.NewLiteral(types.True),
/*step=*/ mef.NewCall(operators.Conditional, args[2],
Expand All @@ -311,18 +311,18 @@ func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
transform = args[2]
}

// __result__ = __result__ + [transform]
// accumulator = accumulator + [transform]
step := mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewList(transform))
if filter != nil {
// __result__ = (filter) ? __result__ + [transform] : __result__
// accumulator = (filter) ? accumulator + [transform] : accumulator
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewList(),
/*condition=*/ mef.NewLiteral(types.True),
step,
Expand All @@ -346,17 +346,17 @@ func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (a
transform = args[2]
}

// __result__ = cel.@mapInsert(__result__, iterVar1, transform)
// accumulator = cel.@mapInsert(accumulator, iterVar1, transform)
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), mef.NewIdent(iterVar1), transform)
if filter != nil {
// __result__ = (filter) ? cel.@mapInsert(__result__, iterVar1, transform) : __result__
// accumulator = (filter) ? cel.@mapInsert(accumulator, iterVar1, transform) : accumulator
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewMap(),
/*condition=*/ mef.NewLiteral(types.True),
step,
Expand All @@ -380,17 +380,17 @@ func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Exp
transform = args[2]
}

// __result__ = cel.@mapInsert(__result__, transform)
// accumulator = cel.@mapInsert(accumulator, transform)
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), transform)
if filter != nil {
// __result__ = (filter) ? cel.@mapInsert(__result__, transform) : __result__
// accumulator = (filter) ? cel.@mapInsert(accumulator, transform) : accumulator
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewMap(),
/*condition=*/ mef.NewLiteral(types.True),
step,
Expand All @@ -410,10 +410,10 @@ func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, str
if iterVar1 == iterVar2 {
return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1))
}
if iterVar1 == parser.AccumulatorName {
if iterVar1 == mef.AccuIdentName() || iterVar1 == parser.AccumulatorName {
return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable")
}
if iterVar2 == parser.AccumulatorName {
if iterVar2 == mef.AccuIdentName() || iterVar2 == parser.AccumulatorName {
return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable")
}
return iterVar1, iterVar2, nil
Expand Down
4 changes: 3 additions & 1 deletion ext/comprehensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
Lists(),
Strings(),
cel.OptionalTypes(),
cel.EnableMacroCallTracking()}
cel.EnableMacroCallTracking(),
cel.EnableHiddenAccumulatorName(true),
}
env, err := cel.NewEnv(append(baseOpts, opts...)...)
if err != nil {
t.Fatalf("cel.NewEnv(TwoVarComprehensions()) failed: %v", err)
Expand Down
5 changes: 5 additions & 0 deletions parser/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ func (e *exprHelper) NewAccuIdent() ast.Expr {
return e.exprFactory.NewAccuIdent(e.nextMacroID())
}

// AccuIdentName implements the ExprHelper interface method.
func (e *exprHelper) AccuIdentName() string {
return e.exprFactory.AccuIdentName()
}

// NewGlobalCall implements the ExprHelper interface method.
func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr {
return e.exprFactory.NewCall(e.nextMacroID(), function, args...)
Expand Down
23 changes: 17 additions & 6 deletions parser/macro.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ type ExprHelper interface {
// NewAccuIdent returns an accumulator identifier for use with comprehension results.
NewAccuIdent() ast.Expr

// AccuIdentName returns the name of the accumulator identifier.
AccuIdentName() string

// NewCall creates a function call Expr value for a global (free) function.
NewCall(function string, args ...ast.Expr) ast.Expr

Expand Down Expand Up @@ -298,6 +301,11 @@ var (
// AccumulatorName is the traditional variable name assigned to the fold accumulator variable.
const AccumulatorName = "__result__"

// HiddenAccumulatorName is a proposed update to the default fold accumlator variable.
// @result is not normally accessible from source, preventing accidental or intentional collisions
// in user expressions.
const HiddenAccumulatorName = "@result"

type quantifierKind int

const (
Expand Down Expand Up @@ -342,7 +350,8 @@ func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common
if !found {
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
}
if v == AccumulatorName {
accu := eh.AccuIdentName()
if v == accu || v == AccumulatorName {
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
}

Expand All @@ -364,7 +373,7 @@ func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common
if filter != nil {
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
}
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
return eh.NewComprehension(target, v, accu, init, condition, step, eh.NewAccuIdent()), nil
}

// MakeFilter expands the input call arguments into a comprehension which produces a list which contains
Expand All @@ -375,7 +384,8 @@ func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *com
if !found {
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
}
if v == AccumulatorName {
accu := eh.AccuIdentName()
if v == accu || v == AccumulatorName {
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
}

Expand All @@ -384,7 +394,7 @@ func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *com
condition := eh.NewLiteral(types.True)
step := eh.NewCall(operators.Add, eh.NewAccuIdent(), eh.NewList(args[0]))
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
return eh.NewComprehension(target, v, accu, init, condition, step, eh.NewAccuIdent()), nil
}

// MakeHas expands the input call arguments into a presence test, e.g. has(<operand>.field)
Expand All @@ -401,7 +411,8 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []
if !found {
return nil, eh.NewError(args[0].ID(), "argument must be a simple name")
}
if v == AccumulatorName {
accu := eh.AccuIdentName()
if v == accu || v == AccumulatorName {
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
}

Expand Down Expand Up @@ -431,7 +442,7 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []
default:
return nil, eh.NewError(args[0].ID(), fmt.Sprintf("unrecognized quantifier '%v'", kind))
}
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, result), nil
return eh.NewComprehension(target, v, accu, init, condition, step, result), nil
}

func extractIdent(e ast.Expr) (string, bool) {
Expand Down
13 changes: 13 additions & 0 deletions parser/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type options struct {
enableOptionalSyntax bool
enableVariadicOperatorASTs bool
enableIdentEscapeSyntax bool
enableHiddenAccumulatorName bool
}

// Option configures the behavior of the parser.
Expand Down Expand Up @@ -137,6 +138,18 @@ func EnableIdentEscapeSyntax(enableIdentEscapeSyntax bool) Option {
}
}

// EnableHiddenAccumulatorName uses an accumulator variable name that is not a
// normally accessible identifier in source for comprehension macros. Compatibility notes:
// with this option enabled, a parsed AST would be semantically the same as if disabled, but would
// have different internal identifiers in any of the built-in comprehension sub-expressions. When
// disabled, it is possible but almost certainly a logic error to access the accumulator variable.
func EnableHiddenAccumulatorName(enabled bool) Option {
return func(opts *options) error {
opts.enableHiddenAccumulatorName = enabled
return nil
}
}

// EnableVariadicOperatorASTs enables a compact representation of chained like-kind commutative
// operators. e.g. `a || b || c || d` -> `call(op='||', args=[a, b, c, d])`
//
Expand Down
6 changes: 5 additions & 1 deletion parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ func mustNewParser(opts ...Option) *Parser {
// Parse parses the expression represented by source and returns the result.
func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) {
errs := common.NewErrors(source)
fac := ast.NewExprFactory()
accu := AccumulatorName
if p.enableHiddenAccumulatorName {
accu = HiddenAccumulatorName
}
fac := ast.NewExprFactoryWithAccumulator(accu)
impl := parser{
errors: &parseErrors{errs},
exprFactory: fac,
Expand Down
Loading