Skip to content

Commit

Permalink
#1477 generalising 'explicit forward' option so any solver can use it
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 5, 2021
1 parent 0cbe0a5 commit 5214994
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 27 deletions.
35 changes: 20 additions & 15 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import sys
import itertools
from scipy.linalg import block_diag
from scipy.sparse import block_diag
import multiprocessing as mp
import warnings

Expand Down Expand Up @@ -241,6 +241,8 @@ def set_up(self, model, inputs=None, t_eval=None,
# save sensitivity parameters so we can identify them later on
# (FYI: this is used in the Solution class)
model.calculate_sensitivities = calculate_sensitivites
model.len_rhs_sens = model.len_rhs * len(calculate_sensitivites)
model.len_alg_sens = model.len_alg * len(calculate_sensitivites)

# Only allow solving explicit sensitivity equations with the casadi format for now
if (
Expand Down Expand Up @@ -277,8 +279,6 @@ def set_up(self, model, inputs=None, t_eval=None,
pS_casadi_stacked = casadi.vertcat(
*[p_casadi[name] for name in calculate_sensitivites]
)
model.len_rhs_sens = model.len_rhs * pS_casadi_stacked.shape[0]
model.len_alg_sens = model.len_alg * pS_casadi_stacked.shape[0]
S_x = casadi.MX.sym("S_x", model.len_rhs_sens)
S_z = casadi.MX.sym("S_z", model.len_alg_sens)
y_and_S = casadi.vertcat(y_diff, S_x, y_alg, S_z)
Expand Down Expand Up @@ -615,6 +615,21 @@ def jacp(*args, **kwargs):
interpolant_extrapolation_events_eval
)

# if we have changed the equations to include the explicit sensitivity
# equations, then we also need to update the mass matrix
if self.sensitivity == "explicit forward":
n_inputs = len(calculate_sensitivites)
model.mass_matrix_inv = pybamm.Matrix(
block_diag(
[model.mass_matrix_inv.entries] * (n_inputs + 1), format="csr"
)
)
model.mass_matrix = pybamm.Matrix(
block_diag(
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
)
)

# Save CasADi functions for the CasADi solver
# Note: when we pass to casadi the ode part of the problem must be in explicit
# form so we pre-multiply by the inverse of the mass matrix
Expand All @@ -623,16 +638,7 @@ def jacp(*args, **kwargs):
):
# can use DAE solver to solve model with algebraic equations only
if len(model.rhs) > 0:
if self.sensitivity == "explicit forward":
# Copy mass matrix blocks diagonally
single_mass_matrix_inv = model.mass_matrix_inv.entries.toarray()
n_inputs = p_casadi_stacked.shape[0]
block_mass_matrix = block_diag(
*[single_mass_matrix_inv] * (n_inputs + 1)
)
mass_matrix_inv = casadi.MX(block_mass_matrix)
else:
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
explicit_rhs = mass_matrix_inv @ rhs(
t_casadi, y_and_S, p_casadi_stacked
)
Expand Down Expand Up @@ -754,8 +760,7 @@ def calculate_consistent_state(self, model, time=0, inputs=None):
)
pybamm.logger.debug("Found consistent states")

# use all_ys_and_sens in case we are solving the full sensitivity equations
y0 = root_sol.all_ys_and_sens[0]
y0 = root_sol.all_ys[0]
if isinstance(y0, np.ndarray):
y0 = y0.flatten()
return y0
Expand Down
30 changes: 25 additions & 5 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
inputs_dict : dict, optional
Any external variables or input parameters to pass to the model when solving
"""


# are we solving explicit forward equations?
explicit_sensitivities = self.sensitivity == 'explicit forward'

# Record whether there are any symbolic inputs
inputs_dict = inputs_dict or {}

Expand Down Expand Up @@ -158,14 +163,15 @@ def _integrate(self, model, t_eval, inputs_dict=None):
# Create integrator without grid to avoid having to create several times
self.create_integrator(model, inputs)
solution = self._run_integrator(
model, model.y0, inputs_dict, inputs, t_eval, use_grid=False
model, model.y0, inputs_dict, inputs, t_eval, use_grid=False,
)
if self.sensitivity == "casadi" and inputs_dict != {}:
# If the solution has already been created, we can reuse it
if model in self.y_sols:
y_sol = self.y_sols[model]
solution = pybamm.Solution(
t_eval, y_sol, model=model, inputs=inputs_dict
t_eval, y_sol, model=model, inputs=inputs_dict,
sensitivities=explicit_sensitivities
)
else:
# Create integrator without grid, which will be called repeatedly
Expand Down Expand Up @@ -212,7 +218,10 @@ def _integrate(self, model, t_eval, inputs_dict=None):
# to avoid having to create several times
self.create_integrator(model, inputs_dict)
# Initialize solution
solution = pybamm.Solution(np.array([t]), y0, model, inputs_dict)
solution = pybamm.Solution(
np.array([t]), y0, model, inputs_dict,
sensitivities=explicit_sensitivities
)
solution.solve_time = 0
solution.integration_time = 0
use_grid = False
Expand Down Expand Up @@ -455,6 +464,7 @@ def integer_bisect():
np.array([t_event]),
y_event[:, np.newaxis],
"event",
sensitivities=explicit_sensitivities
)
solution.integration_time = (
coarse_solution.integration_time + dense_step_sol.integration_time
Expand Down Expand Up @@ -613,6 +623,10 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):

def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True):
pybamm.logger.debug("Running CasADi integrator")

# are we solving explicit forward equations?
explicit_sensitivities = self.sensitivity == 'explicit forward'

if use_grid is True:
t_eval_shifted = t_eval - t_eval[0]
t_eval_shifted_rounded = np.round(t_eval_shifted, decimals=12).tobytes()
Expand Down Expand Up @@ -649,7 +663,10 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
)
integration_time = timer.time()
y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"])
sol = pybamm.Solution(t_eval, y_sol, model, inputs_dict)
sol = pybamm.Solution(
t_eval, y_sol, model, inputs_dict,
sensitivities=explicit_sensitivities
)
sol.integration_time = integration_time
return sol
else:
Expand Down Expand Up @@ -682,7 +699,10 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
# Save the solution, can just reuse and change the inputs
self.y_sols[model] = y_sol

sol = pybamm.Solution(t_eval, y_sol, model, inputs_dict)
sol = pybamm.Solution(
t_eval, y_sol, model, inputs_dict,
sensitivities=explicit_sensitivities
)
sol.integration_time = integration_time
return sol
except RuntimeError as e:
Expand Down
19 changes: 12 additions & 7 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class Solution(object):
the event happens.
termination : str
String to indicate why the solution terminated
sensitivities: None or dict
Will be None if there are no sensitivities in this soluion. Otherwise, this is a
dict of parameter names to their calcululated sensitivities
sensitivities: bool or dict
True if sensitivities included as the solution of the explicit forwards
equations. False if no sensitivities included/wanted. Dict if sensitivities are
provided as a dict of {parameter: sensitivities} pairs.
"""

Expand All @@ -56,7 +58,7 @@ def __init__(
t_event=None,
y_event=None,
termination="final time",
sensitivities=None
sensitivities=False
):
if not isinstance(all_ts, list):
all_ts = [all_ts]
Expand All @@ -80,11 +82,12 @@ def __init__(
self.all_inputs = all_inputs

# sensitivities
if sensitivities is None:
if isinstance(sensitivities, bool):
self._sensitivities = {}
# if solution consists of explicit sensitivity equations, extract them
if (
all_models[0] is not None
sensitivities == True
and all_models[0] is not None
and not isinstance(all_ys[0], casadi.Function)
and all_models[0].len_rhs_and_alg != all_ys[0].shape[0]
and all_models[0].len_rhs_and_alg != 0 # for the dummy solver
Expand All @@ -95,8 +98,10 @@ def __init__(
self._extract_explicit_sensitivities(
all_models[0], all_ys[0], all_ts[0], self.all_inputs[0]
)
else:
elif isinstance(sensitivities, dict):
self._sensitivities = sensitivities
else:
raise RuntimeError('sensitivities arg needs to be a bool or dict')

self._t_event = t_event
self._y_event = y_event
Expand Down

0 comments on commit 5214994

Please sign in to comment.