Skip to content

Commit

Permalink
fix: handle scalarized array parameters in initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 31, 2025
1 parent 1b7261a commit 4de5e18
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
34 changes: 26 additions & 8 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ function generate_initializesystem(sys::AbstractSystem;
# If either of them are `missing` the parameter is an unknown
# But if the parameter is passed a value, use that as an additional
# equation in the system
_val1 = get(pmap, p, nothing)
_val2 = get(defs, p, nothing)
_val3 = get(guesses, p, nothing)
_val1 = get_possibly_array_fallback_singletons(pmap, p)
_val2 = get_possibly_array_fallback_singletons(defs, p)
_val3 = get_possibly_array_fallback_singletons(guesses, p)
varp = tovar(p)
paramsubs[p] = varp
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
Expand All @@ -139,7 +139,7 @@ function generate_initializesystem(sys::AbstractSystem;
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
end
# given a symbolic value to ODEProblem
elseif symbolic_type(_val1) != NotSymbolic()
elseif symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1)
push!(eqs_ics, varp ~ _val1)
push!(defs, varp => _val3)
# No value passed to `ODEProblem`, but a default and a guess are present
Expand Down Expand Up @@ -268,16 +268,34 @@ struct InitializationSystemMetadata
oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob}
end

function get_possibly_array_fallback_singletons(varmap, p)
if haskey(varmap, p)
return varmap[p]
end
symbolic_type(p) == ArraySymbolic() || return nothing
scal = collect(p)
if all(x -> haskey(varmap, x), scal)
res = [varmap[x] for x in scal]
if any(x -> x === nothing, res)
return nothing
elseif any(x -> x === missing, res)
return missing
end
return res
end
return nothing
end

function is_parameter_solvable(p, pmap, defs, guesses)
p = unwrap(p)
is_variable_floatingpoint(p) || return false
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
_val2 = get(defs, p, nothing)
_val3 = get(guesses, p, nothing)
_val1 = pmap isa AbstractDict ? get_possibly_array_fallback_singletons(pmap, p) : nothing
_val2 = get_possibly_array_fallback_singletons(defs, p)
_val3 = get_possibly_array_fallback_singletons(guesses, p)
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
# the ODEProblem and it has a default and a guess)
return ((_val1 === missing || _val2 === missing) ||
(symbolic_type(_val1) != NotSymbolic() ||
(symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1) ||
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
end

Expand Down
11 changes: 11 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1281,3 +1281,14 @@ end
@test sol[S, 1] 999
@test SciMLBase.successful_retcode(sol)
end

@testset "Solvable array parameters with scalarized guesses" begin
@variables x(t)
@parameters p[1:2] q
@mtkbuild sys = ODESystem(D(x) ~ p[1] + p[2] + q, t; defaults = [p[1] => q, p[2] => 2q], guesses = [p[1] => q, p[2] => 2q])
@test ModelingToolkit.is_parameter_solvable(p, Dict(), defaults(sys), guesses(sys))
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [q => 2.0])
@test length(ModelingToolkit.observed(prob.f.initialization_data.initializeprob.f.sys)) == 3
sol = solve(prob, Tsit5())
@test sol.ps[p] [2.0, 4.0]
end

0 comments on commit 4de5e18

Please sign in to comment.