Skip to content

Commit

Permalink
#1477 add some tests and remove uncovered lines not neccessary
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Aug 2, 2021
1 parent 4c7bbe5 commit 03528da
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 34 deletions.
12 changes: 2 additions & 10 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,7 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
result = result.reshape(result.shape[0], -1)

# don't need known_evals, but need to reproduce Symbol.evaluate signature
if known_evals is not None:
return result, known_evals
else:
return result
return result


class EvaluatorJaxSensitivities:
Expand All @@ -708,8 +704,4 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
# execute code
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)

# don't need known_evals, but need to reproduce Symbol.evaluate signature
if known_evals is not None:
return result, known_evals
else:
return result
return result
21 changes: 2 additions & 19 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,6 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
if model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver):
calculate_sensitivities_explicit = True

if calculate_sensitivities_explicit and model.convert_to_format != 'casadi':
raise NotImplementedError(
"Sensitivities only supported for:\n"
" - model.convert_to_format = 'casadi'\n"
" - IDAKLUSolver (any convert_to_format)"
)

# if we are calculating sensitivities explicitly then the number of
# equations will change
if calculate_sensitivities_explicit:
Expand Down Expand Up @@ -284,12 +277,7 @@ def report(string):
report(f"Converting {name} to jax")
func = pybamm.EvaluatorJax(func)
jacp = None
if calculate_sensitivities_explicit:
raise NotImplementedError(
"explicit sensitivity equations not supported for "
"convert_to_format='jax'"
)
elif model.calculate_sensitivities:
if model.calculate_sensitivities:
report((
f"Calculating sensitivities for {name} with respect "
f"to parameters {model.calculate_sensitivities} using jax"
Expand All @@ -308,12 +296,7 @@ def report(string):
elif model.convert_to_format != "casadi":
# Process with pybamm functions, optionally converting
# to python evaluator
if calculate_sensitivities_explicit:
raise NotImplementedError(
"explicit sensitivity equations not supported for "
"convert_to_format='{}'".format(model.convert_to_format)
)
elif model.calculate_sensitivities:
if model.calculate_sensitivities:
report((
f"Calculating sensitivities for {name} with respect "
f"to parameters {model.calculate_sensitivities}"
Expand Down
6 changes: 1 addition & 5 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
elif isinstance(sensitivities, dict):
self._sensitivities = sensitivities
else:
raise RuntimeError('sensitivities arg needs to be a bool or dict')
raise TypeError('sensitivities arg needs to be a bool or dict')

self._t_event = t_event
self._y_event = y_event
Expand Down Expand Up @@ -304,10 +304,6 @@ def all_ts(self):
def all_ys(self):
return self._all_ys

@property
def all_ys_and_sens(self):
return self._all_ys_and_sens

@property
def all_models(self):
"""Model(s) used for solution"""
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_solvers/test_processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,48 @@ def test_processed_variable_0D(self):
)
np.testing.assert_array_equal(processed_var.entries, y_sol[0])

# check empty sensitivity works

def test_processed_variable_0D_no_sensitivity(self):
# without space
t = pybamm.t
y = pybamm.StateVector(slice(0, 1))
var = t * y
var.mesh = None
t_sol = np.linspace(0, 1)
y_sol = np.array([np.linspace(0, 5)])
var_casadi = to_casadi(var, y_sol)
processed_var = pybamm.ProcessedVariable(
[var],
[var_casadi],
pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
warn=False,
)

# test no inputs (i.e. no sensitivity)
self.assertDictEqual(processed_var.sensitivities, {})

# with parameter
t = pybamm.t
y = pybamm.StateVector(slice(0, 1))
a = pybamm.InputParameter('a')
var = t * y * a
var.mesh = None
t_sol = np.linspace(0, 1)
y_sol = np.array([np.linspace(0, 5)])
inputs = {'a': np.array([1.0])}
var_casadi = to_casadi(var, y_sol, inputs=inputs)
processed_var = pybamm.ProcessedVariable(
[var],
[var_casadi],
pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), inputs),
warn=False,
)

# test no sensitivity raises error
with self.assertRaisesRegex(ValueError, 'Cannot compute sensitivities'):
print(processed_var.sensitivities)

def test_processed_variable_1D(self):
t = pybamm.t
var = pybamm.Variable("var", domain=["negative electrode", "separator"])
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_solvers/test_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def test_init(self):
self.assertEqual(sol.all_inputs, [{}])
self.assertIsInstance(sol.all_models[0], pybamm.BaseModel)

def test_sensitivities(self):
t = np.linspace(0, 1)
y = np.tile(t, (20, 1))
with self.assertRaises(TypeError):
pybamm.Solution(t, y, pybamm.BaseModel(), {}, sensitivities=1.0)

def test_errors(self):
bad_ts = [np.array([1, 2, 3]), np.array([3, 4, 5])]
sol = pybamm.Solution(
Expand Down

0 comments on commit 03528da

Please sign in to comment.