From 3d98c8eda4d21a686d98d108ad0076d2214fea46 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Thu, 14 Apr 2022 16:05:31 +0100 Subject: [PATCH] sync feature branch with dev (#4138) * update citation (#4133) Signed-off-by: Wenqi Li * `ToMetaTensor` and `FromMetaTensor` transforms (#4115) to and from meta Co-authored-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- CITATION.cff | 12 +- docs/source/transforms.rst | 15 ++ monai/data/meta_tensor.py | 5 + monai/transforms/__init__.py | 8 + monai/transforms/meta_utility/__init__.py | 10 ++ monai/transforms/meta_utility/dictionary.py | 102 +++++++++++ monai/utils/enums.py | 18 +- tests/test_module_list.py | 1 + tests/test_to_from_meta_tensord.py | 182 ++++++++++++++++++++ 9 files changed, 342 insertions(+), 11 deletions(-) create mode 100644 monai/transforms/meta_utility/__init__.py create mode 100644 monai/transforms/meta_utility/dictionary.py create mode 100644 tests/test_to_from_meta_tensord.py diff --git a/CITATION.cff b/CITATION.cff index 8cbf686ce60..dcce4af3776 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,11 +6,15 @@ title: "MONAI: Medical Open Network for AI" abstract: "AI Toolkit for Healthcare Imaging" authors: - name: "MONAI Consortium" -date-released: 2020-03-28 -version: "0.6.0" -doi: "10.5281/zenodo.4323058" +date-released: 2022-02-16 +version: "0.8.1" +identifiers: + - description: "This DOI represents all versions of MONAI, and will always resolve to the latest one." + type: doi + value: "10.5281/zenodo.4323058" license: "Apache-2.0" repository-code: "https://github.com/Project-MONAI/MONAI" -cff-version: "1.1.0" +url: "https://monai.io" +cff-version: "1.2.0" message: "If you use this software, please cite it using these metadata." ... diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 78fb3030931..676e0274fe2 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1842,6 +1842,21 @@ Utility (Dict) :members: :special-members: __call__ +MetaTensor +^^^^^^^^^^ + +`ToMetaTensord` +""""""""""""""" +.. autoclass:: ToMetaTensord + :members: + :special-members: __call__ + +`FromMetaTensord` +""""""""""""""""" +.. autoclass:: FromMetaTensord + :members: + :special-members: __call__ + Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 30270d89e2b..ba80f93e74e 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -12,6 +12,7 @@ from __future__ import annotations import warnings +from copy import deepcopy from typing import Callable import torch @@ -88,6 +89,10 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No self.affine = x.affine else: self.affine = self.get_default_affine() + + # if we are creating a new MetaTensor, then deep copy attributes + if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor): + self.meta = deepcopy(self.meta) self.affine = self.affine.to(self.device) def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 8e6ccc8b94e..581e368ba0c 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -206,6 +206,14 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .meta_utility.dictionary import ( + FromMetaTensord, + FromMetaTensorD, + FromMetaTensorDict, + ToMetaTensord, + ToMetaTensorD, + ToMetaTensorDict, +) from .nvtx import ( Mark, Markd, diff --git a/monai/transforms/meta_utility/__init__.py b/monai/transforms/meta_utility/__init__.py new file mode 100644 index 00000000000..1e97f894078 --- /dev/null +++ b/monai/transforms/meta_utility/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/meta_utility/dictionary.py b/monai/transforms/meta_utility/dictionary.py new file mode 100644 index 00000000000..1a9cf4c6319 --- /dev/null +++ b/monai/transforms/meta_utility/dictionary.py @@ -0,0 +1,102 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of dictionary-based wrappers for moving between MetaTensor types and dictionaries of data. +These can be used to make backwards compatible code. + +Class names are ended with 'd' to denote dictionary-based transforms. +""" + +from copy import deepcopy +from typing import Dict, Hashable, Mapping + +from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_tensor import MetaTensor +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.utils.enums import PostFix, TransformBackends + +__all__ = [ + "FromMetaTensord", + "FromMetaTensorD", + "FromMetaTensorDict", + "ToMetaTensord", + "ToMetaTensorD", + "ToMetaTensorDict", +] + + +class FromMetaTensord(MapTransform, InvertibleTransform): + """ + Dictionary-based transform to convert MetaTensor to a dictionary. + + If input is `{"a": MetaTensor, "b": MetaTensor}`, then output will + have the form `{"a": torch.Tensor, "a_meta_dict": dict, "b": ...}`. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + self.push_transform(d, key) + im: MetaTensor = d[key] # type: ignore + d.update(im.as_dict(key)) + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + # check transform + _ = self.get_most_recent_transform(d, key) + # do the inverse + im, meta = d[key], d.pop(PostFix.meta(key), None) + im = MetaTensor(im, meta=meta) # type: ignore + d[key] = im + # Remove the applied transform + self.pop_transform(d, key) + return d + + +class ToMetaTensord(MapTransform, InvertibleTransform): + """ + Dictionary-based transform to convert a dictionary to MetaTensor. + + If input is `{"a": torch.Tensor, "a_meta_dict": dict, "b": ...}`, then output will + have the form `{"a": MetaTensor, "b": MetaTensor}`. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + self.push_transform(d, key) + im, meta = d[key], d.pop(PostFix.meta(key), None) + im = MetaTensor(im, meta=meta) # type: ignore + d[key] = im + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + # check transform + _ = self.get_most_recent_transform(d, key) + # do the inverse + im: MetaTensor = d[key] # type: ignore + d.update(im.as_dict(key)) + # Remove the applied transform + self.pop_transform(d, key) + return d + + +FromMetaTensorD = FromMetaTensorDict = FromMetaTensord +ToMetaTensorD = ToMetaTensorDict = ToMetaTensord diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 4bc3d6ee842..1bfbdf824b7 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -227,13 +227,13 @@ class ForwardMode(Enum): class TraceKeys: """Extra meta data keys used for traceable transforms.""" - CLASS_NAME = "class" - ID = "id" - ORIG_SIZE = "orig_size" - EXTRA_INFO = "extra_info" - DO_TRANSFORM = "do_transforms" - KEY_SUFFIX = "_transforms" - NONE = "none" + CLASS_NAME: str = "class" + ID: str = "id" + ORIG_SIZE: str = "orig_size" + EXTRA_INFO: str = "extra_info" + DO_TRANSFORM: str = "do_transforms" + KEY_SUFFIX: str = "_transforms" + NONE: str = "none" @deprecated(since="0.8.0", msg_suffix="use monai.utils.enums.TraceKeys instead.") @@ -287,6 +287,10 @@ def meta(key: Optional[str] = None): def orig_meta(key: Optional[str] = None): return PostFix._get_str(key, "orig_meta_dict") + @staticmethod + def transforms(key: Optional[str] = None): + return PostFix._get_str(key, TraceKeys.KEY_SUFFIX[1:]) + class TransformBackends(Enum): """ diff --git a/tests/test_module_list.py b/tests/test_module_list.py index 83c6979f308..d81d067c58e 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -40,6 +40,7 @@ def test_transform_api(self): to_exclude = {"MapTransform"} # except for these transforms to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision"} to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"}) + to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"}) xforms = { name: obj for name, obj in monai.transforms.__dict__.items() diff --git a/tests/test_to_from_meta_tensord.py b/tests/test_to_from_meta_tensord.py new file mode 100644 index 00000000000..9bbf4592ab4 --- /dev/null +++ b/tests/test_to_from_meta_tensord.py @@ -0,0 +1,182 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string +import unittest +from copy import deepcopy +from typing import Optional, Union + +import torch +from parameterized import parameterized + +from monai.data.meta_tensor import MetaTensor +from monai.transforms import FromMetaTensord, ToMetaTensord +from monai.utils.enums import PostFix +from monai.utils.module import get_torch_version_tuple +from tests.utils import TEST_DEVICES, assert_allclose + +PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple() + +DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] +TESTS = [] +for _device in TEST_DEVICES: + for _dtype in DTYPES: + TESTS.append((*_device, *_dtype)) + + +def rand_string(min_len=5, max_len=10): + str_size = random.randint(min_len, max_len) + chars = string.ascii_letters + string.punctuation + return "".join(random.choice(chars) for _ in range(str_size)) + + +class TestToFromMetaTensord(unittest.TestCase): + @staticmethod + def get_im(shape=None, dtype=None, device=None): + if shape is None: + shape = shape = (1, 10, 8) + affine = torch.randint(0, 10, (4, 4)) + meta = {"fname": rand_string()} + t = torch.rand(shape) + if dtype is not None: + t = t.to(dtype) + if device is not None: + t = t.to(device) + m = MetaTensor(t.clone(), affine, meta) + return m + + def check_ids(self, a, b, should_match): + comp = self.assertEqual if should_match else self.assertNotEqual + comp(id(a), id(b)) + + def check( + self, + out: torch.Tensor, + orig: torch.Tensor, + *, + shape: bool = True, + vals: bool = True, + ids: bool = True, + device: Optional[Union[str, torch.device]] = None, + meta: bool = True, + check_ids: bool = True, + **kwargs, + ): + if device is None: + device = orig.device + + # check the image + self.assertIsInstance(out, type(orig)) + if shape: + assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape)) + if vals: + assert_allclose(out, orig, **kwargs) + if check_ids: + self.check_ids(out, orig, ids) + self.assertTrue(str(device) in str(out.device)) + + # check meta and affine are equal and affine is on correct device + if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: + orig_meta_no_affine = deepcopy(orig.meta) + del orig_meta_no_affine["affine"] + out_meta_no_affine = deepcopy(out.meta) + del out_meta_no_affine["affine"] + self.assertEqual(orig_meta_no_affine, out_meta_no_affine) + assert_allclose(out.affine, orig.affine) + self.assertTrue(str(device) in str(out.affine.device)) + if check_ids: + self.check_ids(out.affine, orig.affine, ids) + self.check_ids(out.meta, orig.meta, ids) + + @parameterized.expand(TESTS) + def test_from_to_meta_tensord(self, device, dtype): + m1 = self.get_im(device=device, dtype=dtype) + m2 = self.get_im(device=device, dtype=dtype) + m3 = self.get_im(device=device, dtype=dtype) + d_metas = {"m1": m1, "m2": m2, "m3": m3} + m1_meta = {k: v for k, v in m1.meta.items() if k != "affine"} + m1_aff = m1.affine + + # FROM -> forward + t_from_meta = FromMetaTensord(["m1", "m2"]) + d_dict = t_from_meta(d_metas) + + self.assertEqual( + sorted(d_dict.keys()), + [ + "m1", + PostFix.meta("m1"), + PostFix.transforms("m1"), + "m2", + PostFix.meta("m2"), + PostFix.transforms("m2"), + "m3", + ], + ) + self.check(d_dict["m3"], m3, ids=True) # unchanged + self.check(d_dict["m1"], m1.as_tensor(), ids=False) + meta_out = {k: v for k, v in d_dict["m1_meta_dict"].items() if k != "affine"} + aff_out = d_dict["m1_meta_dict"]["affine"] + self.check(aff_out, m1_aff, ids=True) + self.assertEqual(meta_out, m1_meta) + + # FROM -> inverse + d_meta_dict_meta = t_from_meta.inverse(d_dict) + self.assertEqual( + sorted(d_meta_dict_meta.keys()), ["m1", PostFix.transforms("m1"), "m2", PostFix.transforms("m2"), "m3"] + ) + self.check(d_meta_dict_meta["m3"], m3, ids=False) # unchanged (except deep copy in inverse) + self.check(d_meta_dict_meta["m1"], m1, ids=False) + meta_out = {k: v for k, v in d_meta_dict_meta["m1"].meta.items() if k != "affine"} + aff_out = d_meta_dict_meta["m1"].affine + self.check(aff_out, m1_aff, ids=False) + self.assertEqual(meta_out, m1_meta) + + # TO -> Forward + t_to_meta = ToMetaTensord(["m1", "m2"]) + del d_dict["m1_transforms"] + del d_dict["m2_transforms"] + d_dict_meta = t_to_meta(d_dict) + self.assertEqual( + sorted(d_dict_meta.keys()), ["m1", PostFix.transforms("m1"), "m2", PostFix.transforms("m2"), "m3"] + ) + self.check(d_dict_meta["m3"], m3, ids=True) # unchanged (except deep copy in inverse) + self.check(d_dict_meta["m1"], m1, ids=False) + meta_out = {k: v for k, v in d_dict_meta["m1"].meta.items() if k != "affine"} + aff_out = d_dict_meta["m1"].meta["affine"] + self.check(aff_out, m1_aff, ids=False) + self.assertEqual(meta_out, m1_meta) + + # TO -> Inverse + d_dict_meta_dict = t_to_meta.inverse(d_dict_meta) + self.assertEqual( + sorted(d_dict_meta_dict.keys()), + [ + "m1", + PostFix.meta("m1"), + PostFix.transforms("m1"), + "m2", + PostFix.meta("m2"), + PostFix.transforms("m2"), + "m3", + ], + ) + self.check(d_dict_meta_dict["m3"], m3.as_tensor(), ids=False) # unchanged (except deep copy in inverse) + self.check(d_dict_meta_dict["m1"], m1.as_tensor(), ids=False) + meta_out = {k: v for k, v in d_dict_meta_dict["m1_meta_dict"].items() if k != "affine"} + aff_out = d_dict_meta_dict["m1_meta_dict"]["affine"] + self.check(aff_out, m1_aff, ids=False) + self.assertEqual(meta_out, m1_meta) + + +if __name__ == "__main__": + unittest.main()