Skip to content

Commit

Permalink
Merge pull request #794 from pybamm-team/issue-793-solution-evals-wit…
Browse files Browse the repository at this point in the history
…h-inputs

Fix issue 793 solution evaluation with Electrode height input
  • Loading branch information
valentinsulzer authored Jan 23, 2020
2 parents ae1e6ae + 30661f4 commit f772517
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
evaluated_children = [None] * len(self.children)
for i, child in enumerate(self.children):
evaluated_children[i], known_evals = child.evaluate(
t, y, known_evals=known_evals
t, y, u, known_evals=known_evals
)
known_evals[self.id] = self._function_evaluate(evaluated_children)
return known_evals[self.id], known_evals
Expand Down
24 changes: 14 additions & 10 deletions pybamm/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def __init__(self, base_variable, solution, known_evals=None):
self.base_eval, self.known_evals[solution.t[0]] = base_variable.evaluate(
solution.t[0],
solution.y[:, 0],
solution.inputs,
{name: inp[0] for name, inp in solution.inputs.items()},
known_evals=self.known_evals[solution.t[0]],
)
else:
self.base_eval = base_variable.evaluate(
solution.t[0], solution.y[:, 0], solution.inputs
solution.t[0],
solution.y[:, 0],
{name: inp[0] for name, inp in solution.inputs.items()},
)

# handle 2D (in space) finite element variables differently
Expand Down Expand Up @@ -90,14 +92,14 @@ def initialise_1D(self):
# Evaluate the base_variable index-by-index
for idx in range(len(self.t_sol)):
t = self.t_sol[idx]
u = self.u_sol[:, idx]
inputs = {name: inp[0] for name, inp in self.inputs.items()}
if self.known_evals:
entries[idx], self.known_evals[t] = self.base_variable.evaluate(
t, self.u_sol[:, idx], self.inputs, known_evals=self.known_evals[t]
t, u, inputs, known_evals=self.known_evals[t]
)
else:
entries[idx] = self.base_variable.evaluate(
t, self.u_sol[:, idx], self.inputs
)
entries[idx] = self.base_variable.evaluate(t, u, inputs)

# No discretisation provided, or variable has no domain (function of t only)
self._interpolation_function = interp.interp1d(
Expand All @@ -115,14 +117,15 @@ def initialise_2D(self):
for idx in range(len(self.t_sol)):
t = self.t_sol[idx]
u = self.u_sol[:, idx]
inputs = {name: inp[0] for name, inp in self.inputs.items()}
if self.known_evals:
eval_and_known_evals = self.base_variable.evaluate(
t, u, self.inputs, known_evals=self.known_evals[t]
t, u, inputs, known_evals=self.known_evals[t]
)
entries[:, idx] = eval_and_known_evals[0][:, 0]
self.known_evals[t] = eval_and_known_evals[1]
else:
entries[:, idx] = self.base_variable.evaluate(t, u, self.inputs)[:, 0]
entries[:, idx] = self.base_variable.evaluate(t, u, inputs)[:, 0]

# Process the discretisation to get x values
nodes = self.mesh[0].nodes
Expand Down Expand Up @@ -218,9 +221,10 @@ def initialise_3D(self):
for idx in range(len(self.t_sol)):
t = self.t_sol[idx]
u = self.u_sol[:, idx]
inputs = {name: inp[0] for name, inp in self.inputs.items()}
if self.known_evals:
eval_and_known_evals = self.base_variable.evaluate(
t, u, self.inputs, known_evals=self.known_evals[t]
t, u, inputs, known_evals=self.known_evals[t]
)
entries[:, :, idx] = np.reshape(
eval_and_known_evals[0],
Expand All @@ -230,7 +234,7 @@ def initialise_3D(self):
self.known_evals[t] = eval_and_known_evals[1]
else:
entries[:, :, idx] = np.reshape(
self.base_variable.evaluate(t, u, self.inputs),
self.base_variable.evaluate(t, u, inputs),
[first_dim_size, second_dim_size],
order="F",
)
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_solvers/test_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,30 @@ def test_save(self):
self.assertEqual(solution.model.name, solution_load.model.name)
np.testing.assert_array_equal(solution["c"].entries, solution_load["c"].entries)

def test_solution_evals_with_inputs(self):
model = pybamm.lithium_ion.SPM()
geometry = model.default_geometry
param = model.default_parameter_values
param.update({"Electrode height [m]": "[input]"})
param.process_model(model)
param.process_geometry(geometry)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 5, var.x_s: 5, var.x_p: 5, var.r_n: 10, var.r_p: 10}
spatial_methods = model.default_spatial_methods
solver = model.default_solver
sim = pybamm.Simulation(
model=model,
geometry=geometry,
parameter_values=param,
var_pts=var_pts,
spatial_methods=spatial_methods,
solver=solver,
)
inputs = {"Electrode height [m]": 0.1}
sim.solve(t_eval=np.linspace(0, 0.01, 10), inputs=inputs)
time = sim.solution["Time [h]"](sim.solution.t)
self.assertEqual(len(time), 10)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit f772517

Please sign in to comment.