diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index f759599b21..eb1bf0d226 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -680,14 +680,57 @@ def simplify_elementwise_binary_broadcasts(left, right): return left, right +def simplified_binary_broadcast_concatenation(left, right, operator): + """ + Check if there are concatenations or broadcasts that we can commute the operator + with + """ + # Broadcast commutes with elementwise operators + if isinstance(left, pybamm.Broadcast) and right.domain == []: + return left._unary_new_copy(operator(left.orphans[0], right)) + elif isinstance(right, pybamm.Broadcast) and left.domain == []: + return right._unary_new_copy(operator(left, right.orphans[0])) + + # Concatenation commutes with elementwise operators + # If one of the sides is constant then commute concatenation with the operator + # Don't do this if left has any Variable or StateVector objects as these will + # be simplified differently later on + if isinstance(left, pybamm.Concatenation) and not any( + isinstance(child, (pybamm.Variable, pybamm.StateVector)) + for child in left.children + ): + if right.evaluates_to_constant_number(): + return pybamm.concatenation( + *[operator(child, right) for child in left.orphans] + ) + elif ( + isinstance(right, pybamm.Concatenation) + and all(child.is_constant() for child in left.children) + and all(child.is_constant() for child in right.children) + ): + return pybamm.concatenation( + *[ + operator(left_child, right_child) + for left_child, right_child in zip(left.orphans, right.orphans) + ] + ) + if isinstance(right, pybamm.Concatenation) and not any( + isinstance(child, (pybamm.Variable, pybamm.StateVector)) + for child in right.children + ): + if left.evaluates_to_constant_number(): + return pybamm.concatenation( + *[operator(left, child) for child in right.orphans] + ) + + def simplified_power(left, right): left, right = simplify_elementwise_binary_broadcasts(left, right) - # Broadcast commutes with power operator - if isinstance(left, pybamm.Broadcast) and right.domain == []: - return left._unary_new_copy(left.orphans[0] ** right) - elif isinstance(right, pybamm.Broadcast) and left.domain == []: - return right._unary_new_copy(left ** right.orphans[0]) + # Check for Concatenations and Broadcasts + out = simplified_binary_broadcast_concatenation(left, right, simplified_power) + if out is not None: + return out # anything to the power of zero is one if pybamm.is_scalar_zero(right): @@ -733,11 +776,10 @@ def simplified_addition(left, right): """ left, right = simplify_elementwise_binary_broadcasts(left, right) - # Broadcast commutes with addition operator - if isinstance(left, pybamm.Broadcast) and right.domain == []: - return left._unary_new_copy(left.orphans[0] + right) - elif isinstance(right, pybamm.Broadcast) and left.domain == []: - return right._unary_new_copy(left + right.orphans[0]) + # Check for Concatenations and Broadcasts + out = simplified_binary_broadcast_concatenation(left, right, simplified_addition) + if out is not None: + return out # anything added by a scalar zero returns the other child elif pybamm.is_scalar_zero(left): @@ -828,11 +870,10 @@ def simplified_subtraction(left, right): """ left, right = simplify_elementwise_binary_broadcasts(left, right) - # Broadcast commutes with subtraction operator - if isinstance(left, pybamm.Broadcast) and right.domain == []: - return left._unary_new_copy(left.orphans[0] - right) - elif isinstance(right, pybamm.Broadcast) and left.domain == []: - return right._unary_new_copy(left - right.orphans[0]) + # Check for Concatenations and Broadcasts + out = simplified_binary_broadcast_concatenation(left, right, simplified_subtraction) + if out is not None: + return out # anything added by a scalar zero returns the other child if pybamm.is_scalar_zero(left): @@ -879,11 +920,12 @@ def simplified_subtraction(left, right): def simplified_multiplication(left, right): left, right = simplify_elementwise_binary_broadcasts(left, right) - # Broadcast commutes with multiplication operator - if isinstance(left, pybamm.Broadcast) and right.domain == []: - return left._unary_new_copy(left.orphans[0] * right) - elif isinstance(right, pybamm.Broadcast) and left.domain == []: - return right._unary_new_copy(left * right.orphans[0]) + # Check for Concatenations and Broadcasts + out = simplified_binary_broadcast_concatenation( + left, right, simplified_multiplication + ) + if out is not None: + return out # simplify multiply by scalar zero, being careful about shape if pybamm.is_scalar_zero(left): @@ -1048,11 +1090,10 @@ def simplified_multiplication(left, right): def simplified_division(left, right): left, right = simplify_elementwise_binary_broadcasts(left, right) - # Broadcast commutes with division operator - if isinstance(left, pybamm.Broadcast) and right.domain == []: - return left._unary_new_copy(left.orphans[0] / right) - elif isinstance(right, pybamm.Broadcast) and left.domain == []: - return right._unary_new_copy(left / right.orphans[0]) + # Check for Concatenations and Broadcasts + out = simplified_binary_broadcast_concatenation(left, right, simplified_division) + if out is not None: + return out # zero divided by anything returns zero (being careful about shape) if pybamm.is_scalar_zero(left): diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 67f63378c8..ccfd519617 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -588,6 +588,10 @@ def __neg__(self): # Move negation inside the broadcast # Apply recursively return self._unary_new_copy(-self.orphans[0]) + elif isinstance(self, pybamm.Concatenation) and all( + child.is_constant() for child in self.children + ): + return pybamm.concatenation(*[-child for child in self.orphans]) else: return pybamm.simplify_if_constant(pybamm.Negate(self)) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index c132f17358..1fac526651 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -543,6 +543,29 @@ def test_binary_simplifications(self): with self.assertRaises(ZeroDivisionError): b / a + def test_binary_simplifications_concatenations(self): + def conc_broad(x, y, z): + return pybamm.concatenation( + pybamm.PrimaryBroadcast(x, "negative electrode"), + pybamm.PrimaryBroadcast(y, "separator"), + pybamm.PrimaryBroadcast(z, "positive electrode"), + ) + + # Test that concatenations get simplified correctly + a = conc_broad(1, 2, 3) + b = conc_broad(11, 12, 13) + self.assertEqual((a + 4).id, conc_broad(5, 6, 7).id) + self.assertEqual((4 + a).id, conc_broad(5, 6, 7).id) + self.assertEqual((a + b).id, conc_broad(12, 14, 16).id) + + # No simplifications if there are Variable or StateVector objects + v = pybamm.concatenation( + pybamm.Variable("x", "negative electrode"), + pybamm.Variable("y", "separator"), + pybamm.Variable("z", "positive electrode"), + ) + self.assertIsInstance((a * v), pybamm.Multiplication) + def test_advanced_binary_simplifications(self): # MatMul simplifications that often appear when discretising spatial operators A = pybamm.Matrix(np.random.rand(10, 10))