Skip to content

Commit

Permalink
fixes slicing metatensor
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jan 11, 2023
1 parent 7eddd2d commit f32306a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
17 changes: 13 additions & 4 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,20 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
# respectively. Don't need to do anything with the metadata.
if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0:
ret_meta = decollate_batch(args[0], detach=False)[batch_idx]
if isinstance(ret_meta, list): # e.g. batch[0:2], re-collate
ret_meta = list_data_collate(ret_meta)
else: # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
try:
ret_meta = list_data_collate(ret_meta)
except (TypeError, ValueError, RuntimeError, IndexError) as e:
raise ValueError(
"Inconsistent metadata when slicing a MetaTensor, please convert it into"
"a torch Tensor using `x.as_tensor()` or into a numpy array using `x.array`."
) from e
elif isinstance(
ret_meta, MetaObj
): # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer
ret_meta.is_batch = False
ret.__dict__ = ret_meta.__dict__.copy()
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th
# dimension.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,13 @@ def test_indexing(self):
for _d in d:
self.check_meta(_d, data)

def test_slicing(self):
x = MetaTensor(np.zeros((10, 3, 4)))
self.assertEqual(x[slice(4, 1)].shape[0], 0)
x.is_batch = True
with self.assertRaises(ValueError):
x[slice(0, 8)]

@parameterized.expand(DTYPES)
@SkipIfBeforePyTorchVersion((1, 8))
def test_decollate(self, dtype):
Expand Down

0 comments on commit f32306a

Please sign in to comment.