Skip to content

Commit

Permalink
Merge pull request #1416 from pybamm-team/issue-1414-experiment-bug
Browse files Browse the repository at this point in the history
#1414 fix bug in set_initial_conditions_from
  • Loading branch information
valentinsulzer authored Mar 5, 2021
2 parents 900e52e + d41d557 commit e1c51a2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
4 changes: 1 addition & 3 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,6 @@ def set_initial_conditions_from(self, solution, inplace=True):

if isinstance(solution, pybamm.Solution):
solution = solution.last_state
else:
solution = pybamm.FuzzyDict(solution)
for var, equation in model.initial_conditions.items():
if isinstance(var, pybamm.Variable):
try:
Expand All @@ -404,7 +402,7 @@ def set_initial_conditions_from(self, solution, inplace=True):
elif final_state.ndim == 2:
final_state_eval = final_state[:, -1]
elif final_state.ndim == 3:
final_state_eval = final_state[:, :, -1].flatten()
final_state_eval = final_state[:, :, -1].flatten(order="F")
else:
raise NotImplementedError("Variable must be 0D, 1D, or 2D")
model.initial_conditions[var] = pybamm.Vector(final_state_eval)
Expand Down
4 changes: 3 additions & 1 deletion pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
y_diff = casadi.horzcat(*[y0_diff] * len(t_eval))
y_sol = casadi.vertcat(y_diff, y_alg)
# Return solution object (no events, so pass None to t_event, y_event)
sol = pybamm.Solution(t_eval, y_sol, model, inputs_dict, termination="success")
sol = pybamm.Solution(
[t_eval], y_sol, model, inputs_dict, termination="success"
)
sol.integration_time = integration_time
return sol
13 changes: 12 additions & 1 deletion tests/unit/test_experiments/test_simulation_with_experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Test setting up a simulation with an experiment
#
import casadi
import pybamm
import numpy as np
import unittest
Expand Down Expand Up @@ -82,12 +83,22 @@ def test_run_experiment(self):
)
]
)
model = pybamm.lithium_ion.SPM()
model = pybamm.lithium_ion.DFN()
sim = pybamm.Simulation(model, experiment=experiment)
sol = sim.solve()
self.assertEqual(sol.termination, "final time")
self.assertEqual(len(sol.cycles), 1)

for i, step in enumerate(sol.cycles[0].steps[:-1]):
len_rhs = sol.all_models[0].concatenated_rhs.size
y_left = step.all_ys[-1][:len_rhs, -1]
if isinstance(y_left, casadi.DM):
y_left = y_left.full()
y_right = sol.cycles[0].steps[i + 1].all_ys[0][:len_rhs, 0]
if isinstance(y_right, casadi.DM):
y_right = y_right.full()
np.testing.assert_array_equal(y_left.flatten(), y_right.flatten())

# Solve again starting from solution
sol2 = sim.solve(starting_solution=sol)
self.assertEqual(sol2.termination, "final time")
Expand Down

0 comments on commit e1c51a2

Please sign in to comment.