Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta tensor channel #4222

Merged
merged 34 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0dd6c87
MetaTensor channel transforms
rijobro May 4, 2022
e89e2b2
fixes
rijobro May 4, 2022
d89f01c
typo
rijobro May 4, 2022
4f27daf
fixes
rijobro May 4, 2022
a0bcbc0
fix
rijobro May 4, 2022
4383f24
remove deepcopy
rijobro May 4, 2022
403986a
fixes
rijobro May 6, 2022
16aa46e
fix
rijobro May 6, 2022
18f651d
fix
rijobro May 6, 2022
9ea5331
Merge branch 'feature/MetaTensor' into MetaTensor_channel
rijobro May 6, 2022
13f7c4d
Merge branch 'feature/MetaTensor' into MetaTensor_channel
rijobro May 6, 2022
6a5f888
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_channel
rijobro May 9, 2022
d256143
metatensor convert helper
rijobro May 10, 2022
04466a4
autofix
rijobro May 10, 2022
f66b19b
Merge branch 'feature/MetaTensor' into MetaTensor_channel
wyli May 10, 2022
c8240d2
Revert "Merge remote-tracking branch 'MONAI/dev' into MetaTensor_chan…
rijobro May 9, 2022
14f50ec
Merge remote-tracking branch 'MONAI/feature/MetaTensor' into MetaTens…
rijobro May 10, 2022
6905640
fix
rijobro May 10, 2022
83c11a0
FIXME
rijobro May 10, 2022
a4488dd
update_meta docstring
rijobro May 11, 2022
dc18f06
fix test
rijobro May 11, 2022
cb11fb2
Merge branch 'MetaTensor_channel' of github.com:rijobro/MONAI into Me…
rijobro May 11, 2022
ce70dac
fix
rijobro May 11, 2022
d7d9d7d
fix str
rijobro May 11, 2022
9ee79c5
format, update str
wyli May 11, 2022
ac06a89
fixes comp. torch.solve for metatensor
wyli May 11, 2022
93f1eba
fixes inverse collation
wyli May 11, 2022
02c95ef
fixes resample to match
wyli May 11, 2022
f1d5608
fixes resample to matchd
wyli May 11, 2022
9095905
fixes integration bundle run
wyli May 11, 2022
ad4d78f
fixes image dataset test
wyli May 11, 2022
ed700d6
[MONAI] python code formatting
monai-bot May 11, 2022
f74a784
fixes integration
wyli May 11, 2022
0f23384
fixes mypy
wyli May 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/deepgrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def _default_transforms(image_key, label_key, pixdim):
return Compose(
[
LoadImaged(keys=keys),
AsChannelFirstd(keys=keys),
FromMetaTensord(keys=keys),
ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]),
AsChannelFirstd(keys=keys),
Orientationd(keys=keys, axcodes="RAS"),
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
]
Expand Down
9 changes: 8 additions & 1 deletion monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.config import DtypeLike, KeysCollection, PathLike
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format, orientation_ras_lps
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg
Expand Down Expand Up @@ -833,8 +834,14 @@ def get_data(
metadata["spatial_shape"] = np.asarray(region.shape[:-1])
metadata["original_channel_dim"] = -1

# combine image and metadata
if not isinstance(region, MetaTensor):
rijobro marked this conversation as resolved.
Show resolved Hide resolved
region = MetaTensor(region, meta=metadata)

# Make it channel first
region = EnsureChannelFirst()(region, metadata)
# FIXME: would be nice not to convert -> MetaTensor -> numpy
region = EnsureChannelFirst()(region).numpy()
rijobro marked this conversation as resolved.
Show resolved Hide resolved
del metadata["affine"] # automatically created but not needed

# Split into patches
if patch_size is None:
Expand Down
40 changes: 33 additions & 7 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,35 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call
setattr(self, attribute, val.to(self.device))

@staticmethod
def update_meta(rets: Sequence, func, args, kwargs):
"""Update the metadata from the output of `__torch_function__`.
The output could be a single object, or a sequence of them. Hence, they get
converted to a sequence if necessary and then processed by iterating across them.
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
"""
Update the metadata from the output of `MetaTensor.__torch_function__`.

The output of `torch.Tensor.__torch_function__` could be a single object or a
sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
list of not already, and then we loop across each element, processing metadata
as necessary. For each element, if not of type `MetaTensor`, then nothing to do.

Args:
rets: the output from `torch.Tensor.__torch_function__`, which has been
converted to a list in `MetaTensor.__torch_function__` if it wasn't
already a `Sequence`.
func: the torch function that was applied. Examples might be `torch.squeeze`
or `torch.Tensor.__add__`. We need this since the metadata need to be
treated differently if a batch of data is considered. For example,
slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
dimension of a batch of data should return a ith tensor with the ith
metadata.
args: positional arguments that were passed to `func`.
kwargs: keyword arguments that were passed to `func`.

For each element, if not of type `MetaTensor`, then nothing to do
Returns:
A sequence with the same number of elements as `rets`. For each element, if
the input type was not `MetaTensor`, then no modifications will have been
made. If global parameters have been set to false (e.g.,
`not get_track_meta()`), then any `MetaTensor` will be converted to
`torch.Tensor`. Else, metadata will be propogated as necessary (see
:py:func:`MetaTensor._copy_meta`).
"""
out = []
metas = None
Expand Down Expand Up @@ -192,8 +215,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
# return ret
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
# else (e.g., `__repr__` returns a string).
# Convert to list (if necessary), process, and at end remove list if one was added.
if not isinstance(ret, Sequence):
# Convert to list (if necessary), process, and at end remove list if one was added.s
if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence):
ret = [ret]
unpack = True
else:
Expand Down Expand Up @@ -275,3 +298,6 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):

# return the `MetaTensor`
return MetaTensor(img, meta=meta)

def __repr__(self, *, tensor_contents=None):
return self.as_tensor().__repr__() + super().__repr__()
2 changes: 1 addition & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def __call__(self, prob_map: NdarrayOrTensor):
if self.sigma != 0:
if not isinstance(prob_map, torch.Tensor):
prob_map = torch.as_tensor(prob_map, dtype=torch.float)
self.filter.to(prob_map)
self.filter.to(prob_map.device)
prob_map = self.filter(prob_map)

prob_map_shape = prob_map.shape
Expand Down
5 changes: 5 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij, normalize_transform
Expand Down Expand Up @@ -216,6 +217,10 @@ def __call__(
if isinstance(src_affine, np.ndarray):
xform = np.linalg.solve(src_affine, dst_affine)
else:
if isinstance(src_affine, MetaTensor):
src_affine = src_affine.as_tensor()
if isinstance(dst_affine, MetaTensor):
dst_affine = dst_affine.as_tensor()
xform = (
torch.linalg.solve(src_affine, dst_affine)
if pytorch_after(1, 8, 0)
Expand Down
14 changes: 6 additions & 8 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sys
import time
import warnings
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -207,20 +207,18 @@ def __init__(self, strict_check: bool = True):
self.strict_check = strict_check
self.add_channel = AddChannel()

def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor:
def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""
Apply the transform to `img`.
"""
if isinstance(img, MetaTensor):
meta_dict = img.meta
if not isinstance(meta_dict, Mapping):
if not isinstance(img, MetaTensor):
msg = "meta_dict not available, EnsureChannelFirst is not in use."
if self.strict_check:
raise ValueError(msg)
warnings.warn(msg)
return img

channel_dim = meta_dict.get("original_channel_dim")
channel_dim = img.meta.get("original_channel_dim")

if channel_dim is None:
msg = "Unknown original_channel_dim in the meta_dict, EnsureChannelFirst is not in use."
Expand All @@ -229,8 +227,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) ->
warnings.warn(msg)
return img
if channel_dim == "no_channel":
return self.add_channel(img)
return AsChannelFirst(channel_dim=channel_dim)(img)
return self.add_channel(img) # type: ignore
return AsChannelFirst(channel_dim=channel_dim)(img) # type: ignore


class RepeatChannel(Transform):
Expand Down
47 changes: 17 additions & 30 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import no_collation
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
Expand Down Expand Up @@ -291,6 +292,8 @@ class EnsureChannelFirstd(MapTransform):

backend = EnsureChannelFirst.backend

@deprecated_arg(name="meta_keys", since="0.8", msg_suffix="not needed if image is type `MetaTensor`.")
@deprecated_arg(name="meta_key_postfix", since="0.8", msg_suffix="not needed if image is type `MetaTensor`.")
def __init__(
self,
keys: KeysCollection,
Expand All @@ -302,26 +305,16 @@ def __init__(
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
meta_keys: explicitly indicate the key of the corresponding meta data dictionary.
for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
the meta data is a dictionary object which contains: filename, original_shape, etc.
it can be a sequence of string, map to the `keys`.
if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
meta_key_postfix: if meta_keys is None and `key_{postfix}` was used to store the metadata in `LoadImaged`.
So need the key to extract metadata for channel dim information, default is `meta_dict`.
For example, for data with key `image`, metadata by default is in `image_meta_dict`.
strict_check: whether to raise an error when the meta information is insufficient.

"""
super().__init__(keys)
self.adjuster = EnsureChannelFirst(strict_check=strict_check)
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))

def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix):
d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"])
for key in self.key_iterator(d):
d[key] = self.adjuster(d[key])
return d


Expand Down Expand Up @@ -416,24 +409,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
split_key = f"{key}_{postfixes[i]}"
if split_key in d:
raise RuntimeError(f"input data already contains key {split_key}.")
d[split_key] = r

if self.update_meta:
orig_meta = d.get(PostFix.meta(key), None)
if orig_meta is not None:
split_meta_key = PostFix.meta(split_key)
d[split_meta_key] = deepcopy(orig_meta)
dim = self.splitter.dim
if dim > 0: # don't update affine if channel dim
affine = d[split_meta_key]["affine"] # type: ignore
ndim = len(affine)
shift: NdarrayOrTensor
if isinstance(affine, torch.Tensor):
shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype)
else:
shift = np.eye(ndim)
shift[dim - 1, -1] = i # type: ignore
d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore
if self.update_meta and isinstance(r, MetaTensor):
r.meta = deepcopy(r.meta)
dim = self.splitter.dim
if dim > 0: # don't update affine if channel dim
affine = r.affine
ndim = len(r.affine)
shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype)
shift[dim - 1, -1] = i
r.affine = r.affine @ shift

d[split_key] = r

return d

Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
convert_data_type,
convert_to_cupy,
convert_to_dst_type,
convert_to_meta_tensor,
convert_to_numpy,
convert_to_tensor,
dtype_numpy_to_torch,
Expand Down
78 changes: 75 additions & 3 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"convert_to_cupy",
"convert_to_numpy",
"convert_to_tensor",
"convert_to_meta_tensor",
"convert_to_dst_type",
]

Expand Down Expand Up @@ -74,7 +75,7 @@ def get_equivalent_dtype(dtype, data_type):
"""
if dtype is None:
return None
if data_type is torch.Tensor:
if data_type is torch.Tensor or data_type.__name__ == "MetaTensor":
if isinstance(dtype, torch.dtype):
# already a torch dtype and target `data_type` is torch.Tensor
return dtype
Expand Down Expand Up @@ -116,7 +117,12 @@ def convert_to_tensor(
E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`.

"""
# avoids circular import
from monai.data.meta_tensor import MetaTensor

if isinstance(data, torch.Tensor):
if isinstance(data, MetaTensor):
data = data.as_tensor()
return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore
if isinstance(data, np.ndarray):
# skip array of string classes and object, refer to:
Expand All @@ -141,6 +147,58 @@ def convert_to_tensor(
return data


def convert_to_meta_tensor(
data, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = False
):
"""
Utility to convert the input data to a MetaTensor. If passing a dictionary, list or tuple,
recursively check every item and convert it to MetaTensor.

Args:
data: input data can be MetaTensor, PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original.
for dictionary, list or tuple, convert every item to a Tensor if applicable.
dtype: target data type to when converting to Tensor.
device: target device to put the converted Tensor data.
wrap_sequence: if `False`, then lists will recursively call this function.
E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`.

"""
# avoids circular import
from monai.data.meta_tensor import MetaTensor

if isinstance(data, torch.Tensor):
out = data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore
if not isinstance(out, MetaTensor):
out = MetaTensor(out)
return out
if isinstance(data, np.ndarray):
# skip array of string classes and object, refer to:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13
if re.search(r"[SaUO]", data.dtype.str) is None:
# numpy array with 0 dims is also sequence iterable,
# `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims
if data.ndim > 0:
data = np.ascontiguousarray(data)
return MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore
elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)):
return MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore
elif isinstance(data, list):
list_ret = [convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data]
return (
MetaTensor(torch.as_tensor(list_ret, dtype=dtype, device=device)) if wrap_sequence else list_ret # type: ignore
) # type: ignore
elif isinstance(data, tuple):
tuple_ret = tuple(convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data)
return (
MetaTensor(torch.as_tensor(tuple_ret, dtype=dtype, device=device)) if wrap_sequence else tuple_ret # type: ignore
) # type: ignore
elif isinstance(data, dict):
return {k: convert_to_meta_tensor(v, dtype=dtype, device=device) for k, v in data.items()}

return data


def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False):
"""
Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple,
Expand Down Expand Up @@ -241,8 +299,13 @@ def convert_data_type(
(1.0, <class 'torch.Tensor'>, None)

"""
# avoids circular import
from monai.data.meta_tensor import MetaTensor

orig_type: type
if isinstance(data, torch.Tensor):
if isinstance(data, MetaTensor):
orig_type = MetaTensor
elif isinstance(data, torch.Tensor):
orig_type = torch.Tensor
elif isinstance(data, np.ndarray):
orig_type = np.ndarray
Expand All @@ -258,6 +321,10 @@ def convert_data_type(
dtype_ = get_equivalent_dtype(dtype, output_type)

data_: NdarrayTensor

if issubclass(output_type, MetaTensor):
data_ = convert_to_meta_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence)
return data_, orig_type, orig_device
if issubclass(output_type, torch.Tensor):
data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence)
return data_, orig_type, orig_device
Expand Down Expand Up @@ -289,12 +356,17 @@ def convert_to_dst_type(
See Also:
:func:`convert_data_type`
"""
# avoids circular import
from monai.data.meta_tensor import MetaTensor

device = dst.device if isinstance(dst, torch.Tensor) else None
if dtype is None:
dtype = dst.dtype

output_type: Any
if isinstance(dst, torch.Tensor):
if isinstance(dst, MetaTensor):
output_type = MetaTensor
elif isinstance(dst, torch.Tensor):
output_type = torch.Tensor
elif isinstance(dst, np.ndarray):
output_type = np.ndarray
Expand Down
Loading