Skip to content

Commit

Permalink
Merge pull request #3373 from SciML/dw/odefunctionexpr_specialize
Browse files Browse the repository at this point in the history
Support specialization in `ODEFunctionExpr`
  • Loading branch information
ChrisRackauckas authored Feb 5, 2025
2 parents 387df59 + 58dde09 commit f5ea344
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 22 deletions.
48 changes: 26 additions & 22 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
The arguments `dvs` and `ps` are used to set the order of the dependent
variable and parameter vectors, respectively.
"""
struct ODEFunctionExpr{iip} end
struct ODEFunctionExpr{iip, specialize} end

struct ODEFunctionClosure{O, I} <: Function
f_oop::O
Expand All @@ -551,7 +551,7 @@ end
(f::ODEFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
(f::ODEFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)

function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys), u0 = nothing;
version = nothing, tgrad = false,
jac = false, p = nothing,
Expand All @@ -560,14 +560,12 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
steady_state = false,
sparsity = false,
observedfun_exp = nothing,
kwargs...) where {iip}
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`")
end
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)

dict = Dict()

fsym = gensym(:f)
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))
tgradsym = gensym(:tgrad)
Expand All @@ -590,30 +588,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
_jac = :($jacsym = nothing)
end

Msym = gensym(:M)
M = calculate_massmatrix(sys)

_M = if sparse && !(u0 === nothing || M === I)
SparseArrays.sparse(M)
if sparse && !(u0 === nothing || M === I)
_M = :($Msym = $(SparseArrays.sparse(M)))
elseif u0 === nothing || M === I
M
_M = :($Msym = $M)
else
ArrayInterface.restructure(u0 .* u0', M)
_M = :($Msym = $(ArrayInterface.restructure(u0 .* u0', M)))
end

jp_expr = sparse ? :($similar($(get_jac(sys)[]), Float64)) : :nothing
ex = quote
$_f
$_tgrad
$_jac
M = $_M
ODEFunction{$iip}($fsym,
sys = $sys,
jac = $jacsym,
tgrad = $tgradsym,
mass_matrix = M,
jac_prototype = $jp_expr,
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
observed = $observedfun_exp)
let $_f, $_tgrad, $_jac, $_M
ODEFunction{$iip, $specialize}($fsym,
sys = $sys,
jac = $jacsym,
tgrad = $tgradsym,
mass_matrix = $Msym,
jac_prototype = $jp_expr,
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
observed = $observedfun_exp)
end
end
!linenumbers ? Base.remove_linenums!(ex) : ex
end
Expand All @@ -622,6 +618,14 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
ODEFunctionExpr{true}(sys, args...; kwargs...)
end

function ODEFunctionExpr{true}(sys::AbstractODESystem, args...; kwargs...)
return ODEFunctionExpr{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function ODEFunctionExpr{false}(sys::AbstractODESystem, args...; kwargs...)
return ODEFunctionExpr{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

"""
```julia
DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
Expand Down
19 changes: 19 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,25 @@ f.f(du, u, p, 0.1)
@test du == [4, 0, -16]
@test_throws ArgumentError f.f(u, p, 0.1)

#check iip
f = eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β]))
f2 = ODEFunction(de, [x, y, z], [σ, ρ, β])
@test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2)
@test SciMLBase.specialization(f) === SciMLBase.specialization(f2)
for iip in (true, false)
f = eval(ODEFunctionExpr{iip}(de, [x, y, z], [σ, ρ, β]))
f2 = ODEFunction{iip}(de, [x, y, z], [σ, ρ, β])
@test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2) === iip
@test SciMLBase.specialization(f) === SciMLBase.specialization(f2)

for specialize in (SciMLBase.AutoSpecialize, SciMLBase.FullSpecialize)
f = eval(ODEFunctionExpr{iip, specialize}(de, [x, y, z], [σ, ρ, β]))
f2 = ODEFunction{iip, specialize}(de, [x, y, z], [σ, ρ, β])
@test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2) === iip
@test SciMLBase.specialization(f) === SciMLBase.specialization(f2) === specialize
end
end

#check sparsity
f = eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β], sparsity = true))
@test f.sparsity == ModelingToolkit.jacobian_sparsity(de)
Expand Down

0 comments on commit f5ea344

Please sign in to comment.