Skip to content

Commit

Permalink
Fix bug in einsum
Browse files Browse the repository at this point in the history
A shortcut in the numpy implementation of einsum_path when there's nothing to optimize, creates a default path that can combine more than 2 operands. Our implementation only works with 2 or 1 operand operations at each step.

https://github.com/numpy/numpy/blob/cc5851e654bfd82a23f2758be4bd224be84fc1c3/numpy/_core/einsumfunc.py#L945-L951
  • Loading branch information
ricardoV94 committed Feb 3, 2025
1 parent 8bb2038 commit c22e79e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
23 changes: 20 additions & 3 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,12 @@ def _contraction_list_from_path(
return contraction_list


def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
# Create a right to left contraction path
# if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
return tuple(pairwise(reversed(range(n))))


def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
"""
Multiplication and summation of tensors using the Einstein summation convention.
Expand Down Expand Up @@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
else:
# By default, we try right to left because we assume that most graphs
# have a lower dimensional rightmost operand
path = tuple(pairwise(reversed(range(len(tensor_operands)))))
path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
Expand All @@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
einsum_call=True, # Not part of public API
optimize="optimal",
) # type: ignore
path = tuple(contraction[0] for contraction in contraction_list)
np_path = tuple(contraction[0] for contraction in contraction_list)

if len(np_path) == 1 and len(np_path[0]) > 2:
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
# pairwise reductions, which our implementation below demands.
path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
else:
path = np_path

optimized = True

def removechars(s, chars):
Expand Down Expand Up @@ -744,7 +761,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
)
else:
raise ValueError(
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}, {path=}."
)

# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
Expand Down
19 changes: 19 additions & 0 deletions tests/tensor/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,22 @@ def test_broadcastable_dims():
atol = 1e-12 if config.floatX == "float64" else 1e-2
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)


@pytest.mark.parametrize("static_length", [False, True])
def test_threeway_mul(static_length):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1184
# x, y, z = vectors("x", "y", "z")
sh = (3,) if static_length else (None,)
x = tensor("x", shape=sh)
y = tensor("y", shape=sh)
z = tensor("z", shape=sh)
out = einsum("..., ..., ... -> ...", x, y, z)

x_test = np.ones((3,), dtype=x.dtype)
y_test = x_test + 1
z_test = x_test + 2
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test, z: z_test}),
np.full((3,), fill_value=6),
)

0 comments on commit c22e79e

Please sign in to comment.