Skip to content

Commit

Permalink
#759 change models to support new Event class
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jan 27, 2020
1 parent 4618751 commit 9e9db38
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 32 deletions.
6 changes: 6 additions & 0 deletions pybamm/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION):
self._expression = expression
self._event_type = event_type

def evaluate(self, t=None, y=None, u=None, known_evals=None):
"""
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
"""
return self._expression(t, y, u, known_evals)

def __str__(self):
return self._name

Expand Down
4 changes: 2 additions & 2 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,9 @@ def get_termination_reason(self, solution, events):
elif solution.termination == "event":
# Get final event value
final_event_values = {}
for name, event in events.items():
for event in events:
y_event = add_external(solution.y_event, self.y_pad, self.y_ext)
final_event_values[name] = abs(
final_event_values[event.name] = abs(
event.evaluate(solution.t_event, y_event)
)
termination_event = min(final_event_values, key=final_event_values.get)
Expand Down
38 changes: 24 additions & 14 deletions pybamm/solvers/dae_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compute_solution(self, model, t_eval, inputs=None):
solution.inputs = inputs

# Identify the event that caused termination
termination = self.get_termination_reason(solution, self.events)
termination = self.get_termination_reason(solution, self.termination_events)

return solution, solve_time, termination

Expand Down Expand Up @@ -135,7 +135,8 @@ def set_up(self, model, inputs=None):
pybamm.logger.info("Simplifying algebraic")
concatenated_algebraic = simp.simplify(concatenated_algebraic)
pybamm.logger.info("Simplifying events")
events = {name: simp.simplify(event) for name, event in events.items()}
for event in events:
event.expression = simp.simplify(event.expression)

if model.use_jacobian:
# Create Jacobian from concatenated rhs and algebraic
Expand Down Expand Up @@ -177,9 +178,8 @@ def set_up(self, model, inputs=None):
pybamm.logger.info("Converting algebraic to python")
concatenated_algebraic = pybamm.EvaluatorPython(concatenated_algebraic)
pybamm.logger.info("Converting events to python")
events = {
name: pybamm.EvaluatorPython(event) for name, event in events.items()
}
for event in events:
event.expression = pybamm.EvaluatorPython(event.expression)

# Calculate consistent initial conditions for the algebraic equations
rhs = Rhs(concatenated_rhs.evaluate)
Expand Down Expand Up @@ -212,8 +212,11 @@ def get_event_class(event):
self.residuals = Residuals(
model, concatenated_rhs.evaluate, concatenated_algebraic.evaluate
)
self.events = events
self.event_funs = [get_event_class(event) for event in events.values()]
self.termination_events = [
events for event in events
if event.event_type == pybamm.EventType.TERMINATION
]
self.termination_funs = [get_event_class(event) for event in termination_events]
self.jacobian = jacobian

pybamm.logger.info("Finish solver set-up")
Expand Down Expand Up @@ -261,10 +264,11 @@ def set_up_casadi(self, model, inputs=None):
)
all_states = casadi.vertcat(concatenated_rhs, concatenated_algebraic)
pybamm.logger.info("Converting events to CasADi")
casadi_events = {
name: event.to_casadi(t_casadi, y_casadi_w_ext, u_casadi)
for name, event in model.events.items()
}
casadi_termination_events = [
event.expression.to_casadi(t_casadi, y_casadi_w_ext, u_casadi)
for event in model.events
if event.event_type == pybamm.EventType.TERMINATION
]

# Create functions to evaluate rhs and algebraic
u_casadi_stacked = casadi.vertcat(*[u for u in u_casadi.values()])
Expand Down Expand Up @@ -337,8 +341,14 @@ def get_event_class(event):
self.rhs = rhs
self.algebraic = algebraic
self.residuals = ResidualsCasadi(model, all_states_fn)
self.events = model.events
self.event_funs = [get_event_class(event) for event in casadi_events.values()]
self.termination_events = [
event for event in model.events
if event.event_type == pybamm.EventType.TERMINATION
]

self.termination_funs = [
get_event_class(event) for event in casadi_termination_events
]
self.jacobian = jacobian

# Save CasADi functions for the CasADi solver
Expand Down Expand Up @@ -371,7 +381,7 @@ def set_inputs_and_external(self, inputs):
self.algebraic.set_inputs(inputs)
self.residuals.set_pad_ext(self.y_pad, self.y_ext)
self.residuals.set_inputs(inputs)
for evnt in self.event_funs:
for evnt in self.termination_funs:
evnt.set_pad_ext(self.y_pad, self.y_ext)
evnt.set_inputs(inputs)
if self.jacobian:
Expand Down
43 changes: 27 additions & 16 deletions pybamm/solvers/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_solution(self, model, t_eval, inputs=None):
self.dydt,
self.y0,
t_eval,
events=self.event_funs,
events=self.termination_funs,
mass_matrix=model.mass_matrix.entries,
jacobian=self.jacobian,
)
Expand All @@ -61,7 +61,7 @@ def compute_solution(self, model, t_eval, inputs=None):
solution.inputs = inputs

# Identify the event that caused termination
termination = self.get_termination_reason(solution, self.events)
termination = self.get_termination_reason(solution, self.termination_events)

return solution, solve_time, termination

Expand Down Expand Up @@ -104,7 +104,8 @@ def set_up(self, model, inputs=None):
concatenated_rhs = simp.simplify(concatenated_rhs)

pybamm.logger.info("Simplifying events")
events = {name: simp.simplify(event) for name, event in events.items()}
for event in events:
event.expression = simp.simplify(event.expression)

y0 = model.concatenated_initial_conditions[:, 0]

Expand Down Expand Up @@ -132,9 +133,8 @@ def set_up(self, model, inputs=None):
pybamm.logger.info("Converting RHS to python")
concatenated_rhs = pybamm.EvaluatorPython(concatenated_rhs)
pybamm.logger.info("Converting events to python")
events = {
name: pybamm.EvaluatorPython(event) for name, event in events.items()
}
for event in events:
event.expression = pybamm.EvaluatorPython(event.expression)

# Create event-dependent function to evaluate events
def get_event_class(event):
Expand All @@ -151,8 +151,13 @@ def get_event_class(event):
# etc. The expression tree versions of these are attributes of the model
self.y0 = y0
self.dydt = Dydt(model, concatenated_rhs.evaluate)
self.events = events
self.event_funs = [get_event_class(event) for event in events.values()]
self.termination_events = [
event for event in events
if event.event_type == pybamm.EventType.TERMINATION
]
self.termination_funs = [
get_event_class(event) for event in self.termination_events
]
self.jacobian = jacobian

pybamm.logger.info("Finish solver set-up")
Expand Down Expand Up @@ -195,11 +200,12 @@ def set_up_casadi(self, model, inputs=None):
concatenated_rhs = model.concatenated_rhs.to_casadi(
t_casadi, y_casadi_w_ext, u_casadi
)
pybamm.logger.info("Converting events to CasADi")
casadi_events = {
name: event.to_casadi(t_casadi, y_casadi_w_ext, u_casadi)
for name, event in model.events.items()
}
pybamm.logger.info("Converting termination events to CasADi")
casadi_termination_events = [
event.expression.to_casadi(t_casadi, y_casadi_w_ext, u_casadi)
for event in model.events
if event.event_type == pybamm.EventType.TERMINATION
]

# Create function to evaluate rhs
u_casadi_stacked = casadi.vertcat(*[u for u in u_casadi.values()])
Expand Down Expand Up @@ -230,8 +236,13 @@ def get_event_class(event):
# Add the solver attributes
self.y0 = y0
self.dydt = DydtCasadi(model, concatenated_rhs_fn)
self.events = model.events
self.event_funs = [get_event_class(event) for event in casadi_events.values()]
self.termination_events = [
event for event in model.events
if event.event_type == pybamm.EventType.TERMINATION
]
self.termination_funs = [
get_event_class(event) for event in casadi_termination_events
]
self.jacobian = jacobian

def set_inputs_and_external(self, inputs):
Expand All @@ -247,7 +258,7 @@ def set_inputs_and_external(self, inputs):
"""
self.dydt.set_pad_ext(self.y_pad, self.y_ext)
self.dydt.set_inputs(inputs)
for evnt in self.event_funs:
for evnt in self.termination_funs:
evnt.set_pad_ext(self.y_pad, self.y_ext)
evnt.set_inputs(inputs)
if self.jacobian:
Expand Down

0 comments on commit 9e9db38

Please sign in to comment.