Skip to content

Commit

Permalink
to and from meta tensor
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Brown <[email protected]>
  • Loading branch information
rijobro committed Apr 12, 2022
1 parent 1880d38 commit deb7951
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 10 deletions.
12 changes: 9 additions & 3 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import warnings
from copy import deepcopy
from typing import Callable

import torch
Expand Down Expand Up @@ -54,10 +55,10 @@ class MetaTensor(MetaObj, torch.Tensor):
Notes:
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
- A warning will be raised if in the constructor `affine` is not `None` and
`meta` already contains the key `affine`.
`meta` already contains the key `affine`.
"""

@staticmethod
Expand Down Expand Up @@ -89,6 +90,11 @@ def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None
out.affine = out.get_default_affine()
out.affine = out.affine.to(out.device)

# if we are creating a new MetaTensor, then deep copy attributes
if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor):
out.meta = deepcopy(out.meta)
out.affine = out.affine.to(out.device)

return out

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
Expand Down
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,14 @@
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .meta_utility.dictionary import (
FromMetaTensord,
FromMetaTensorD,
FromMetaTensorDict,
ToMetaTensord,
ToMetaTensorD,
ToMetaTensorDict,
)
from .nvtx import (
Mark,
Markd,
Expand Down
10 changes: 10 additions & 0 deletions monai/transforms/meta_utility/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
102 changes: 102 additions & 0 deletions monai/transforms/meta_utility/dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A collection of dictionary-based wrappers for moving between MetaTensor types and dictionaries of data.
These can be used to make backwards compatible code.
Class names are ended with 'd' to denote dictionary-based transforms.
"""

from copy import deepcopy
from typing import Dict, Hashable, Mapping

from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_tensor import MetaTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform
from monai.utils.enums import PostFix, TransformBackends

__all__ = [
"FromMetaTensord",
"FromMetaTensorD",
"FromMetaTensorDict",
"ToMetaTensord",
"ToMetaTensorD",
"ToMetaTensorDict",
]


class FromMetaTensord(MapTransform, InvertibleTransform):
"""
Dictionary-based transform to convert MetaTensor to a dictionary.
If input is `{"a": MetaTensor, "b": MetaTensor}`, then output will
have the form `{"a": torch.Tensor, "a_meta_dict": dict, "b": ...}`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
im: MetaTensor = d[key] # type: ignore
d.update(im.as_dict(key))
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
# check transform
_ = self.get_most_recent_transform(d, key)
# do the inverse
im, meta = d[key], d.pop(PostFix.meta(key), None)
im = MetaTensor(im, meta=meta) # type: ignore
d[key] = im
# Remove the applied transform
self.pop_transform(d, key)
return d


class ToMetaTensord(MapTransform, InvertibleTransform):
"""
Dictionary-based transform to convert a dictionary to MetaTensor.
If input is `{"a": torch.Tensor, "a_meta_dict": dict, "b": ...}`, then output will
have the form `{"a": MetaTensor, "b": MetaTensor}`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
im, meta = d[key], d.pop(PostFix.meta(key), None)
im = MetaTensor(im, meta=meta) # type: ignore
d[key] = im
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
# check transform
_ = self.get_most_recent_transform(d, key)
# do the inverse
im: MetaTensor = d[key] # type: ignore
d.update(im.as_dict(key))
# Remove the applied transform
self.pop_transform(d, key)
return d


FromMetaTensorD = FromMetaTensorDict = FromMetaTensord
ToMetaTensorD = ToMetaTensorDict = ToMetaTensord
18 changes: 11 additions & 7 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ class ForwardMode(Enum):
class TraceKeys:
"""Extra meta data keys used for traceable transforms."""

CLASS_NAME = "class"
ID = "id"
ORIG_SIZE = "orig_size"
EXTRA_INFO = "extra_info"
DO_TRANSFORM = "do_transforms"
KEY_SUFFIX = "_transforms"
NONE = "none"
CLASS_NAME: str = "class"
ID: str = "id"
ORIG_SIZE: str = "orig_size"
EXTRA_INFO: str = "extra_info"
DO_TRANSFORM: str = "do_transforms"
KEY_SUFFIX: str = "_transforms"
NONE: str = "none"


@deprecated(since="0.8.0", msg_suffix="use monai.utils.enums.TraceKeys instead.")
Expand Down Expand Up @@ -287,6 +287,10 @@ def meta(key: Optional[str] = None):
def orig_meta(key: Optional[str] = None):
return PostFix._get_str(key, "orig_meta_dict")

@staticmethod
def transforms(key: Optional[str] = None):
return PostFix._get_str(key, TraceKeys.KEY_SUFFIX[1:])


class TransformBackends(Enum):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_module_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_transform_api(self):
to_exclude = {"MapTransform"} # except for these transforms
to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision"}
to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"})
to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"})
xforms = {
name: obj
for name, obj in monai.transforms.__dict__.items()
Expand Down
182 changes: 182 additions & 0 deletions tests/test_to_from_meta_tensord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import string
import unittest
from copy import deepcopy
from typing import Optional, Union

import torch
from parameterized import parameterized

from monai.data.meta_tensor import MetaTensor
from monai.transforms import FromMetaTensord, ToMetaTensord
from monai.utils.enums import PostFix
from monai.utils.module import get_torch_version_tuple
from tests.utils import TEST_DEVICES, assert_allclose

PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple()

DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]]
TESTS = []
for _device in TEST_DEVICES:
for _dtype in DTYPES:
TESTS.append((*_device, *_dtype))


def rand_string(min_len=5, max_len=10):
str_size = random.randint(min_len, max_len)
chars = string.ascii_letters + string.punctuation
return "".join(random.choice(chars) for _ in range(str_size))


class TestToFromMetaTensord(unittest.TestCase):
@staticmethod
def get_im(shape=None, dtype=None, device=None):
if shape is None:
shape = shape = (1, 10, 8)
affine = torch.randint(0, 10, (4, 4))
meta = {"fname": rand_string()}
t = torch.rand(shape)
if dtype is not None:
t = t.to(dtype)
if device is not None:
t = t.to(device)
m = MetaTensor(t.clone(), affine, meta)
return m

def check_ids(self, a, b, should_match):
comp = self.assertEqual if should_match else self.assertNotEqual
comp(id(a), id(b))

def check(
self,
out: torch.Tensor,
orig: torch.Tensor,
*,
shape: bool = True,
vals: bool = True,
ids: bool = True,
device: Optional[Union[str, torch.device]] = None,
meta: bool = True,
check_ids: bool = True,
**kwargs,
):
if device is None:
device = orig.device

# check the image
self.assertIsInstance(out, type(orig))
if shape:
assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape))
if vals:
assert_allclose(out, orig, **kwargs)
if check_ids:
self.check_ids(out, orig, ids)
self.assertTrue(str(device) in str(out.device))

# check meta and affine are equal and affine is on correct device
if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta:
orig_meta_no_affine = deepcopy(orig.meta)
del orig_meta_no_affine["affine"]
out_meta_no_affine = deepcopy(out.meta)
del out_meta_no_affine["affine"]
self.assertEqual(orig_meta_no_affine, out_meta_no_affine)
assert_allclose(out.affine, orig.affine)
self.assertTrue(str(device) in str(out.affine.device))
if check_ids:
self.check_ids(out.affine, orig.affine, ids)
self.check_ids(out.meta, orig.meta, ids)

@parameterized.expand(TESTS)
def test_from_to_meta_tensord(self, device, dtype):
m1 = self.get_im(device=device, dtype=dtype)
m2 = self.get_im(device=device, dtype=dtype)
m3 = self.get_im(device=device, dtype=dtype)
d_metas = {"m1": m1, "m2": m2, "m3": m3}
m1_meta = {k: v for k, v in m1.meta.items() if k != "affine"}
m1_aff = m1.affine

# FROM -> forward
t_from_meta = FromMetaTensord(["m1", "m2"])
d_dict = t_from_meta(d_metas)

self.assertEqual(
sorted(d_dict.keys()),
[
"m1",
PostFix.meta("m1"),
PostFix.transforms("m1"),
"m2",
PostFix.meta("m2"),
PostFix.transforms("m2"),
"m3",
],
)
self.check(d_dict["m3"], m3, ids=True) # unchanged
self.check(d_dict["m1"], m1.as_tensor(), ids=False)
meta_out = {k: v for k, v in d_dict["m1_meta_dict"].items() if k != "affine"}
aff_out = d_dict["m1_meta_dict"]["affine"]
self.check(aff_out, m1_aff, ids=True)
self.assertEqual(meta_out, m1_meta)

# FROM -> inverse
d_meta_dict_meta = t_from_meta.inverse(d_dict)
self.assertEqual(
sorted(d_meta_dict_meta.keys()), ["m1", PostFix.transforms("m1"), "m2", PostFix.transforms("m2"), "m3"]
)
self.check(d_meta_dict_meta["m3"], m3, ids=False) # unchanged (except deep copy in inverse)
self.check(d_meta_dict_meta["m1"], m1, ids=False)
meta_out = {k: v for k, v in d_meta_dict_meta["m1"].meta.items() if k != "affine"}
aff_out = d_meta_dict_meta["m1"].affine
self.check(aff_out, m1_aff, ids=False)
self.assertEqual(meta_out, m1_meta)

# TO -> Forward
t_to_meta = ToMetaTensord(["m1", "m2"])
del d_dict["m1_transforms"]
del d_dict["m2_transforms"]
d_dict_meta = t_to_meta(d_dict)
self.assertEqual(
sorted(d_dict_meta.keys()), ["m1", PostFix.transforms("m1"), "m2", PostFix.transforms("m2"), "m3"]
)
self.check(d_dict_meta["m3"], m3, ids=True) # unchanged (except deep copy in inverse)
self.check(d_dict_meta["m1"], m1, ids=False)
meta_out = {k: v for k, v in d_dict_meta["m1"].meta.items() if k != "affine"}
aff_out = d_dict_meta["m1"].meta["affine"]
self.check(aff_out, m1_aff, ids=False)
self.assertEqual(meta_out, m1_meta)

# TO -> Inverse
d_dict_meta_dict = t_to_meta.inverse(d_dict_meta)
self.assertEqual(
sorted(d_dict_meta_dict.keys()),
[
"m1",
PostFix.meta("m1"),
PostFix.transforms("m1"),
"m2",
PostFix.meta("m2"),
PostFix.transforms("m2"),
"m3",
],
)
self.check(d_dict_meta_dict["m3"], m3.as_tensor(), ids=False) # unchanged (except deep copy in inverse)
self.check(d_dict_meta_dict["m1"], m1.as_tensor(), ids=False)
meta_out = {k: v for k, v in d_dict_meta_dict["m1_meta_dict"].items() if k != "affine"}
aff_out = d_dict_meta_dict["m1_meta_dict"]["affine"]
self.check(aff_out, m1_aff, ids=False)
self.assertEqual(meta_out, m1_meta)


if __name__ == "__main__":
unittest.main()

0 comments on commit deb7951

Please sign in to comment.