Skip to content

Commit

Permalink
#1477 flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jun 14, 2021
1 parent 1f3bc9d commit 47c5345
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pybamm/expression_tree/operations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,10 @@ def get_sensitivities(self):
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n)

self._sens_evaluate = jax.jit(jacobian_evaluate,
static_argnums=self._static_argnums)
static_argnums=self._static_argnums)

return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)


def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
Expand Down Expand Up @@ -688,6 +687,7 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
else:
return result


class EvaluatorJaxSensitivities:
def __init__(self, jac_evaluate, constants):
self._jac_evaluate = jac_evaluate
Expand Down
3 changes: 3 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def jacp(*args, **kwargs):
[p_diff]
)
# jacp should be a function that returns a dict of sensitivities

def jacp(*args, **kwargs):
return {k: v(*args, **kwargs)
for k, v in jacp_dict.items()}
Expand Down Expand Up @@ -1327,6 +1328,7 @@ def function(self, t, y, inputs):
else:
return self._function(t, y, inputs=inputs, known_evals={})[0]


class SensitivityCallable:
"""A class that will be called by the solver when integrating"""

Expand Down Expand Up @@ -1355,6 +1357,7 @@ def function(self, t, y, inputs):
self._function(t, y, inputs=inputs, known_evals={})
return {k: v[0] for k, v in ret_with_known_evals.items()}


class Residuals(SolverCallable):
"""Returns information about residuals at time t and state y"""

Expand Down
2 changes: 0 additions & 2 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,6 @@ def sensfn(resvalS, t, y, yp, yS, ypS):
"""

np = len(resvalS)
n = resvalS[0].shape[0]
dFdy = model.jacobian_eval(t, y, inputs)
dFdyd = mass_matrix
dFdp = model.sensitivities_eval(t, y, inputs)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def exact_diff_b(y, a, b):
use_inputs = inputs

sens = model.sensitivities_eval(
t, y, use_inputs
t, y, use_inputs
)
np.testing.assert_allclose(
sens['a'],
Expand Down

0 comments on commit 47c5345

Please sign in to comment.