Skip to content

Commit

Permalink
#1429 more hacky test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 14, 2021
1 parent 9aad14a commit 397736c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
12 changes: 8 additions & 4 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,11 +922,14 @@ def simplified_multiplication(left, right):
):
l_left, l_right = left.orphans
new_left = right * l_left
# Special hack for the case where l_left is a matrix one
# because of weird domain errors otherwise
if new_left == right and isinstance(right, pybamm.Array):
new_left = right.new_copy()
# be careful about domains to avoid weird errors
new_left.clear_domains()
new_mul = new_left @ l_right
# Keep the domain of the old left
new_left.copy_domains(left)
new_mul.copy_domains(left)
return new_mul

Expand Down Expand Up @@ -956,11 +959,14 @@ def simplified_multiplication(left, right):
):
r_left, r_right = right.orphans
new_left = left * r_left
# Special hack for the case where r_left is a matrix one
# because of weird domain errors otherwise
if new_left == left and isinstance(left, pybamm.Array):
new_left = left.new_copy()
# be careful about domains to avoid weird errors
new_left.clear_domains()
new_mul = new_left @ r_right
# Keep the domain of the old right
new_left.copy_domains(left)
new_mul.copy_domains(right)
return new_mul

Expand Down Expand Up @@ -1024,7 +1030,6 @@ def simplified_division(left, right):
new_left.clear_domains()
new_division = new_left @ l_right
# Keep the domain of the old left
new_left.copy_domains(left)
new_division.copy_domains(left)
return new_division

Expand Down Expand Up @@ -1082,7 +1087,6 @@ def simplified_matrix_multiplication(left, right):
new_left.clear_domains()
new_mul = new_left @ r_right
# Keep the domain of the old right
new_left.copy_domains(left)
new_mul.copy_domains(right)
return new_mul

Expand Down
5 changes: 4 additions & 1 deletion pybamm/expression_tree/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import pybamm
import numpy as np
from scipy.sparse import issparse
from scipy.sparse import issparse, csr_matrix


class Matrix(pybamm.Array):
Expand All @@ -27,4 +27,7 @@ def __init__(
name = "Matrix {!s}".format(entries.shape)
if issparse(entries):
name = "Sparse " + name
# Convert all sparse matrices to csr
if issparse(entries) and not isinstance(entries, csr_matrix):
entries = csr_matrix(entries)
super().__init__(entries, name, domain, auxiliary_domains, entries_string)
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def test_evaluates_on_edges(self):
a = pybamm.StateVector(slice(0, 10), domain="test")
self.assertFalse(pybamm.Index(a, slice(1)).evaluates_on_edges("primary"))
self.assertFalse(pybamm.Laplacian(a).evaluates_on_edges("primary"))
self.assertTrue(pybamm.GradientSquared(a).evaluates_on_edges("primary"))
self.assertFalse(pybamm.GradientSquared(a).evaluates_on_edges("primary"))
self.assertFalse(pybamm.BoundaryIntegral(a).evaluates_on_edges("primary"))
self.assertTrue(pybamm.Upwind(a).evaluates_on_edges("primary"))
self.assertTrue(pybamm.Downwind(a).evaluates_on_edges("primary"))
Expand Down

0 comments on commit 397736c

Please sign in to comment.