From fc6ec22c438a5803a565f48c7450e7cf1634a630 Mon Sep 17 00:00:00 2001 From: tomtranter Date: Thu, 23 Jan 2020 18:23:32 +0000 Subject: [PATCH 1/2] Add test for solution evaluation with Electrode height input --- tests/unit/test_solvers/test_solution.py | 30 ++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 32e81cfbc9..8ca99249a6 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -95,6 +95,36 @@ def test_save(self): self.assertEqual(solution.model.name, solution_load.model.name) np.testing.assert_array_equal(solution["c"].entries, solution_load["c"].entries) + def test_solution_evals_with_inputs(self): + model = pybamm.lithium_ion.SPM() + geometry = model.default_geometry + param = model.default_parameter_values + param.update( + { + "Electrode height [m]": "[input]", + } + ) + param.process_model(model) + param.process_geometry(geometry) + var = pybamm.standard_spatial_vars + var_pts = {var.x_n: 5, var.x_s: 5, var.x_p: 5, var.r_n: 10, var.r_p: 10} + spatial_methods = model.default_spatial_methods + solver = model.default_solver + sim = pybamm.Simulation( + model=model, + geometry=geometry, + parameter_values=param, + var_pts=var_pts, + spatial_methods=spatial_methods, + solver=solver, + ) + inputs = { + 'Electrode height [m]': 0.1 + } + sim.solve(t_eval=np.linspace(0, 0.01, 10), inputs=inputs) + time = sim.solution['Time [h]'](sim.solution.t) + self.assertEqual(len(time), 10) + if __name__ == "__main__": print("Add -v for more debug output") From 30661f48fca014b3a6c0f9af5aae0ca8c6ab54f9 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Thu, 23 Jan 2020 13:35:36 -0500 Subject: [PATCH 2/2] #793 fix test --- pybamm/expression_tree/functions.py | 2 +- pybamm/processed_variable.py | 24 ++++++++++++++---------- tests/unit/test_solvers/test_solution.py | 12 +++--------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index be98461a73..5be8e15a1a 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -159,7 +159,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None): evaluated_children = [None] * len(self.children) for i, child in enumerate(self.children): evaluated_children[i], known_evals = child.evaluate( - t, y, known_evals=known_evals + t, y, u, known_evals=known_evals ) known_evals[self.id] = self._function_evaluate(evaluated_children) return known_evals[self.id], known_evals diff --git a/pybamm/processed_variable.py b/pybamm/processed_variable.py index 77aef5bfef..847fdef809 100644 --- a/pybamm/processed_variable.py +++ b/pybamm/processed_variable.py @@ -41,12 +41,14 @@ def __init__(self, base_variable, solution, known_evals=None): self.base_eval, self.known_evals[solution.t[0]] = base_variable.evaluate( solution.t[0], solution.y[:, 0], - solution.inputs, + {name: inp[0] for name, inp in solution.inputs.items()}, known_evals=self.known_evals[solution.t[0]], ) else: self.base_eval = base_variable.evaluate( - solution.t[0], solution.y[:, 0], solution.inputs + solution.t[0], + solution.y[:, 0], + {name: inp[0] for name, inp in solution.inputs.items()}, ) # handle 2D (in space) finite element variables differently @@ -90,14 +92,14 @@ def initialise_1D(self): # Evaluate the base_variable index-by-index for idx in range(len(self.t_sol)): t = self.t_sol[idx] + u = self.u_sol[:, idx] + inputs = {name: inp[0] for name, inp in self.inputs.items()} if self.known_evals: entries[idx], self.known_evals[t] = self.base_variable.evaluate( - t, self.u_sol[:, idx], self.inputs, known_evals=self.known_evals[t] + t, u, inputs, known_evals=self.known_evals[t] ) else: - entries[idx] = self.base_variable.evaluate( - t, self.u_sol[:, idx], self.inputs - ) + entries[idx] = self.base_variable.evaluate(t, u, inputs) # No discretisation provided, or variable has no domain (function of t only) self._interpolation_function = interp.interp1d( @@ -115,14 +117,15 @@ def initialise_2D(self): for idx in range(len(self.t_sol)): t = self.t_sol[idx] u = self.u_sol[:, idx] + inputs = {name: inp[0] for name, inp in self.inputs.items()} if self.known_evals: eval_and_known_evals = self.base_variable.evaluate( - t, u, self.inputs, known_evals=self.known_evals[t] + t, u, inputs, known_evals=self.known_evals[t] ) entries[:, idx] = eval_and_known_evals[0][:, 0] self.known_evals[t] = eval_and_known_evals[1] else: - entries[:, idx] = self.base_variable.evaluate(t, u, self.inputs)[:, 0] + entries[:, idx] = self.base_variable.evaluate(t, u, inputs)[:, 0] # Process the discretisation to get x values nodes = self.mesh[0].nodes @@ -218,9 +221,10 @@ def initialise_3D(self): for idx in range(len(self.t_sol)): t = self.t_sol[idx] u = self.u_sol[:, idx] + inputs = {name: inp[0] for name, inp in self.inputs.items()} if self.known_evals: eval_and_known_evals = self.base_variable.evaluate( - t, u, self.inputs, known_evals=self.known_evals[t] + t, u, inputs, known_evals=self.known_evals[t] ) entries[:, :, idx] = np.reshape( eval_and_known_evals[0], @@ -230,7 +234,7 @@ def initialise_3D(self): self.known_evals[t] = eval_and_known_evals[1] else: entries[:, :, idx] = np.reshape( - self.base_variable.evaluate(t, u, self.inputs), + self.base_variable.evaluate(t, u, inputs), [first_dim_size, second_dim_size], order="F", ) diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 8ca99249a6..7aa4db34fe 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -99,11 +99,7 @@ def test_solution_evals_with_inputs(self): model = pybamm.lithium_ion.SPM() geometry = model.default_geometry param = model.default_parameter_values - param.update( - { - "Electrode height [m]": "[input]", - } - ) + param.update({"Electrode height [m]": "[input]"}) param.process_model(model) param.process_geometry(geometry) var = pybamm.standard_spatial_vars @@ -118,11 +114,9 @@ def test_solution_evals_with_inputs(self): spatial_methods=spatial_methods, solver=solver, ) - inputs = { - 'Electrode height [m]': 0.1 - } + inputs = {"Electrode height [m]": 0.1} sim.solve(t_eval=np.linspace(0, 0.01, 10), inputs=inputs) - time = sim.solution['Time [h]'](sim.solution.t) + time = sim.solution["Time [h]"](sim.solution.t) self.assertEqual(len(time), 10)