diff --git a/src/indicator.jl b/src/indicator.jl index ec8a9f23138..f8254d196a6 100644 --- a/src/indicator.jl +++ b/src/indicator.jl @@ -34,7 +34,7 @@ function parse_one_operator_constraint( _error("Invalid right-hand side `$(rhs)` of indicator constraint. Expected constraint surrounded by `{` and `}`.") end rhs_con = rhs.args[1] - rhs_vectorized, rhs_parsecode, rhs_buildcall = parse_constraint(_error, rhs_con.args...) + rhs_vectorized, rhs_parsecode, rhs_buildcall = parse_constraint_expr(_error, rhs_con) if vectorized != rhs_vectorized _error("Inconsistent use of `.` in symbols to indicate vectorization.") end diff --git a/src/macros.jl b/src/macros.jl index f94cd604bce..37b42b10fa0 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -177,6 +177,13 @@ function parse_one_operator_constraint(_error::Function, vectorized::Bool, sense return parse_code, _build_call(_error, vectorized, :(_functionize($variable)), set) end +function parse_constraint_expr(_error::Function, expr::Expr) + return parse_constraint_head(_error, Val(expr.head), expr.args...) +end +function parse_constraint_head(_error::Function, ::Val{:call}, args...) + return parse_constraint(_error, args...) +end + function parse_constraint(_error::Function, sense::Symbol, lhs, rhs) (sense, vectorized) = _check_vectorized(sense) vectorized, parse_one_operator_constraint(_error, vectorized, Val(sense), lhs, rhs)... @@ -202,6 +209,10 @@ function parse_ternary_constraint(_error::Function, args...) _error("Only two-sided rows of the form lb <= expr <= ub or ub >= expr >= lb are supported.") end +function parse_constraint_head(_error::Function, ::Val{:comparison}, lb, lsign::Symbol, aff, rsign::Symbol, ub) + return parse_constraint(_error, lb, lsign, aff, rsign, ub) +end + function parse_constraint(_error::Function, lb, lsign::Symbol, aff, rsign::Symbol, ub) (lsign, lvectorized) = _check_vectorized(lsign) (rsign, rvectorized) = _check_vectorized(rsign) @@ -215,13 +226,20 @@ function parse_constraint(_error::Function, lb, lsign::Symbol, aff, rsign::Symbo vectorized, parsecode, buildcall end -function parse_constraint(_error::Function, args...) +function _unknown_constraint_expr(_error::Function) # Unknown _error("Constraints must be in one of the following forms:\n" * " expr1 <= expr2\n" * " expr1 >= expr2\n" * " expr1 == expr2\n" * " lb <= expr <= ub") end +function parse_constraint_head(_error::Function, ::Val, args...) + _unknown_constraint_expr(_error) +end +function parse_constraint(_error::Function, args...) + _unknown_constraint_expr(_error) +end + # Generic fallback. function build_constraint(_error::Function, func, set::Union{MOI.AbstractScalarSet, MOI.AbstractVectorSet}) @@ -374,7 +392,7 @@ function _constraint_macro(args, macro_name::Symbol, parsefun::Function) # in a function returning `ConstraintRef`s and give it to `Containers.container`. idxvars, indices = Containers._build_ref_sets(_error, c) - vectorized, parsecode, buildcall = parsefun(_error, x.args...) + vectorized, parsecode, buildcall = parsefun(_error, x) _add_kw_args(buildcall, kw_args) if vectorized # TODO: Pass through names here. @@ -457,9 +475,12 @@ that either `func` or `set` will be some custom type, rather than e.g. a set appearing in the constraint. """ macro constraint(args...) - _constraint_macro(args, :constraint, parse_constraint) + _constraint_macro(args, :constraint, parse_constraint_expr) end +function parse_SD_constraint_expr(_error::Function, expr::Expr) + return parse_SD_constraint(_error, expr.args...) +end function parse_SD_constraint(_error::Function, sense::Symbol, lhs, rhs) # Simple comparison - move everything to the LHS aff = :() @@ -554,7 +575,7 @@ part of the matrix assuming that it is symmetric, see [`PSDCone`](@ref) to see how to use it. """ macro SDconstraint(args...) - _constraint_macro(args, :SDconstraint, parse_SD_constraint) + _constraint_macro(args, :SDconstraint, parse_SD_constraint_expr) end """ @@ -585,8 +606,8 @@ macro build_constraint(constraint_expr) "Are you missing a comparison (<=, >=, or ==)?") end - is_vectorized, parse_code, build_call = parse_constraint( - _error, constraint_expr.args...) + is_vectorized, parse_code, build_call = parse_constraint_expr( + _error, constraint_expr) result_variable = gensym() code = quote $parse_code diff --git a/test/macros.jl b/test/macros.jl index 1f76b982315..692d9062746 100644 --- a/test/macros.jl +++ b/test/macros.jl @@ -155,6 +155,27 @@ function build_constraint_keyword_test(ModelType::Type{<:JuMP.AbstractModel}) end end +struct CustomType +end +function JuMP.parse_constraint_head(_error::Function, ::Val{:(:=)}, lhs, rhs) + return false, :(), :(build_constraint($_error, $(esc(lhs)), $(esc(rhs)))) +end +struct CustomSet <: MOI.AbstractScalarSet +end +function JuMP.build_constraint(_error::Function, func, ::CustomType) + JuMP.build_constraint(_error, func, CustomSet()) +end +function custom_expression_test(ModelType::Type{<:JuMP.AbstractModel}) + @testset "Custom expression" begin + model = ModelType() + @variable(model, x) + @constraint(model, con_ref, x := CustomType()) + con = JuMP.constraint_object(con_ref) + @test jump_function(con) == x + @test moi_set(con) isa CustomSet + end +end + function macros_test(ModelType::Type{<:JuMP.AbstractModel}, VariableRefType::Type{<:JuMP.AbstractVariableRef}) @testset "build_constraint on variable" begin m = ModelType() @@ -336,6 +357,8 @@ function macros_test(ModelType::Type{<:JuMP.AbstractModel}, VariableRefType::Typ end build_constraint_keyword_test(ModelType) + + custom_expression_test(ModelType) end @testset "Macros for JuMP.Model" begin