Skip to content

Commit

Permalink
#1429 concatenation simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 18, 2021
1 parent 1da9f11 commit e29986b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 25 deletions.
91 changes: 66 additions & 25 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,57 @@ def simplify_elementwise_binary_broadcasts(left, right):
return left, right


def simplified_binary_broadcast_concatenation(left, right, operator):
"""
Check if there are concatenations or broadcasts that we can commute the operator
with
"""
# Broadcast commutes with elementwise operators
if isinstance(left, pybamm.Broadcast) and right.domain == []:
return left._unary_new_copy(operator(left.orphans[0], right))
elif isinstance(right, pybamm.Broadcast) and left.domain == []:
return right._unary_new_copy(operator(left, right.orphans[0]))

# Concatenation commutes with elementwise operators
# If one of the sides is constant then commute concatenation with the operator
# Don't do this if left has any Variable or StateVector objects as these will
# be simplified differently later on
if isinstance(left, pybamm.Concatenation) and not any(
isinstance(child, (pybamm.Variable, pybamm.StateVector))
for child in left.children
):
if right.evaluates_to_constant_number():
return pybamm.concatenation(
*[operator(child, right) for child in left.orphans]
)
elif (
isinstance(right, pybamm.Concatenation)
and all(child.is_constant() for child in left.children)
and all(child.is_constant() for child in right.children)
):
return pybamm.concatenation(
*[
operator(left_child, right_child)
for left_child, right_child in zip(left.orphans, right.orphans)
]
)
if isinstance(right, pybamm.Concatenation) and not any(
isinstance(child, (pybamm.Variable, pybamm.StateVector))
for child in right.children
):
if left.evaluates_to_constant_number():
return pybamm.concatenation(
*[operator(left, child) for child in right.orphans]
)


def simplified_power(left, right):
left, right = simplify_elementwise_binary_broadcasts(left, right)

# Broadcast commutes with power operator
if isinstance(left, pybamm.Broadcast) and right.domain == []:
return left._unary_new_copy(left.orphans[0] ** right)
elif isinstance(right, pybamm.Broadcast) and left.domain == []:
return right._unary_new_copy(left ** right.orphans[0])
# Check for Concatenations and Broadcasts
out = simplified_binary_broadcast_concatenation(left, right, simplified_power)
if out is not None:
return out

# anything to the power of zero is one
if pybamm.is_scalar_zero(right):
Expand Down Expand Up @@ -733,11 +776,10 @@ def simplified_addition(left, right):
"""
left, right = simplify_elementwise_binary_broadcasts(left, right)

# Broadcast commutes with addition operator
if isinstance(left, pybamm.Broadcast) and right.domain == []:
return left._unary_new_copy(left.orphans[0] + right)
elif isinstance(right, pybamm.Broadcast) and left.domain == []:
return right._unary_new_copy(left + right.orphans[0])
# Check for Concatenations and Broadcasts
out = simplified_binary_broadcast_concatenation(left, right, simplified_addition)
if out is not None:
return out

# anything added by a scalar zero returns the other child
elif pybamm.is_scalar_zero(left):
Expand Down Expand Up @@ -828,11 +870,10 @@ def simplified_subtraction(left, right):
"""
left, right = simplify_elementwise_binary_broadcasts(left, right)

# Broadcast commutes with subtraction operator
if isinstance(left, pybamm.Broadcast) and right.domain == []:
return left._unary_new_copy(left.orphans[0] - right)
elif isinstance(right, pybamm.Broadcast) and left.domain == []:
return right._unary_new_copy(left - right.orphans[0])
# Check for Concatenations and Broadcasts
out = simplified_binary_broadcast_concatenation(left, right, simplified_subtraction)
if out is not None:
return out

# anything added by a scalar zero returns the other child
if pybamm.is_scalar_zero(left):
Expand Down Expand Up @@ -879,11 +920,12 @@ def simplified_subtraction(left, right):
def simplified_multiplication(left, right):
left, right = simplify_elementwise_binary_broadcasts(left, right)

# Broadcast commutes with multiplication operator
if isinstance(left, pybamm.Broadcast) and right.domain == []:
return left._unary_new_copy(left.orphans[0] * right)
elif isinstance(right, pybamm.Broadcast) and left.domain == []:
return right._unary_new_copy(left * right.orphans[0])
# Check for Concatenations and Broadcasts
out = simplified_binary_broadcast_concatenation(
left, right, simplified_multiplication
)
if out is not None:
return out

# simplify multiply by scalar zero, being careful about shape
if pybamm.is_scalar_zero(left):
Expand Down Expand Up @@ -1048,11 +1090,10 @@ def simplified_multiplication(left, right):
def simplified_division(left, right):
left, right = simplify_elementwise_binary_broadcasts(left, right)

# Broadcast commutes with division operator
if isinstance(left, pybamm.Broadcast) and right.domain == []:
return left._unary_new_copy(left.orphans[0] / right)
elif isinstance(right, pybamm.Broadcast) and left.domain == []:
return right._unary_new_copy(left / right.orphans[0])
# Check for Concatenations and Broadcasts
out = simplified_binary_broadcast_concatenation(left, right, simplified_division)
if out is not None:
return out

# zero divided by anything returns zero (being careful about shape)
if pybamm.is_scalar_zero(left):
Expand Down
4 changes: 4 additions & 0 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ def __neg__(self):
# Move negation inside the broadcast
# Apply recursively
return self._unary_new_copy(-self.orphans[0])
elif isinstance(self, pybamm.Concatenation) and all(
child.is_constant() for child in self.children
):
return pybamm.concatenation(*[-child for child in self.orphans])
else:
return pybamm.simplify_if_constant(pybamm.Negate(self))

Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,29 @@ def test_binary_simplifications(self):
with self.assertRaises(ZeroDivisionError):
b / a

def test_binary_simplifications_concatenations(self):
def conc_broad(x, y, z):
return pybamm.concatenation(
pybamm.PrimaryBroadcast(x, "negative electrode"),
pybamm.PrimaryBroadcast(y, "separator"),
pybamm.PrimaryBroadcast(z, "positive electrode"),
)

# Test that concatenations get simplified correctly
a = conc_broad(1, 2, 3)
b = conc_broad(11, 12, 13)
self.assertEqual((a + 4).id, conc_broad(5, 6, 7).id)
self.assertEqual((4 + a).id, conc_broad(5, 6, 7).id)
self.assertEqual((a + b).id, conc_broad(12, 14, 16).id)

# No simplifications if there are Variable or StateVector objects
v = pybamm.concatenation(
pybamm.Variable("x", "negative electrode"),
pybamm.Variable("y", "separator"),
pybamm.Variable("z", "positive electrode"),
)
self.assertIsInstance((a * v), pybamm.Multiplication)

def test_advanced_binary_simplifications(self):
# MatMul simplifications that often appear when discretising spatial operators
A = pybamm.Matrix(np.random.rand(10, 10))
Expand Down

0 comments on commit e29986b

Please sign in to comment.