From 306b98e595b31ecbe24f86968b7609a31f37232d Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Fri, 7 Feb 2020 10:54:56 +0000 Subject: [PATCH] #759 solver checks for heaviside functions and adds appropriate discontinuity events --- pybamm/solvers/base_solver.py | 13 +++ tests/unit/test_solvers/test_scipy_solver.py | 110 +++++++++++-------- 2 files changed, 80 insertions(+), 43 deletions(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 3561d8f889..2038c59649 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -212,6 +212,19 @@ def report(string): jac_call = None return func, func_call, jac_call + # Check for heaviside functions in rhs and algebraic and add discontinuity + # events if these exist. + # Note: only checks for the case of t < X, t <= X, X < t, or X <= t + for symbol in model.concatenated_rhs.pre_order(): + if isinstance(symbol, pybamm.Heaviside): + if symbol.right.id == pybamm.t.id: + expr = symbol.left + elif symbol.left.id == pybamm.t.id: + expr = symbol.right + + model.events.append(pybamm.Event(str(symbol), expr.new_copy(), + pybamm.EventType.DISCONTINUITY)) + # Process rhs, algebraic and event expressions rhs, rhs_eval, jac_rhs = process(model.concatenated_rhs, "RHS") algebraic, algebraic_eval, jac_algebraic = process( diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index 22a11eff93..0068835b15 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -139,65 +139,89 @@ def jacobian(t, y): ) def test_model_solver_ode_nonsmooth(self): - model = pybamm.BaseModel() whole_cell = ["negative electrode", "separator", "positive electrode"] var1 = pybamm.Variable("var1", domain=whole_cell) discontinuity = 0.6 + # Create three different models with the same solution, each expressing the + # discontinuity in a different way + + # first model explicitly adds a discontinuity event def nonsmooth_rate(t): return 0.1 * (t < discontinuity) + 0.1 rate = pybamm.Function(nonsmooth_rate, pybamm.t) - model.rhs = {var1: rate * var1} - model.initial_conditions = {var1: 1} - model.events = [ + model1 = pybamm.BaseModel() + model1.rhs = {var1: rate * var1} + model1.initial_conditions = {var1: 1} + model1.events = [ pybamm.Event("var1 = 1.5", pybamm.min(var1 - 1.5)), pybamm.Event("nonsmooth rate", pybamm.Scalar(discontinuity), pybamm.EventType.DISCONTINUITY ), ] - disc = get_discretisation_for_testing() - disc.process_model(model) - # Solve - solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8) + # second model implicitly adds a discontinuity event via a heaviside function + model2 = pybamm.BaseModel() + model2.rhs = {var1: (0.1 * (pybamm.t < discontinuity) + 0.1) * var1} + model2.initial_conditions = {var1: 1} + model2.events = [ + pybamm.Event("var1 = 1.5", pybamm.min(var1 - 1.5)), + ] - # create two time series, one without a time point on the discontinuity, - # and one with - t_eval1 = np.linspace(0, 5, 10) - t_eval2 = np.insert(t_eval1, - np.searchsorted(t_eval1, discontinuity), - discontinuity) - solution1 = solver.solve(model, t_eval1) - solution2 = solver.solve(model, t_eval2) - - # check time vectors - for solution in [solution1, solution2]: - # time vectors are ordered - self.assertTrue(np.all(solution.t[:-1] <= solution.t[1:])) - - # time value before and after discontinuity is an epsilon away - dindex = np.searchsorted(solution.t, discontinuity) - value_before = solution.t[dindex - 1] - value_after = solution.t[dindex] - self.assertEqual(value_before + sys.float_info.epsilon, discontinuity) - self.assertEqual(value_after - sys.float_info.epsilon, discontinuity) - - # both solution time vectors should have same number of points - self.assertEqual(len(solution1.t), len(solution2.t)) - - # check solution - for solution in [solution1, solution2]: - np.testing.assert_array_less(solution.y[0], 1.5) - np.testing.assert_array_less(solution.y[-1], 2.5) - var1_soln = np.exp(0.2 * solution.t) - y0 = np.exp(0.2 * discontinuity) - var1_soln[solution.t > discontinuity] = \ - y0 * np.exp( - 0.1 * (solution.t[solution.t > discontinuity] - discontinuity) - ) - np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06) + # third model implicitly adds a discontinuity event via another heaviside + # function + model3 = pybamm.BaseModel() + model3.rhs = {var1: (-0.1 * (discontinuity < pybamm.t) + 0.2) * var1} + model3.initial_conditions = {var1: 1} + model3.events = [ + pybamm.Event("var1 = 1.5", pybamm.min(var1 - 1.5)), + ] + + for model in [model1, model2, model3]: + + disc = get_discretisation_for_testing() + disc.process_model(model) + + # Solve + solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8) + + # create two time series, one without a time point on the discontinuity, + # and one with + t_eval1 = np.linspace(0, 5, 10) + t_eval2 = np.insert(t_eval1, + np.searchsorted(t_eval1, discontinuity), + discontinuity) + solution1 = solver.solve(model, t_eval1) + solution2 = solver.solve(model, t_eval2) + + # check time vectors + for solution in [solution1, solution2]: + # time vectors are ordered + self.assertTrue(np.all(solution.t[:-1] <= solution.t[1:])) + + # time value before and after discontinuity is an epsilon away + dindex = np.searchsorted(solution.t, discontinuity) + value_before = solution.t[dindex - 1] + value_after = solution.t[dindex] + self.assertEqual(value_before + sys.float_info.epsilon, discontinuity) + self.assertEqual(value_after - sys.float_info.epsilon, discontinuity) + + # both solution time vectors should have same number of points + self.assertEqual(len(solution1.t), len(solution2.t)) + + # check solution + for solution in [solution1, solution2]: + np.testing.assert_array_less(solution.y[0], 1.5) + np.testing.assert_array_less(solution.y[-1], 2.5) + var1_soln = np.exp(0.2 * solution.t) + y0 = np.exp(0.2 * discontinuity) + var1_soln[solution.t > discontinuity] = \ + y0 * np.exp( + 0.1 * (solution.t[solution.t > discontinuity] - discontinuity) + ) + np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06) def test_model_step_python(self): # Create model