diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 657d195ad6..cba40ec6f8 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -410,6 +410,12 @@ def _contraction_list_from_path( return contraction_list +def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]: + # Create a right to left contraction path + # if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0)) + return tuple(pairwise(reversed(range(n)))) + + def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable: """ Multiplication and summation of tensors using the Einstein summation convention. @@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar else: # By default, we try right to left because we assume that most graphs # have a lower dimensional rightmost operand - path = tuple(pairwise(reversed(range(len(tensor_operands))))) + path = _right_to_left_path(len(tensor_operands)) contraction_list = _contraction_list_from_path( subscripts, tensor_operands, path ) @@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar einsum_call=True, # Not part of public API optimize="optimal", ) # type: ignore - path = tuple(contraction[0] for contraction in contraction_list) + np_path = tuple(contraction[0] for contraction in contraction_list) + + if len(np_path) == 1 and len(np_path[0]) > 2: + # When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing + # pairwise reductions, which our implementation below demands. + path = _right_to_left_path(len(tensor_operands)) + contraction_list = _contraction_list_from_path( + subscripts, tensor_operands, path + ) + else: + path = np_path + optimized = True def removechars(s, chars): @@ -744,7 +761,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names): ) else: raise ValueError( - f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}" + f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}, {path=}." ) # the resulting 'operand' with axis labels 'names' should be a permutation of the desired result diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index b359f050df..ba8e354518 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -262,3 +262,22 @@ def test_broadcastable_dims(): atol = 1e-12 if config.floatX == "float64" else 1e-2 np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol) np.testing.assert_allclose(optimal_eval, np_eval, atol=atol) + + +@pytest.mark.parametrize("static_length", [False, True]) +def test_threeway_mul(static_length): + # Regression test for https://github.com/pymc-devs/pytensor/issues/1184 + # x, y, z = vectors("x", "y", "z") + sh = (3,) if static_length else (None,) + x = tensor("x", shape=sh) + y = tensor("y", shape=sh) + z = tensor("z", shape=sh) + out = einsum("..., ..., ... -> ...", x, y, z) + + x_test = np.ones((3,), dtype=x.dtype) + y_test = x_test + 1 + z_test = x_test + 2 + np.testing.assert_allclose( + out.eval({x: x_test, y: y_test, z: z_test}), + np.full((3,), fill_value=6), + )