diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 0e213f130b..e38e009e96 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -111,6 +111,7 @@ class MetaObj: def __init__(self): self._meta: dict = self.get_default_meta() + self._is_batch: bool = False @staticmethod def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: @@ -176,6 +177,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: id_in = id(input_objs[0]) if len(input_objs) > 0 else None deep_copy = id(self) != id_in self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) + self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False def get_default_meta(self) -> dict: """Get the default meta. @@ -194,6 +196,7 @@ def __repr__(self) -> str: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) else: out += "None" + out += f"\nIs batch?: {self.is_batch}" return out @@ -206,3 +209,13 @@ def meta(self) -> dict: def meta(self, d: dict) -> None: """Set the meta.""" self._meta = d + + @property + def is_batch(self) -> bool: + """Return whether object is part of batch or not.""" + return self._is_batch + + @is_batch.setter + def is_batch(self, val: bool) -> None: + """Set whether object is part of batch or not.""" + self._is_batch = val diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ba80f93e74..9196f0186c 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,11 +13,12 @@ import warnings from copy import deepcopy -from typing import Callable +from typing import Any, Callable, Sequence import torch from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix __all__ = ["MetaTensor"] @@ -54,11 +55,20 @@ class MetaTensor(MetaObj, torch.Tensor): assert m2.affine == affine Notes: + - Requires pytorch 1.9 or newer for full compatibility. - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. + - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. + - With a batch of data, `batch[0]` will return the 0th image + with the 0th metadata. When the batch dimension is non-singleton, e.g., + `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the + last example) of the metadata will be returned, and `is_batch` will return `True`. + - When creating a batch with this class, use `monai.data.DataLoader` as opposed + to `torch.utils.data.DataLoader`, as this will take care of collating the + metadata properly. """ @staticmethod @@ -101,21 +111,93 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call if isinstance(val, torch.Tensor): setattr(self, attribute, val.to(self.device)) + @staticmethod + def update_meta(rets: Sequence, func, args, kwargs): + """Update the metadata from the output of `__torch_function__`. + The output could be a single object, or a sequence of them. Hence, they get + converted to a sequence if necessary and then processed by iterating across them. + + For each element, if not of type `MetaTensor`, then nothing to do + """ + out = [] + metas = None + for idx, ret in enumerate(rets): + # if not `MetaTensor`, nothing to do. + if not isinstance(ret, MetaTensor): + pass + # if not tracking, convert to `torch.Tensor`. + elif not (get_track_meta() or get_track_transforms()): + ret = ret.as_tensor() + # else, handle the `MetaTensor` metadata. + else: + meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) + ret._copy_meta(meta_args) + + # If we have a batch of data, then we need to be careful if a slice of + # the data is returned. Depending on how the data are indexed, we return + # some or all of the metadata, and the return object may or may not be a + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). + if ret.is_batch: + # only decollate metadata once + if metas is None: + metas = decollate_batch(ret.meta) + # if indexing e.g., `batch[0]` + if func == torch.Tensor.__getitem__: + idx = args[1] + if isinstance(idx, Sequence): + idx = idx[0] + # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the + # first element will be `slice(None, None, None)` and `Ellipsis`, + # respectively. Don't need to do anything with the metadata. + if idx not in (slice(None, None, None), Ellipsis): + meta = metas[idx] + # if using e.g., `batch[0:2]`, then `is_batch` should still be + # `True`. Also re-collate the remaining elements. + if isinstance(meta, list) and len(meta) > 1: + ret.meta = list_data_collate(meta) + # if using e.g., `batch[0]` or `batch[0, 1]`, then return single + # element from batch, and set `is_batch` to `False`. + else: + ret.meta = meta + ret.is_batch = False + # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. + # But we only want to split the batch if the `unbind` is along the 0th + # dimension. + elif func == torch.Tensor.unbind: + if len(args) > 1: + dim = args[1] + elif "dim" in kwargs: + dim = kwargs["dim"] + else: + dim = 0 + if dim == 0: + ret.meta = metas[idx] + ret.is_batch = False + + ret.affine = ret.affine.to(ret.device) + out.append(ret) + # if the input was a tuple, then return it as a tuple + return tuple(out) if isinstance(rets, tuple) else out + @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: + def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: """Wraps all torch functions.""" if kwargs is None: kwargs = {} - ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) - # e.g., __repr__ returns a string - if not isinstance(ret, torch.Tensor): + ret = super().__torch_function__(func, types, args, kwargs) + # if `out` has been used as argument, metadata is not copied, nothing to do. + if "out" in kwargs: return ret - if not (get_track_meta() or get_track_transforms()): - return ret.as_tensor() - meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) - ret._copy_meta(meta_args) - ret.affine = ret.affine.to(ret.device) - return ret + # we might have 1 or multiple outputs. Might be MetaTensor, might be something + # else (e.g., `__repr__` returns a string). + # Convert to list (if necessary), process, and at end remove list if one was added. + if not isinstance(ret, Sequence): + ret = [ret] + unpack = True + else: + unpack = False + ret = MetaTensor.update_meta(ret, func, args, kwargs) + return ret[0] if unpack else ret def get_default_affine(self, dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=self.device, dtype=dtype) diff --git a/monai/data/utils.py b/monai/data/utils.py index 495daf15e2..2bd7b49731 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,6 +28,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike +from monai.data.meta_obj import MetaObj from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -346,9 +347,17 @@ def list_data_collate(batch: Sequence): ret = {} for k in elem: key = k - ret[key] = default_collate([d[key] for d in data]) - return ret - return default_collate(data) + data_for_batch = [d[key] for d in data] + ret[key] = default_collate(data_for_batch) + if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch): + ret[key].meta = list_data_collate([i.meta for i in data_for_batch]) + ret[key].is_batch = True + else: + ret = default_collate(data) + if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data): + ret.meta = list_data_collate([i.meta for i in data]) + ret.is_batch = True + return ret except RuntimeError as re: re_str = str(re) if "equal size" in re_str: @@ -466,6 +475,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) + # if of type MetaObj, decollate the metadata + if isinstance(batch, MetaObj) and all(isinstance(i, MetaObj) for i in out_list): + metas = decollate_batch(batch.meta) + for i in range(len(out_list)): + out_list[i].meta = metas[i] # type: ignore + out_list[i].is_batch = False # type: ignore if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index c18ef08b85..05356fcc84 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -21,11 +21,13 @@ import torch from parameterized import parameterized +from monai.data import DataLoader, Dataset from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix from monai.utils.module import pytorch_after -from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda +from tests.utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] TESTS = [] @@ -59,6 +61,17 @@ def check_ids(self, a, b, should_match): comp = self.assertEqual if should_match else self.assertNotEqual comp(id(a), id(b)) + def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: + self.assertEqual(a.is_batch, b.is_batch) + meta_a, meta_b = a.meta, b.meta + # need to split affine from rest of metadata + aff_a = meta_a.get("affine", None) + aff_b = meta_b.get("affine", None) + assert_allclose(aff_a, aff_b) + meta_a = {k: v for k, v in meta_a.items() if k != "affine"} + meta_b = {k: v for k, v in meta_b.items() if k != "affine"} + self.assertEqual(meta_a, meta_b) + def check( self, out: torch.Tensor, @@ -87,12 +100,7 @@ def check( # check meta and affine are equal and affine is on correct device if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: - orig_meta_no_affine = deepcopy(orig.meta) - del orig_meta_no_affine["affine"] - out_meta_no_affine = deepcopy(out.meta) - del out_meta_no_affine["affine"] - self.assertEqual(orig_meta_no_affine, out_meta_no_affine) - assert_allclose(out.affine, orig.affine) + self.check_meta(orig, out) self.assertTrue(str(device) in str(out.affine.device)) if check_ids: self.check_ids(out.affine, orig.affine, ids) @@ -261,12 +269,146 @@ def test_amp(self): im_conv2 = conv(im) self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) - # TODO - # collate - # decollate - # dataset - # dataloader - # matplotlib + def test_out(self): + """Test when `out` is given as an argument.""" + m1, _ = self.get_im() + m1_orig = deepcopy(m1) + m2, _ = self.get_im() + m3, _ = self.get_im() + torch.add(m2, m3, out=m1) + m1_add = m2 + m3 + + assert_allclose(m1, m1_add) + self.check_meta(m1, m1_orig) + + @parameterized.expand(TESTS) + def test_collate(self, device, dtype): + numel = 3 + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] + collated = list_data_collate(ims) + # tensor + self.assertIsInstance(collated, MetaTensor) + expected_shape = (numel,) + tuple(ims[0].shape) + self.assertTupleEqual(tuple(collated.shape), expected_shape) + for i, im in enumerate(ims): + self.check(im, ims[i], ids=True) + # affine + self.assertIsInstance(collated.affine, torch.Tensor) + expected_shape = (numel,) + tuple(ims[0].affine.shape) + self.assertTupleEqual(tuple(collated.affine.shape), expected_shape) + + @parameterized.expand(TESTS) + def test_dataset(self, device, dtype): + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(4)] + ds = Dataset(ims) + for i, im in enumerate(ds): + self.check(im, ims[i], ids=True) + + @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) + def test_dataloader(self, dtype): + batch_size = 5 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + im_shape = tuple(ims[0].shape) + affine_shape = tuple(ims[0].affine.shape) + expected_im_shape = (batch_size,) + im_shape + expected_affine_shape = (batch_size,) + affine_shape + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + for batch in dl: + self.assertIsInstance(batch, MetaTensor) + self.assertTupleEqual(tuple(batch.shape), expected_im_shape) + self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + + @SkipIfBeforePyTorchVersion((1, 9)) + def test_indexing(self): + """ + Check the metadata is returned in the expected format depending on whether + the input `MetaTensor` is a batch of data or not. + """ + ims = [self.get_im()[0] for _ in range(5)] + data = list_data_collate(ims) + + # check that when using non-batch data, metadata is copied wholly when indexing + # or iterating across data. + im = ims[0] + self.check_meta(im[0], im) + self.check_meta(next(iter(im)), im) + + # index + d = data[0] + self.check(d, ims[0], ids=False) + + # iter + d = next(iter(data)) + self.check(d, ims[0], ids=False) + + # complex indexing + + # `is_batch==True`, should have subset of image and metadata. + d = data[1:3] + self.check(d, list_data_collate(ims[1:3]), ids=False) + + # is_batch==True, should have subset of image and same metadata as `[1:3]`. + d = data[1:3, 0] + self.check(d, list_data_collate([i[0] for i in ims[1:3]]), ids=False) + + # `is_batch==False`, should have first metadata and subset of first image. + d = data[0, 0] + self.check(d, ims[0][0], ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[:, 0] + self.check(d, list_data_collate([i[0] for i in ims]), ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[..., -1] + self.check(d, list_data_collate([i[..., -1] for i in ims]), ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(dim=0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(dim=-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + + @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) + def test_decollate(self, dtype): + batch_size = 3 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + batch = next(iter(dl)) + decollated = decollate_batch(batch) + self.assertIsInstance(decollated, list) + self.assertEqual(len(decollated), batch_size) + for elem, im in zip(decollated, ims): + self.assertIsInstance(elem, MetaTensor) + self.check(elem, im, ids=False) if __name__ == "__main__":