Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetaTensor: collate; decollate; dataset; dataloader; out=; indexing and iterating across batches #4137

Merged
merged 13 commits into from
Apr 20, 2022
3 changes: 3 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
if kwargs is None:
kwargs = {}
ret: MetaTensor = super().__torch_function__(func, types, args, kwargs)
# if `out` has been used as argument, metadata is not copied, nothing to do.
if "out" in kwargs:
return ret
# e.g., __repr__ returns a string
if not isinstance(ret, torch.Tensor):
return ret
Expand Down
18 changes: 15 additions & 3 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.utils.data._utils.collate import default_collate

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
from monai.utils import (
MAX_SEED,
Expand Down Expand Up @@ -346,9 +347,15 @@ def list_data_collate(batch: Sequence):
ret = {}
for k in elem:
key = k
ret[key] = default_collate([d[key] for d in data])
return ret
return default_collate(data)
data_for_batch = [d[key] for d in data]
ret[key] = default_collate(data_for_batch)
if isinstance(ret[key], MetaTensor) and all(isinstance(d, MetaTensor) for d in data_for_batch):
rijobro marked this conversation as resolved.
Show resolved Hide resolved
ret[key].meta = list_data_collate([i.meta for i in data_for_batch])
else:
ret = default_collate(data)
if isinstance(ret, MetaTensor) and all(isinstance(d, MetaTensor) for d in data):
ret.meta = list_data_collate([i.meta for i in data])
return ret
except RuntimeError as re:
re_str = str(re)
if "equal size" in re_str:
Expand Down Expand Up @@ -466,6 +473,11 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
if batch.ndim == 0:
return batch.item() if detach else batch
out_list = torch.unbind(batch, dim=0)
# if of type MetaTensor, decollate the metadata
if isinstance(batch, MetaTensor):
metas = decollate_batch(batch.meta)
for i in range(len(out_list)):
out_list[i].meta = metas[i] # type: ignore
if out_list[0].ndim == 0 and detach:
return [t.item() for t in out_list]
return list(out_list)
Expand Down
77 changes: 70 additions & 7 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import torch
from parameterized import parameterized

from monai.data import DataLoader, Dataset
from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import decollate_batch, list_data_collate
from monai.utils.enums import PostFix
from monai.utils.module import pytorch_after
from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda
from tests.utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda

DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]]
TESTS = []
Expand Down Expand Up @@ -261,12 +263,73 @@ def test_amp(self):
im_conv2 = conv(im)
self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3)

# TODO
# collate
# decollate
# dataset
# dataloader
# matplotlib
def test_out(self):
"""Test when `out` is given as an argument."""
m1, _ = self.get_im()
m1_orig = deepcopy(m1)
m2, _ = self.get_im()
m3, _ = self.get_im()
torch.add(m2, m3, out=m1)
m1_add = m2 + m3

assert_allclose(m1, m1_add)
aff1, aff1_orig = m1.affine, m1_orig.affine
assert_allclose(aff1, aff1_orig)
meta1 = {k: v for k, v in m1.meta.items() if k != "affine"}
meta1_orig = {k: v for k, v in m1_orig.meta.items() if k != "affine"}
self.assertEqual(meta1, meta1_orig)

@parameterized.expand(TESTS)
def test_collate(self, device, dtype):
numel = 3
ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)]
collated = list_data_collate(ims)
# tensor
self.assertIsInstance(collated, MetaTensor)
expected_shape = (numel,) + tuple(ims[0].shape)
self.assertTupleEqual(tuple(collated.shape), expected_shape)
for i, im in enumerate(ims):
self.check(im, ims[i], ids=True)
# affine
self.assertIsInstance(collated.affine, torch.Tensor)
expected_shape = (numel,) + tuple(ims[0].affine.shape)
self.assertTupleEqual(tuple(collated.affine.shape), expected_shape)

@parameterized.expand(TESTS)
def test_dataset(self, device, dtype):
ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(4)]
ds = Dataset(ims)
for i, im in enumerate(ds):
self.check(im, ims[i], ids=True)

@parameterized.expand(DTYPES)
@SkipIfBeforePyTorchVersion((1, 8))
def test_dataloader(self, dtype):
batch_size = 5
ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)]
ds = Dataset(ims)
expected_im_shape = (batch_size,) + tuple(ims[0].shape)
expected_affine_shape = (batch_size,) + tuple(ims[0].affine.shape)
dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size)
for batch in dl:
self.assertIsInstance(batch, MetaTensor)
self.assertTupleEqual(tuple(batch.shape), expected_im_shape)
self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape)

@parameterized.expand(DTYPES)
@SkipIfBeforePyTorchVersion((1, 8))
def test_decollate(self, dtype):
batch_size = 3
ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)]
ds = Dataset(ims)
dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size)
batch = next(iter(dl))
decollated = decollate_batch(batch)
self.assertIsInstance(decollated, list)
self.assertEqual(len(decollated), batch_size)
for elem, im in zip(decollated, ims):
self.assertIsInstance(elem, MetaTensor)
self.check(elem, im, ids=False)


if __name__ == "__main__":
Expand Down