From 4a1ce9fe244ff75676919a64fda815c8b9609564 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 13 Apr 2022 21:20:25 +0100 Subject: [PATCH 1/2] update meta tensor api Signed-off-by: Wenqi Li --- monai/data/__init__.py | 2 +- monai/data/meta_obj.py | 3 ++- monai/data/meta_tensor.py | 33 ++++++++++++++++++--------------- tests/test_meta_tensor.py | 2 +- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index cdab2a1037..19ca29eafa 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -46,7 +46,7 @@ resolve_writer, ) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer -from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms +from .meta_obj import MetaObj, get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from .meta_tensor import MetaTensor from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index d60ec6e473..0e213f130b 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -109,7 +109,8 @@ class MetaObj: """ - _meta: dict + def __init__(self): + self._meta: dict = self.get_default_meta() @staticmethod def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index c5b95f8d08..94991c2e42 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -62,34 +62,35 @@ class MetaTensor(MetaObj, torch.Tensor): @staticmethod def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: + return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore + + def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None: """ If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it. Else, use the default value. Similar for the affine, except this could come from four places. Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. """ - out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore + super().__init__() # set meta if meta is not None: - out.meta = meta + self.meta = meta elif isinstance(x, MetaObj): - out.meta = x.meta + self.meta = x.meta else: - out.meta = out.get_default_meta() + self.meta = self.get_default_meta() # set the affine if affine is not None: - if "affine" in out.meta: - warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.") - out.affine = affine - elif "affine" in out.meta: + if "affine" in self.meta: + warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.") + self.affine = affine + elif "affine" in self.meta: pass # nothing to do elif isinstance(x, MetaTensor): - out.affine = x.affine + self.affine = x.affine else: - out.affine = out.get_default_affine() - out.affine = out.affine.to(out.device) - - return out + self.affine = self.get_default_affine() + self.affine = self.affine.to(self.device) def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: super()._copy_attr(attribute, input_objs, default_fn, deep_copy) @@ -113,8 +114,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: ret.affine = ret.affine.to(ret.device) return ret - def get_default_affine(self) -> torch.Tensor: - return torch.eye(4, device=self.device) + def get_default_affine(self, dtype=torch.float64) -> torch.Tensor: + return torch.eye(4, device=self.device, dtype=dtype) def as_tensor(self) -> torch.Tensor: """ @@ -140,6 +141,8 @@ def as_dict(self, key: str) -> dict: @property def affine(self) -> torch.Tensor: """Get the affine.""" + if "affine" not in self.meta: + self.meta["affine"] = self.get_default_affine() return self.meta["affine"] # type: ignore @affine.setter diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 7688950a4b..c18ef08b85 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -44,7 +44,7 @@ class TestMetaTensor(unittest.TestCase): @staticmethod def get_im(shape=None, dtype=None, device=None): if shape is None: - shape = shape = (1, 10, 8) + shape = (1, 10, 8) affine = torch.randint(0, 10, (4, 4)) meta = {"fname": rand_string()} t = torch.rand(shape) From f1ebae3d893f2fcedc7922d893cf7253394f78a6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 13 Apr 2022 22:30:12 +0100 Subject: [PATCH 2/2] update based on comments Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 94991c2e42..30270d89e2 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -77,8 +77,6 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No self.meta = meta elif isinstance(x, MetaObj): self.meta = x.meta - else: - self.meta = self.get_default_meta() # set the affine if affine is not None: if "affine" in self.meta: @@ -141,8 +139,6 @@ def as_dict(self, key: str) -> dict: @property def affine(self) -> torch.Tensor: """Get the affine.""" - if "affine" not in self.meta: - self.meta["affine"] = self.get_default_affine() return self.meta["affine"] # type: ignore @affine.setter