Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update meta tensor api #4131

Merged
merged 2 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_obj import MetaObj, get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_tensor import MetaTensor
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
3 changes: 2 additions & 1 deletion monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class MetaObj:

"""

_meta: dict
def __init__(self):
self._meta: dict = self.get_default_meta()

@staticmethod
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
Expand Down
31 changes: 15 additions & 16 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,33 @@ class MetaTensor(MetaObj, torch.Tensor):

@staticmethod
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor:
return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore

def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None:
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affine, except this could come from
four places.
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
"""
out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore
super().__init__()
# set meta
if meta is not None:
out.meta = meta
self.meta = meta
elif isinstance(x, MetaObj):
out.meta = x.meta
else:
out.meta = out.get_default_meta()
self.meta = x.meta
# set the affine
if affine is not None:
if "affine" in out.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.")
out.affine = affine
elif "affine" in out.meta:
if "affine" in self.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.")
self.affine = affine
elif "affine" in self.meta:
pass # nothing to do
elif isinstance(x, MetaTensor):
out.affine = x.affine
self.affine = x.affine
else:
out.affine = out.get_default_affine()
out.affine = out.affine.to(out.device)

return out
self.affine = self.get_default_affine()
self.affine = self.affine.to(self.device)

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
super()._copy_attr(attribute, input_objs, default_fn, deep_copy)
Expand All @@ -113,8 +112,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
ret.affine = ret.affine.to(ret.device)
return ret

def get_default_affine(self) -> torch.Tensor:
return torch.eye(4, device=self.device)
def get_default_affine(self, dtype=torch.float64) -> torch.Tensor:
wyli marked this conversation as resolved.
Show resolved Hide resolved
return torch.eye(4, device=self.device, dtype=dtype)

def as_tensor(self) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TestMetaTensor(unittest.TestCase):
@staticmethod
def get_im(shape=None, dtype=None, device=None):
if shape is None:
shape = shape = (1, 10, 8)
shape = (1, 10, 8)
affine = torch.randint(0, 10, (4, 4))
meta = {"fname": rand_string()}
t = torch.rand(shape)
Expand Down