Skip to content

Commit

Permalink
#1477 fix some remaining bugs with algebraic solver bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Aug 5, 2021
1 parent f8bc091 commit 6ca02be
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
14 changes: 10 additions & 4 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,11 @@ def jacp(*args, **kwargs):
n_inputs = model.len_rhs_sens // model.len_rhs
elif model.len_alg != 0:
n_inputs = model.len_alg_sens // model.len_alg
model.bounds = (
np.repeat(model.bounds[0], n_inputs + 1),
np.repeat(model.bounds[1], n_inputs + 1),
)
if model.bounds[0].shape[0] < model.len_alg + model.len_alg_sens:
model.bounds = (
np.repeat(model.bounds[0], n_inputs + 1),
np.repeat(model.bounds[1], n_inputs + 1),
)
if (model.mass_matrix is not None
and model.mass_matrix.shape[0] == model.len_rhs_and_alg):

Expand All @@ -634,6 +635,11 @@ def jacp(*args, **kwargs):
)
else:
# take care if calculate_sensitivites used then not used
if model.bounds[0].shape[0] > model.len_alg:
model.bounds = (
model.bounds[0][:model.len_alg],
model.bounds[1][:model.len_alg],
)
if (model.mass_matrix is not None and
model.mass_matrix.shape[0] > model.len_rhs_and_alg):
if model.mass_matrix_inv is not None:
Expand Down
6 changes: 5 additions & 1 deletion pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):

# Return solution object (no events, so pass None to t_event, y_event)

explicit_sensitivities = bool(model.calculate_sensitivities)
try:
explicit_sensitivities = bool(model.calculate_sensitivities)
except AttributeError:
explicit_sensitivities = False

sol = pybamm.Solution(
[t_eval], y_sol, model, inputs_dict, termination="success",
sensitivities=explicit_sensitivities
Expand Down
4 changes: 4 additions & 0 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def _integrate(self, model, t_eval, inputs_dict=None):
# update y0
y0 = solution.all_ys[-1][:, -1]

# now we extract sensitivities from the solution
if (bool(model.calculate_sensitivities)):
solution.sensitivities = True

return solution

def _solve_for_event(self, coarse_solution, init_event_signs):
Expand Down
14 changes: 10 additions & 4 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ def __init__(
else:
self.all_inputs = all_inputs

# sensitivities must be a dict or bool
if not isinstance(sensitivities, (bool, dict)):
raise TypeError('sensitivities arg needs to be a bool or dict')
self._sensitivities = sensitivities

self.sensitivities = sensitivities

self._t_event = t_event
self._y_event = y_event
Expand Down Expand Up @@ -285,6 +283,14 @@ def sensitivities(self):
self._sensitivities = {}
return self._sensitivities

@sensitivities.setter
def sensitivities(self, value):
"""Updates the sensitivity"""
# sensitivities must be a dict or bool
if not isinstance(value, (bool, dict)):
raise TypeError('sensitivities arg needs to be a bool or dict')
self._sensitivities = value

def set_y(self):
try:
if isinstance(self.all_ys[0], (casadi.DM, casadi.MX)):
Expand Down

0 comments on commit 6ca02be

Please sign in to comment.