diff --git a/src/remake.jl b/src/remake.jl index 4a0f1726b..0741bd3bd 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -1075,7 +1075,34 @@ function updated_u0_p( return (u0 === missing ? state_values(prob) : u0), (p === missing ? parameter_values(prob) : p) end - return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults) + newu0, newp = _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults) + return late_binding_update_u0_p(prob, u0, p, t0, newu0, newp) +end + +""" + $(TYPEDSIGNATURES) + +A function to perform custom modifications to `newu0` and/or `newp` after they have been +constructed in `remake`. `root_indp` is the innermost index provider found by recursively +calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. Returns +the updated `newu0` and `newp`. +""" +function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp) + return newu0, newp +end + +""" + $(TYPEDSIGNATURES) + +Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after finding +`root_indp`. +""" +function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp) + root_indp = prob + while hasmethod(symbolic_container, Tuple{typeof(root_indp)}) && (sc = symbolic_container(root_indp)) !== root_indp + root_indp = sc + end + return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp) end # overloaded in MTK to intercept symbolic remake diff --git a/test/remake_tests.jl b/test/remake_tests.jl index 4af29623e..7b89d3460 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -83,6 +83,11 @@ for T in containerTypes push!(probs, NonlinearLeastSquaresProblem(fn, u0, T(p))) end +# temporary definition to test this functionality +function SciMLBase.late_binding_update_u0_p(prob, u0, p::SciMLBase.NullParameters, t0, newu0, newp) + return newu0, ones(3) +end + for prob in deepcopy(probs) prob2 = @inferred remake(prob) @test prob2.u0 == u0 @@ -274,8 +279,15 @@ for prob in deepcopy(probs) end ForwardDiff.derivative(fakeloss!, 1.0) end + + # test late_binding_update_u0_p + prob2 = remake(prob; p = SciMLBase.NullParameters()) + @test prob2.p ≈ ones(3) end +# delete the method defined here to prevent breaking other tests +Base.delete_method(only(methods(SciMLBase.late_binding_update_u0_p, @__MODULE__))) + # eltype(()) <: Pair, so ensure that this doesn't error function lorenz!(du, u, _, t) du[1] = 1 * (u[2] - u[1])