diff --git a/einops/einops.py b/einops/einops.py index f4ae207a..2b20f481 100644 --- a/einops/einops.py +++ b/einops/einops.py @@ -687,9 +687,19 @@ def parse_shape(x, pattern: str) -> dict: else: composition = exp.composition result = {} - for (axis_name,), axis_length in zip(composition, shape): # type: ignore - if axis_name != "_": - result[axis_name] = axis_length + for axes, axis_length in zip(composition, shape): # type: ignore + # axes either [], or [AnonymousAxis] or ['axis_name'] + if len(axes) == 0: + if axis_length != 1: + raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}") + else: + [axis] = axes + if isinstance(axis, str): + if axis != "_": + result[axis] = axis_length + else: + if axis.value != axis_length: + raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}") return result diff --git a/tests/test_other.py b/tests/test_other.py index 87b08157..313d5d2b 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -121,8 +121,9 @@ def test_repeating(self): with pytest.raises(einops.EinopsError): parse_shape(self.backend.from_numpy(self.x), "a a b b") - @parameterized.expand( - [ + + def test_ellipsis(self): + for shape, pattern, expected in [ ([10, 20], "...", dict()), ([10], "... a", dict(a=10)), ([10, 20], "... a", dict(a=20)), @@ -134,13 +135,37 @@ def test_repeating(self): ([10, 20, 30, 40], "a ...", dict(a=10)), ([10, 20, 30, 40], " a ... b", dict(a=10, b=40)), ([10, 40], " a ... b", dict(a=10, b=40)), - ] - ) - def test_ellipsis(self, shape: List[int], pattern: str, expected: Dict[str, int]): - x = numpy.ones(shape) - parsed1 = parse_shape(x, pattern) - parsed2 = parse_shape(self.backend.from_numpy(x), pattern) - assert parsed1 == parsed2 == expected + ]: + x = numpy.ones(shape) + parsed1 = parse_shape(x, pattern) + parsed2 = parse_shape(self.backend.from_numpy(x), pattern) + assert parsed1 == parsed2 == expected + + def test_parse_with_anonymous_axes(self): + for shape, pattern, expected in [ + ([1, 2, 3, 4], "1 2 3 a", dict(a=4)), + ([10, 1, 2], "a 1 2", dict(a=10)), + ([10, 1, 2], "a () 2", dict(a=10)), + ]: + x = numpy.ones(shape) + parsed1 = parse_shape(x, pattern) + parsed2 = parse_shape(self.backend.from_numpy(x), pattern) + assert parsed1 == parsed2 == expected + + + def test_failures(self): + # every test should fail + for shape, pattern in [ + ([1, 2, 3, 4], "a b c"), + ([1, 2, 3, 4], "2 a b c"), + ([1, 2, 3, 4], "a b c ()"), + ([1, 2, 3, 4], "a b c d e"), + ([1, 2, 3, 4], "a b c d e ..."), + ([1, 2, 3, 4], "a b c ()"), + ]: + with pytest.raises(RuntimeError): + x = numpy.ones(shape) + parse_shape(self.backend.from_numpy(x), pattern) _SYMBOLIC_BACKENDS = [