diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index fd0ec805c9..ff8bf92853 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -685,7 +685,10 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals) result = result.reshape(result.shape[0], -1) - return result + if known_evals is not None: + return result, known_evals + else: + return result class EvaluatorJaxSensitivities: @@ -704,4 +707,7 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): # execute code result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals) - return result + if known_evals is not None: + return result, known_evals + else: + return result