-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slicing a MetaTensor: TypeError: only integer tensors of a single element can be converted to an index
#5844
Comments
thanks for reporting, it's related to the >>> x = MetaTensor(np.random.rand(128, 1024, 8, 64))
>>> x.is_batch = True
>>> x[slice(0, 64, None), ...]
2023-01-11 15:45:21,510 - > collate/stack a list of tensors
Traceback (most recent call last): (it seems the error message I have is different from yours @jeanmonet, could you please share the full traceback if it's convenient for you to do it?) |
Thanks for your fast reply. In the meantime I am able to temporarily circumvent the issue by using # Replace:
from monai.data import DataLoader
# With:
from torch.utils.data import DataLoader
import numpy as np
from monai.data import MetaTensor
x = MetaTensor(np.random.rand(128, 1024, 8, 64))
x.is_batch = True # https://github.com/Project-MONAI/MONAI/issues/5844
print(x.shape, x.is_batch)
x[slice(0, 64, None), ...].shape Full traceback
However, the training Traceback (pointing to this line: https://github.com/pashtari/factorizer/blob/fa7f03dfe8c390f0993cbc0d161922d2de819a1d/factorizer/factorization/operations.py#L332) is: Full traceback
|
thanks, I still couldn't get the you can still track the meta info using import factorizer as ft
import torch
from monai.data import MetaTensor
from monai.data.utils import collate_meta_tensor
x = [MetaTensor(torch.rand(1024, 8, 64)) for _ in range(128)]
x = collate_meta_tensor(x)
m = x.meta.copy()
sw_matricize = ft.SWMatricize((128, 1024, 8, 64), head_dim=8, patch_size=8)
x = x.as_tensor()
y = sw_matricize(x) # (262144, 8, 64)
z = sw_matricize.inverse_forward(y)
z.meta = m
print(torch.equal(x, z)) # True I'll make a PR to improve the error message about the usage. |
Thank you for investigating and for the suggestion, it is well noted and very helpful. |
Signed-off-by: Wenqi Li <[email protected]> Fixes #5844 - improves the error message when slicing data with inconsistent metadata - fixes an indexing case: ``` x = MetaTensor(np.zeros((10, 3, 4))) x[slice(1, 0)] # should return zero length data ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
The goal is to slice the tensor in half. For some reason that I cannot comprehend, it doesn't work on a particular tensor:
What could be going on here?
If it is of any use, the error occurs at this specific line: https://github.com/pashtari/factorizer/blob/fa7f03dfe8c390f0993cbc0d161922d2de819a1d/factorizer/factorization/operations.py#L331
To Reproduce
How is it possible that a specific tensor would cause such error?
Environment
The text was updated successfully, but these errors were encountered: