Skip to content

Commit

Permalink
#1429 a few more simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 14, 2021
1 parent 1a4a487 commit 4a0b73d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 6 deletions.
69 changes: 63 additions & 6 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,23 +657,22 @@ def simplify_elementwise_binary_broadcasts(left, right):
# No need to broadcast if the other symbol already has the shape that is being
# broadcasted to
# Also check for broadcast of a broadcast
if left.domains == right.domains and all(
left.evaluates_on_edges(dim) == right.evaluates_on_edges(dim)
for dim in ["primary", "secondary", "tertiary"]
):
if isinstance(left, pybamm.Broadcast):
if left.domains == right.domains:
if isinstance(left, pybamm.Broadcast) and left.broadcasts_to_nodes:
if left.child.domain == []:
left = left.orphans[0]
elif (
isinstance(left.child, pybamm.Broadcast)
and left.child.broadcasts_to_nodes
and left.child.child.domain == []
):
left = left.child.orphans[0]
elif isinstance(right, pybamm.Broadcast):
elif isinstance(right, pybamm.Broadcast) and right.broadcasts_to_nodes:
if right.child.domain == []:
right = right.orphans[0]
elif (
isinstance(right.child, pybamm.Broadcast)
and right.child.broadcasts_to_nodes
and right.child.child.domain == []
):
right = right.child.orphans[0]
Expand Down Expand Up @@ -777,6 +776,10 @@ def simplified_addition(left, right):
):
return left

# Return constant if both sides are constant
if left.is_constant() and right.is_constant():
return pybamm.simplify_if_constant(pybamm.Addition(left, right))

# Simplify A @ c + B @ c to (A + B) @ c if (A + B) is constant
# This is a common construction that appears from discretisation of spatial
# operators
Expand All @@ -793,6 +796,25 @@ def simplified_addition(left, right):
new_sum.copy_domains(pybamm.Addition(left, right))
return new_sum

if isinstance(right, pybamm.Addition) and left.is_constant():
# Simplify a + (b + c) to (a + b) + c if (a + b) is constant
if right.left.is_constant():
r_left, r_right = right.orphans
return (left + r_left) + r_right
# Simplify a + (b + c) to (a + c) + b if (a + c) is constant
elif right.right.is_constant():
r_left, r_right = right.orphans
return (left + r_right) + r_left
if isinstance(left, pybamm.Addition) and right.is_constant():
# Simplify (a + b) + c to a + (b + c) if (b + c) is constant
if left.right.is_constant():
l_left, l_right = left.orphans
return l_left + (l_right + right)
# Simplify (a + b) + c to (a + c) + b if (a + c) is constant
elif left.left.is_constant():
l_left, l_right = left.orphans
return (l_left + right) + l_right

return pybamm.simplify_if_constant(pybamm.Addition(left, right))


Expand Down Expand Up @@ -988,6 +1010,33 @@ def simplified_multiplication(left, right):
new_left = left / r_right
return new_left * r_left

# Simplify a * (b + c) to (a * b) + (a * c) if (a * b) or (a * c) is constant
# This is a common construction that appears from discretisation of spatial
# operators
# Also do this for cases like a * (b @ c + d) where (a * b) is constant
elif isinstance(right, Addition):
mul_classes = (
pybamm.Multiplication,
pybamm.MatrixMultiplication,
pybamm.Division,
)
if (
right.left.is_constant()
or right.right.is_constant()
or (isinstance(right.left, mul_classes) and right.left.left.is_constant())
or (isinstance(right.right, mul_classes) and right.right.left.is_constant())
):
r_left, r_right = right.orphans
return (left * r_left) + (left * r_right)

# Negation simplifications
if isinstance(left, pybamm.Negate) and right.is_constant():
# Simplify (-a) * b to a * (-b) if (-b) is constant
return left.orphans[0] * (-right)
elif isinstance(right, pybamm.Negate) and left.is_constant():
# Simplify a * (-b) to (-a) * b if (-a) is constant
return (-left) * right.orphans[0]

return pybamm.Multiplication(left, right)


Expand Down Expand Up @@ -1047,6 +1096,14 @@ def simplified_division(left, right):
if new_right.is_constant():
return l_left * new_right

# Negation simplifications
elif isinstance(left, pybamm.Negate) and right.is_constant():
# Simplify (-a) / b to a / (-b) if (-b) is constant
return left.orphans[0] / (-right)
elif isinstance(right, pybamm.Negate) and left.is_constant():
# Simplify a / (-b) to (-a) / b if (-a) is constant
return (-left) / right.orphans[0]

return pybamm.simplify_if_constant(pybamm.Division(left, right))


Expand Down
7 changes: 7 additions & 0 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def __init__(
self.broadcast_domain = broadcast_domain
super().__init__(name, child, domain, auxiliary_domains)

@property
def broadcasts_to_nodes(self):
if self.broadcast_type.endswith("nodes"):
return True
else:
return False

def reduce_one_dimension(self):
"""
Reduce the broadcast by one dimension. See specific broadcast classes
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ def test_binary_simplifications(self):
self.assertIsInstance((c + a), pybamm.Parameter)
self.assertIsInstance((c + b), pybamm.Addition)
self.assertIsInstance((b + c), pybamm.Addition)
# rearranging additions
self.assertEqual(((c + 1) + 2).id, (c + 3).id)
self.assertEqual(((1 + c) + 2).id, (3 + c).id)
self.assertEqual((2 + (c + 1)).id, (3 + c).id)
self.assertEqual((2 + (1 + c)).id, (3 + c).id)
# addition with broadcast zero
self.assertIsInstance((b + broad0), pybamm.PrimaryBroadcast)
np.testing.assert_array_equal((b + broad0).child.evaluate(), 1)
Expand Down Expand Up @@ -495,6 +500,9 @@ def test_binary_simplifications(self):
# multiplication with -1
self.assertEqual((c * -1).id, (-c).id)
self.assertEqual((-1 * c).id, (-c).id)
# multiplication with a negation
self.assertEqual((-c * 4).id, (c * -4).id)
self.assertEqual((4 * -c).id, (-4 * c).id)
# multiplication with broadcasts
self.assertEqual((c * broad2).id, pybamm.PrimaryBroadcast(c * 2, "domain").id)
self.assertEqual((broad2 * c).id, pybamm.PrimaryBroadcast(2 * c, "domain").id)
Expand Down Expand Up @@ -522,6 +530,9 @@ def test_binary_simplifications(self):
# division by itself
self.assertEqual((c / c).id, pybamm.Scalar(1).id)
self.assertEqual((broad2 / broad2).id, broad1.id)
# division with a negation
self.assertEqual((-c / 4).id, (c / -4).id)
self.assertEqual((4 / -c).id, (-4 / c).id)
# division with broadcasts
self.assertEqual((c / broad2).id, pybamm.PrimaryBroadcast(c / 2, "domain").id)
self.assertEqual((broad2 / c).id, pybamm.PrimaryBroadcast(2 / c, "domain").id)
Expand Down

0 comments on commit 4a0b73d

Please sign in to comment.