Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature suggestion] fixup/support anonymous axes in parse_shape #302

Closed
rehno-lindeque opened this issue Jan 11, 2024 · 2 comments
Closed

Comments

@rehno-lindeque
Copy link

rehno-lindeque commented Jan 11, 2024

It would be handy to support the second example below:

>>> parse_shape(torch.tensor([[0]]), "b n")
{'b': 1, 'n': 1}
>>> parse_shape(torch.tensor([[0]]), "b 1")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/nix/store/2x7c1xn335dvgia34gz2dqz6vy3a9pj0-python3-3.11.6-env/lib/python3.11/site-packages/einops/einops.py", line 690, in parse_shape
    for (axis_name,), axis_length in zip(composition, shape):  # type: ignore
        ^^^^^^^^^^^^
ValueError: not enough values to unpack (expected 1, got 0)

Example use case:

def assert_shapes(*tensor_patterns, **dim_sizes):
    # Collect all shapes, including explicit dimension sizes
    shapes = [
        dict(**dim_sizes),
        *(parse_shape(tensor, pattern) for tensor, pattern in tensor_patterns),
    ]

    # Check for consistency in dimensions across all tensors
    for dim_name in set(dim for shape in shapes for dim in shape):
        sizes = set(shape.get(dim_name, None) for shape in shapes)
        sizes.discard(None)
        assert len(sizes) <= 1, f"Inconsistent size for dimension '{dim_name}': {sizes}"

tensor1 = torch.randn(3, 1, 64)
tensor2 = torch.randn(3, 2, 64)

# I want:
assert_shapes(
  (tensor1, 'batch 1 h'),
  (tensor2, 'batch 2 h'),
  h=64
)

# Instead of:
assert_shapes(
  (tensor1, 'batch c1 h'),
  (tensor2, 'batch c2 h'),
  h=64,
  c1=1,
  c2=2
)

I imagine it would have similar behavior to the ellipsis / underscore syntax:

>>> parse_shape(torch.tensor([[[1,2,3]]]), "b _ k")
{'b': 1, 'k': 3}

Thanks!

arogozhnikov added a commit that referenced this issue Jan 11, 2024
Allow anonymous axes in parse_shape, fix #302
@arogozhnikov
Copy link
Owner

Makes sense, please try current master:

pip install git+https://github.com/arogozhnikov/einops.git

PS seems einops in nix-able now 😉

@rehno-lindeque
Copy link
Author

Thank you, the ergonomics of einops just keeps getting better!

PS seems einops in nix-able now 😉

Indeed, since 0.3.2 actually 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants