Skip to content

Commit

Permalink
MetaTensor: collate; decollate; dataset; dataloader; out=; indexing…
Browse files Browse the repository at this point in the history
… and iterating across batches (#4137)

`MetaTensor`: collate; decollate; dataset; dataloader; out=; indexing and iterating across batches (#4137)
  • Loading branch information
rijobro authored Apr 20, 2022
1 parent cb9040e commit 6dfb6a8
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 27 deletions.
13 changes: 13 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class MetaObj:

def __init__(self):
self._meta: dict = self.get_default_meta()
self._is_batch: bool = False

@staticmethod
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
Expand Down Expand Up @@ -176,6 +177,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None:
id_in = id(input_objs[0]) if len(input_objs) > 0 else None
deep_copy = id(self) != id_in
self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy)
self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False

def get_default_meta(self) -> dict:
"""Get the default meta.
Expand All @@ -194,6 +196,7 @@ def __repr__(self) -> str:
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items())
else:
out += "None"
out += f"\nIs batch?: {self.is_batch}"

return out

Expand All @@ -206,3 +209,13 @@ def meta(self) -> dict:
def meta(self, d: dict) -> None:
"""Set the meta."""
self._meta = d

@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
return self._is_batch

@is_batch.setter
def is_batch(self, val: bool) -> None:
"""Set whether object is part of batch or not."""
self._is_batch = val
104 changes: 93 additions & 11 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

import warnings
from copy import deepcopy
from typing import Callable
from typing import Any, Callable, Sequence

import torch

from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms
from monai.data.utils import decollate_batch, list_data_collate
from monai.utils.enums import PostFix

__all__ = ["MetaTensor"]
Expand Down Expand Up @@ -54,11 +55,20 @@ class MetaTensor(MetaObj, torch.Tensor):
assert m2.affine == affine
Notes:
- Requires pytorch 1.9 or newer for full compatibility.
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
- A warning will be raised if in the constructor `affine` is not `None` and
`meta` already contains the key `affine`.
- You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.
- With a batch of data, `batch[0]` will return the 0th image
with the 0th metadata. When the batch dimension is non-singleton, e.g.,
`batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the
last example) of the metadata will be returned, and `is_batch` will return `True`.
- When creating a batch with this class, use `monai.data.DataLoader` as opposed
to `torch.utils.data.DataLoader`, as this will take care of collating the
metadata properly.
"""

@staticmethod
Expand Down Expand Up @@ -101,21 +111,93 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call
if isinstance(val, torch.Tensor):
setattr(self, attribute, val.to(self.device))

@staticmethod
def update_meta(rets: Sequence, func, args, kwargs):
"""Update the metadata from the output of `__torch_function__`.
The output could be a single object, or a sequence of them. Hence, they get
converted to a sequence if necessary and then processed by iterating across them.
For each element, if not of type `MetaTensor`, then nothing to do
"""
out = []
metas = None
for idx, ret in enumerate(rets):
# if not `MetaTensor`, nothing to do.
if not isinstance(ret, MetaTensor):
pass
# if not tracking, convert to `torch.Tensor`.
elif not (get_track_meta() or get_track_transforms()):
ret = ret.as_tensor()
# else, handle the `MetaTensor` metadata.
else:
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
ret._copy_meta(meta_args)

# If we have a batch of data, then we need to be careful if a slice of
# the data is returned. Depending on how the data are indexed, we return
# some or all of the metadata, and the return object may or may not be a
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
if ret.is_batch:
# only decollate metadata once
if metas is None:
metas = decollate_batch(ret.meta)
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
idx = args[1]
if isinstance(idx, Sequence):
idx = idx[0]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if idx not in (slice(None, None, None), Ellipsis):
meta = metas[idx]
# if using e.g., `batch[0:2]`, then `is_batch` should still be
# `True`. Also re-collate the remaining elements.
if isinstance(meta, list) and len(meta) > 1:
ret.meta = list_data_collate(meta)
# if using e.g., `batch[0]` or `batch[0, 1]`, then return single
# element from batch, and set `is_batch` to `False`.
else:
ret.meta = meta
ret.is_batch = False
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th
# dimension.
elif func == torch.Tensor.unbind:
if len(args) > 1:
dim = args[1]
elif "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 0
if dim == 0:
ret.meta = metas[idx]
ret.is_batch = False

ret.affine = ret.affine.to(ret.device)
out.append(ret)
# if the input was a tuple, then return it as a tuple
return tuple(out) if isinstance(rets, tuple) else out

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
"""Wraps all torch functions."""
if kwargs is None:
kwargs = {}
ret: MetaTensor = super().__torch_function__(func, types, args, kwargs)
# e.g., __repr__ returns a string
if not isinstance(ret, torch.Tensor):
ret = super().__torch_function__(func, types, args, kwargs)
# if `out` has been used as argument, metadata is not copied, nothing to do.
if "out" in kwargs:

This comment has been minimized.

Copy link
@wyli

wyli May 3, 2022

Contributor
return ret
if not (get_track_meta() or get_track_transforms()):
return ret.as_tensor()
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
ret._copy_meta(meta_args)
ret.affine = ret.affine.to(ret.device)
return ret
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
# else (e.g., `__repr__` returns a string).
# Convert to list (if necessary), process, and at end remove list if one was added.
if not isinstance(ret, Sequence):
ret = [ret]
unpack = True
else:
unpack = False
ret = MetaTensor.update_meta(ret, func, args, kwargs)
return ret[0] if unpack else ret

def get_default_affine(self, dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=self.device, dtype=dtype)
Expand Down
21 changes: 18 additions & 3 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.utils.data._utils.collate import default_collate

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.data.meta_obj import MetaObj
from monai.networks.layers.simplelayers import GaussianFilter
from monai.utils import (
MAX_SEED,
Expand Down Expand Up @@ -346,9 +347,17 @@ def list_data_collate(batch: Sequence):
ret = {}
for k in elem:
key = k
ret[key] = default_collate([d[key] for d in data])
return ret
return default_collate(data)
data_for_batch = [d[key] for d in data]
ret[key] = default_collate(data_for_batch)
if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch):
ret[key].meta = list_data_collate([i.meta for i in data_for_batch])
ret[key].is_batch = True
else:
ret = default_collate(data)
if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data):
ret.meta = list_data_collate([i.meta for i in data])
ret.is_batch = True
return ret
except RuntimeError as re:
re_str = str(re)
if "equal size" in re_str:
Expand Down Expand Up @@ -466,6 +475,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
if batch.ndim == 0:
return batch.item() if detach else batch
out_list = torch.unbind(batch, dim=0)
# if of type MetaObj, decollate the metadata
if isinstance(batch, MetaObj) and all(isinstance(i, MetaObj) for i in out_list):
metas = decollate_batch(batch.meta)
for i in range(len(out_list)):
out_list[i].meta = metas[i] # type: ignore
out_list[i].is_batch = False # type: ignore
if out_list[0].ndim == 0 and detach:
return [t.item() for t in out_list]
return list(out_list)
Expand Down
Loading

0 comments on commit 6dfb6a8

Please sign in to comment.