Skip to content

Commit

Permalink
#1477 update test_sensitivities to use a dae
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed May 24, 2021
1 parent 7d97c45 commit 7c39f3f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
38 changes: 22 additions & 16 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,37 +324,43 @@ def test_extrapolation_warnings(self):

def test_sensitivities(self):

def exact_diff_a(v, a, b):
return np.array([v**2 + 2 * a])
def exact_diff_a(y, a, b):
return np.array([
[y[0]**2 + 2 * a],
[y[0]]
])

def exact_diff_b(v, a, b):
return np.array([v])
def exact_diff_b(y, a, b):
return np.array([[y[0]], [0]])

for f in ['', 'python', 'casadi', 'jax']:
for convert_to_format in ['', 'python', 'casadi', 'jax']:
model = pybamm.BaseModel()
v = pybamm.Variable("v")
u = pybamm.Variable("u")
a = pybamm.InputParameter("a")
b = pybamm.InputParameter("b")
model.rhs = {v: a * v**2 + b * v + a**2}
model.initial_conditions = {v: 1}
model.convert_to_format = f
solver = pybamm.ScipySolver()
model.algebraic = {u: a * v - u}
model.initial_conditions = {v: 1, u: a * 1}
model.convert_to_format = convert_to_format
solver = pybamm.CasadiSolver()
solver.set_up(model, calculate_sensitivites=True,
inputs={'a': 0, 'b': 0})
all_inputs = []
for v_value in [0.1, -0.2, 1.5, 8.4]:
for a_value in [0.12, 1.5]:
for b_value in [0.82, 1.9]:
y = np.array([v_value])
t = 0
inputs = {'a': a_value, 'b': b_value}
all_inputs.append((t, y, inputs))
for u_value in [0.13, -0.23, 1.3, 13.4]:
for a_value in [0.12, 1.5]:
for b_value in [0.82, 1.9]:
y = np.array([v_value, u_value])
t = 0
inputs = {'a': a_value, 'b': b_value}
all_inputs.append((t, y, inputs))
for t, y, inputs in all_inputs:
if f == 'casadi':
if model.convert_to_format == 'casadi':
use_inputs = casadi.vertcat(*[x for x in inputs.values()])
else:
use_inputs = inputs
if f == 'jax':
if model.convert_to_format == 'jax':
sens = model.sensitivities_eval(
t, y, use_inputs
)
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def test_ida_roberts_klu_sensitivities(self):

t_eval = np.linspace(0, 3, 100)
a_value = 0.1
sol = solver.solve(model, t_eval, inputs={"a": a_value})
sol = solver.solve(
model, t_eval, inputs={"a": a_value},
calculate_sensitivities=True
)

# test that final time is time of event
# y = 0.1 t + y0 so y=0.2 when t=2
Expand Down

0 comments on commit 7c39f3f

Please sign in to comment.