Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add late_binding_update_u0_p #926

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,34 @@ function updated_u0_p(
return (u0 === missing ? state_values(prob) : u0),
(p === missing ? parameter_values(prob) : p)
end
return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
newu0, newp = _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
return late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
end

"""
$(TYPEDSIGNATURES)

A function to perform custom modifications to `newu0` and/or `newp` after they have been
constructed in `remake`. `root_indp` is the innermost index provider found by recursively
calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. Returns
the updated `newu0` and `newp`.
"""
function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
return newu0, newp
end

"""
$(TYPEDSIGNATURES)

Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after finding
`root_indp`.
"""
function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
root_indp = prob
while hasmethod(symbolic_container, Tuple{typeof(root_indp)}) && (sc = symbolic_container(root_indp)) !== root_indp
root_indp = sc
end
return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
end

# overloaded in MTK to intercept symbolic remake
Expand Down
12 changes: 12 additions & 0 deletions test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ for T in containerTypes
push!(probs, NonlinearLeastSquaresProblem(fn, u0, T(p)))
end

# temporary definition to test this functionality
function SciMLBase.late_binding_update_u0_p(prob, u0, p::SciMLBase.NullParameters, t0, newu0, newp)
return newu0, ones(3)
end

for prob in deepcopy(probs)
prob2 = @inferred remake(prob)
@test prob2.u0 == u0
Expand Down Expand Up @@ -274,8 +279,15 @@ for prob in deepcopy(probs)
end
ForwardDiff.derivative(fakeloss!, 1.0)
end

# test late_binding_update_u0_p
prob2 = remake(prob; p = SciMLBase.NullParameters())
@test prob2.p ≈ ones(3)
end

# delete the method defined here to prevent breaking other tests
Base.delete_method(only(methods(SciMLBase.late_binding_update_u0_p, @__MODULE__)))

# eltype(()) <: Pair, so ensure that this doesn't error
function lorenz!(du, u, _, t)
du[1] = 1 * (u[2] - u[1])
Expand Down
Loading