Skip to content

Commit

Permalink
#664 add some more functions
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Oct 18, 2019
1 parent a5236b7 commit 69aac74
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 13 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ exclude=
lib,
lib64,
share,
pyvenv.cfg
pyvenv.cfg,
third-party
ignore=
# False positive for white space before ':' on list slice
# black should format these correctly
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/compare_lithium_ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
# load models
options = {"thermal": "isothermal"}
models = [
pybamm.lithium_ion.SPM(options),
# pybamm.lithium_ion.SPM(options),
# pybamm.lithium_ion.SPMe(options),
# pybamm.lithium_ion.DFN(options),
pybamm.lithium_ion.DFN(options)
]


Expand Down
38 changes: 32 additions & 6 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
import pybamm
import casadi
import numpy as np


class CasadiConverter(object):
Expand Down Expand Up @@ -51,7 +52,7 @@ def _convert(self, symbol, t, y):
converted_left = self.convert(left, t, y)
converted_right = self.convert(right, t, y)
if isinstance(symbol, pybamm.Outer):
return casadi.outer_prod(converted_left, converted_right)
return casadi.kron(converted_left, converted_right)
else:
# _binary_evaluate defined in derived classes for specific rules
return symbol._binary_evaluate(converted_left, converted_right)
Expand All @@ -63,16 +64,41 @@ def _convert(self, symbol, t, y):
return symbol._unary_evaluate(converted_child)

elif isinstance(symbol, pybamm.Function):
converted_children = [None] * len(symbol.children)
for i, child in enumerate(symbol.children):
converted_children[i] = self.convert(child, t, y)
return symbol._function_evaluate(converted_children)
converted_children = [
self.convert(child, t, y) for child in symbol.children
]
if symbol.function == np.min:
return casadi.mmin(*converted_children)
elif symbol.function == np.max:
return casadi.mmax(*converted_children)
else:
return symbol._function_evaluate(converted_children)

elif isinstance(symbol, pybamm.Concatenation):
converted_children = [
self.convert(child, t, y) for child in symbol.children
]
return symbol._concatenation_evaluate(converted_children)
if isinstance(symbol, (pybamm.NumpyConcatenation, pybamm.SparseStack)):
return casadi.vertcat(*converted_children)
# DomainConcatenation specifies a particular ordering for the concatenation,
# which we must follow
elif isinstance(symbol, pybamm.DomainConcatenation):
slice_starts = []
all_child_vectors = []
for i in range(symbol.secondary_dimensions_npts):
child_vectors = []
for child_var, slices in zip(
converted_children, symbol._children_slices
):
for child_dom, child_slice in slices.items():
slice_starts.append(symbol._slices[child_dom][i].start)
child_vectors.append(
child_var[child_slice[i].start : child_slice[i].stop]
)
all_child_vectors.extend(
[v for _, v in sorted(zip(slice_starts, child_vectors))]
)
return casadi.vertcat(*all_child_vectors)

else:
raise TypeError(
Expand Down
10 changes: 7 additions & 3 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ def set_up(self, model):

t = casadi.SX.sym("t")
y_diff = casadi.SX.sym("y_diff", len(model.concatenated_rhs.evaluate(0, y0)))
# y_alg = casadi.SX.sym("y_alg")
y_alg = casadi.SX.sym(
"y_alg", len(model.concatenated_algebraic.evaluate(0, y0))
)
y = casadi.vertcat(y_diff, y_alg)
# create simplified rhs and event expressions
concatenated_rhs = model.concatenated_rhs.to_casadi(t, y_diff)
events = model.events.to_casadi(t, y_diff)
concatenated_rhs = model.concatenated_rhs.to_casadi(t, y)
concatenated_algebraic = model.concatenated_algebraic.to_casadi(t, y)
events = {name: event.to_casadi(t, y) for name, event in model.events.items()}

# Create function to evaluate rhs
def dydt(t, y):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pybamm
import unittest
from tests import get_discretisation_for_testing
from tests import get_mesh_for_testing, get_1p1d_discretisation_for_testing


class TestCasadiConverter(unittest.TestCase):
Expand Down Expand Up @@ -68,6 +68,45 @@ def test_convert_array_symbols(self):
casadi.is_equal(pybamm_y.to_casadi(casadi_t, casadi_y), casadi_y)
)

# outer product
outer = pybamm.Outer(pybamm_a, pybamm_a)
self.assertTrue(casadi.is_equal(outer.to_casadi(), casadi.SX(outer.evaluate())))

def test_special_functions(self):
a = np.array([1, 2, 3, 4, 5])
pybamm_a = pybamm.Array(a)
self.assertEqual(pybamm.min(pybamm_a).to_casadi(), casadi.SX(1))

def test_concatenations(self):
y = np.linspace(0, 1, 10)[:, np.newaxis]
a = pybamm.Vector(y)
b = pybamm.Scalar(16)
c = pybamm.Scalar(3)
conc = pybamm.NumpyConcatenation(a, b, c)
self.assertTrue(casadi.is_equal(conc.to_casadi(), casadi.SX(conc.evaluate())))

# Domain concatenation
mesh = get_mesh_for_testing()
a_dom = ["negative electrode"]
b_dom = ["positive electrode"]
a = 2 * pybamm.Vector(np.ones_like(mesh[a_dom[0]][0].nodes), domain=a_dom)
b = pybamm.Vector(np.ones_like(mesh[b_dom[0]][0].nodes), domain=b_dom)
conc = pybamm.DomainConcatenation([b, a], mesh)
self.assertTrue(casadi.is_equal(conc.to_casadi(), casadi.SX(conc.evaluate())))

# 2d
disc = get_1p1d_discretisation_for_testing()
a = pybamm.Variable("a", domain=a_dom)
b = pybamm.Variable("b", domain=b_dom)
conc = pybamm.Concatenation(a, b)
disc.set_variable_slices([conc])
expr = disc.process_symbol(conc)
y = casadi.SX.sym("y", expr.size)
x = expr.to_casadi(None, y)
f = casadi.Function("f", [x], [x])
y_eval = np.linspace(0, 1, expr.size)
self.assertTrue(casadi.is_equal(f(y_eval), casadi.SX(expr.evaluate(y=y_eval))))


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit 69aac74

Please sign in to comment.