Skip to content

Commit

Permalink
#1477 fix algebraic solver
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 16, 2021
1 parent 72560c5 commit d9ff546
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
1 change: 0 additions & 1 deletion pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def jacp(*args, **kwargs):
# Add sensitivity vectors to the rhs and algebraic equations
jacp = None
if calculate_sensitivites_explicit:
print('CASADI EXPLICIT', name, model.len_rhs)
# The formulation is as per Park, S., Kato, D., Gima, Z., Klein, R.,
# & Moura, S. (2018). Optimal experimental design for
# parameterization of an electrochemical lithium-ion battery model.
Expand Down
27 changes: 20 additions & 7 deletions pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def _integrate(self, model, t_eval, inputs_dict=None):
"""
# Record whether there are any symbolic inputs
inputs_dict = inputs_dict or {}
has_symbolic_inputs = any(
isinstance(v, casadi.MX) for v in inputs_dict.values()
)
symbolic_inputs = casadi.vertcat(
*[v for v in inputs_dict.values() if isinstance(v, casadi.MX)]
)
Expand All @@ -70,22 +73,29 @@ def _integrate(self, model, t_eval, inputs_dict=None):

y0 = model.y0

# If y0 already satisfies the tolerance for all t then keep it
if has_symbolic_inputs is False and all(
np.all(abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol)
for t in t_eval
):
pybamm.logger.debug("Keeping same solution at all times")
return pybamm.Solution(
t_eval, y0, model, inputs_dict, termination="success"
)

# The casadi algebraic solver can read rhs equations, but leaves them unchanged
# i.e. the part of the solution vector that corresponds to the differential
# equations will be equal to the initial condition provided. This allows this
# solver to be used for initialising the DAE solvers
if model.rhs == {}:
print('no rhs')
len_rhs = 0
y0_diff = casadi.DM()
y0_alg = y0
else:
# Check y0 to see if it includes sensitivities
if model.len_rhs_and_alg == y0.shape[0]:
print('doesnt include sens')
len_rhs = model.len_rhs
else:
print('includes sens', inputs.shape[0])
len_rhs = model.len_rhs * (inputs.shape[0] + 1)
y0_diff = y0[:len_rhs]
y0_alg = y0[len_rhs:]
Expand Down Expand Up @@ -159,7 +169,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
for idx, t in enumerate(t_eval):
# Evaluate algebraic with new t and previous y0, if it's already close
# enough then keep it
if np.all(
# We can't do this if there are symbolic inputs
if has_symbolic_inputs is False and np.all(
abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol
):
pybamm.logger.debug(
Expand All @@ -171,7 +182,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
y_alg = casadi.horzcat(y_alg, y0_alg)
# Otherwise calculate new y_sol
else:
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, inputs)
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, symbolic_inputs)
# Solve
try:
timer.reset()
Expand All @@ -187,9 +198,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
message = err.args[0]
fun = None

# check the function is below the tol
# If there are no symbolic inputs, check the function is below the tol
# Skip this check if there are symbolic inputs
if success and (
not any(np.isnan(fun)) and np.all(casadi.fabs(fun) < self.tol)
has_symbolic_inputs is True
or (not any(np.isnan(fun)) and np.all(casadi.fabs(fun) < self.tol))
):
# update initial guess for the next iteration
y0_alg = y_alg_sol
Expand Down
7 changes: 1 addition & 6 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def integer_bisect():
np.array([t_event]),
y_event[:, np.newaxis],
"event",
sensitivities=explicit_sensitivities
sensitivities=bool(self.calculate_sensitivites)
)
solution.integration_time = (
coarse_solution.integration_time + dense_step_sol.integration_time
Expand Down Expand Up @@ -665,11 +665,6 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
y_sol = y_diff
else:
y_sol = casadi.vertcat(y_diff, y_alg)
# If doing sensitivity, return the solution as a function of the inputs
if self.sensitivity == "casadi":
y_sol = casadi.Function("y_sol", [symbolic_inputs], [y_sol])
# Save the solution, can just reuse and change the inputs
self.y_sols[model] = y_sol

sol = pybamm.Solution(
t_eval, y_sol, model, inputs_dict,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def test_solve_sensitivity_scalar_var_scalar_input(self):
t_eval = np.linspace(0, 1, 80)
solution = solver.solve(
model, t_eval, inputs={"p": 0.1, "q": 2, "r": -1, "s": 0.5},
sensitivity=True,
calculate_sensitivities=True,
)
np.testing.assert_allclose(solution.y[0], -1 + 0.2 * solution.t)
np.testing.assert_allclose(
Expand Down

0 comments on commit d9ff546

Please sign in to comment.