From a55edea07f2c3bb026d89f4e7fe88ae4755a810f Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 17 Jan 2024 12:52:26 +1300 Subject: [PATCH] Fix short circuting of && and || in macros (#3655) --- src/macros.jl | 22 ++++++++++++++++++++-- test/test_macros.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/macros.jl b/src/macros.jl index 410daeae923..dc133ccde06 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -208,9 +208,27 @@ function _rewrite_to_jump_logic(x) return Expr(:call, op_equal_to, x.args[2:end]...) end elseif Meta.isexpr(x, :||) - return Expr(:call, op_or, x.args...) + # Take special care to ensure short-circuiting behavior of operator is + # retained. We don't want to evaluate the second argument if the first + # is `true`. + @assert length(x.args) == 2 + return Expr( + :if, + Expr(:call, ===, x.args[1], true), + true, + Expr(:call, op_or, x.args[1], x.args[2]), + ) elseif Meta.isexpr(x, :&&) - return Expr(:call, op_and, x.args...) + # Take special care to ensure short-circuiting behavior of operator is + # retained. We don't want to evaluate the second argument if the first + # is `false`. + @assert length(x.args) == 2 + return Expr( + :if, + Expr(:call, ===, x.args[1], false), + false, + Expr(:call, op_and, x.args[1], x.args[2]), + ) elseif Meta.isexpr(x, :comparison) lhs = Expr(:call, x.args[2], x.args[1], x.args[3]) rhs = Expr(:call, x.args[4], x.args[3], x.args[5]) diff --git a/test/test_macros.jl b/test/test_macros.jl index 1a4e0aca7c0..03652a0f679 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -2332,4 +2332,30 @@ function test_error_parsing_reference_sets() return end +function test_op_and_short_circuit() + model = Model() + @test @expression(model, false && error()) == false + data = Dict(2 => 1) + @expression( + model, + expr, + sum(1 for p in [1, 2] if p in keys(data) && data[p] == 1), + ) + @test expr == 1 + return +end + +function test_op_or_short_circuit() + model = Model() + @test @expression(model, true || error()) == true + data = Dict(2 => 1) + @expression( + model, + expr, + sum(1 for p in [1, 2] if !(p in keys(data)) || data[p] == 2), + ) + @test expr == 1 + return +end + end # module