Skip to content

Commit

Permalink
#1477 do sensitivity integration tests using a processed variable
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Aug 4, 2021
1 parent bf59b7d commit 88ecb3f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
23 changes: 12 additions & 11 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,17 +537,18 @@ def initialise_sensitivity_explicit_forward(self):
dvar_dp_func = casadi.Function(
"dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp]
)
for idx in range(len(self.all_ts[0])):
t = self.all_ts[0][idx]
u = self.all_ys[0][:, idx]
next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked)
next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked)
if idx == 0:
dvar_dy_eval = next_dvar_dy_eval
dvar_dp_eval = next_dvar_dp_eval
else:
dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval)
dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)
for index, (ts, ys) in enumerate(zip(self.all_ts, self.all_ys)):
for idx in range(len(ts)):
t = ts[idx]
u = ys[:, idx]
next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked)
next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked)
if index == 0 and idx == 0:
dvar_dy_eval = next_dvar_dy_eval
dvar_dp_eval = next_dvar_dp_eval
else:
dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval)
dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)

# Compute sensitivity
dy_dp = self.solution_sensitivities["all"]
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_outputs(self):
)
std_out_test.test_all()

def test_sensitivities(self, param_name, param_value):
def test_sensitivities(self, param_name, param_value,
output_name='Terminal voltage [V]'):
self.parameter_values.update({param_name: "[input]"})
inputs = {param_name: param_value}

Expand All @@ -113,6 +114,7 @@ def test_sensitivities(self, param_name, param_value):
self.model, t_eval, inputs=inputs,
calculate_sensitivities=True
)
output_sens = self.solution[output_name].sensitivities[param_name]

# check via finite differencing
h = 1e-6 * param_value
Expand All @@ -121,14 +123,16 @@ def test_sensitivities(self, param_name, param_value):
sol_plus = self.solver.solve(
self.model, t_eval, inputs=inputs_plus,
)
output_plus = sol_plus[output_name](t=t_eval)
sol_neg = self.solver.solve(
self.model, t_eval, inputs=inputs_neg
)
fd = ((np.array(sol_plus.y) - np.array(sol_neg.y)) / h)
output_neg = sol_neg[output_name](t=t_eval)
fd = ((np.array(output_plus) - np.array(output_neg)) / h)
fd = fd.transpose().reshape(-1, 1)
np.testing.assert_allclose(
self.solution.sensitivities[param_name], fd,
rtol=1e-1, atol=1e-5,
output_sens, fd,
rtol=1e-2, atol=1e-6,
)

def test_all(
Expand Down

0 comments on commit 88ecb3f

Please sign in to comment.