Skip to content

Commit

Permalink
Add short-circuit evaluation optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
bizywizy committed Feb 2, 2024
1 parent 1c8c9e6 commit 98d9aee
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 0 deletions.
21 changes: 21 additions & 0 deletions compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,24 @@ func TestCompile_OpCallFast(t *testing.T) {
require.Equal(t, vm.OpCallFast, program.Bytecode[4])
require.Equal(t, 3, program.Arguments[4])
}

func TestCompile_optimizes_short_circuit(t *testing.T) {
env := mock.Env{}
program, err := expr.Compile("let a = true; let b = false; let c = true; a || b || c", expr.Env(env))
require.NoError(t, err)
require.Equal(t, vm.OpTrue, program.Bytecode[0])
require.Equal(t, vm.OpStore, program.Bytecode[1])
require.Equal(t, vm.OpFalse, program.Bytecode[2])
require.Equal(t, vm.OpStore, program.Bytecode[3])
require.Equal(t, vm.OpTrue, program.Bytecode[4])
require.Equal(t, vm.OpStore, program.Bytecode[5])
require.Equal(t, vm.OpLoadVar, program.Bytecode[6])
require.Equal(t, vm.OpJumpIfTrue, program.Bytecode[7])
require.Equal(t, 5, program.Arguments[7])
require.Equal(t, vm.OpPop, program.Bytecode[8])
require.Equal(t, vm.OpLoadVar, program.Bytecode[9])
require.Equal(t, vm.OpJumpIfTrue, program.Bytecode[10])
require.Equal(t, 2, program.Arguments[10])
require.Equal(t, vm.OpPop, program.Bytecode[11])
require.Equal(t, vm.OpLoadVar, program.Bytecode[12])
}
1 change: 1 addition & 0 deletions optimizer/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ func Optimize(node *Node, config *conf.Config) error {
Walk(node, &filterLen{})
Walk(node, &filterLast{})
Walk(node, &filterFirst{})
Walk(node, &shortCircuitEvaluation{})
return nil
}
100 changes: 100 additions & 0 deletions optimizer/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,103 @@ func TestOptimize_filter_map_first(t *testing.T) {

assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
}

func TestOptimize_short_circuit_evaluation(t *testing.T) {
tests := []struct {
input string
expected ast.Node
}{
{
input: `a && b && c && d`,
expected: &ast.BinaryNode{
Operator: "&&",
Left: &ast.IdentifierNode{Value: "a"},
Right: &ast.BinaryNode{
Operator: "&&",
Left: &ast.IdentifierNode{Value: "b"},
Right: &ast.BinaryNode{
Operator: "&&",
Left: &ast.IdentifierNode{Value: "c"},
Right: &ast.IdentifierNode{Value: "d"},
},
},
},
},
{
input: `a || b || c || d`,
expected: &ast.BinaryNode{
Operator: "||",
Left: &ast.IdentifierNode{Value: "a"},
Right: &ast.BinaryNode{
Operator: "||",
Left: &ast.IdentifierNode{Value: "b"},
Right: &ast.BinaryNode{
Operator: "||",
Left: &ast.IdentifierNode{Value: "c"},
Right: &ast.IdentifierNode{Value: "d"},
},
},
},
},
{
input: `a && b || c && d`,
expected: &ast.BinaryNode{
Operator: "||",
Left: &ast.BinaryNode{
Operator: "&&",
Left: &ast.IdentifierNode{Value: "a"},
Right: &ast.IdentifierNode{Value: "b"},
},
Right: &ast.BinaryNode{
Operator: "&&",
Left: &ast.IdentifierNode{Value: "c"},
Right: &ast.IdentifierNode{Value: "d"},
},
},
},
{
input: `filter([1, 2, 3, 4, 5], # > 3 && # != 4 && # != 5)`,
expected: &ast.BuiltinNode{
Name: "filter",
Arguments: []ast.Node{
&ast.ConstantNode{Value: []any{1, 2, 3, 4, 5}},
&ast.ClosureNode{
Node: &ast.BinaryNode{
Operator: "&&",
Left: &ast.BinaryNode{
Operator: ">",
Left: &ast.PointerNode{},
Right: &ast.IntegerNode{Value: 3},
},
Right: &ast.BinaryNode{
Operator: "&&",
Left: &ast.BinaryNode{
Operator: "!=",
Left: &ast.PointerNode{},
Right: &ast.IntegerNode{Value: 4},
},
Right: &ast.BinaryNode{
Operator: "!=",
Left: &ast.PointerNode{},
Right: &ast.IntegerNode{Value: 5},
},
},
},
},
},
},
},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
tree, err := parser.Parse(tt.input)
require.NoError(t, err)

err = optimizer.Optimize(&tree.Node, nil)
require.NoError(t, err)

assert.Equal(t, ast.Dump(tt.expected), ast.Dump(tree.Node))
})
}
}
31 changes: 31 additions & 0 deletions optimizer/short_circuit_evaluation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package optimizer

import (
. "github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/parser/operator"
)

type shortCircuitEvaluation struct{}

func (v *shortCircuitEvaluation) Visit(node *Node) {
if n, ok := (*node).(*BinaryNode); ok {
if operator.IsBoolean(n.Operator) {
if left, ok := n.Left.(*BinaryNode); ok {
if left.Operator == n.Operator {
Patch(node, &BinaryNode{
Operator: left.Operator,
Left: left.Left,
Right: &BinaryNode{
Operator: n.Operator,
Left: left.Right,
Right: n.Right,
},
})
if n, ok := (*node).(*BinaryNode); ok {
v.Visit(&n.Right)
}
}
}
}
}
}

0 comments on commit 98d9aee

Please sign in to comment.