-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* meta tensor Signed-off-by: Richard Brown <[email protected]>
- Loading branch information
Showing
6 changed files
with
650 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from typing import Any, Callable, Sequence | ||
|
||
_TRACK_META = True | ||
_TRACK_TRANSFORMS = True | ||
|
||
__all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"] | ||
|
||
|
||
def set_track_meta(val: bool) -> None: | ||
""" | ||
Boolean to set whether metadata is tracked. If `True`, metadata will be associated | ||
its data by using subclasses of `MetaObj`. If `False`, then data will be returned | ||
with empty metadata. | ||
If both `set_track_meta` and `set_track_transforms` are set to | ||
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
`np.ndarray`) as opposed to our enhanced objects. | ||
By default, this is `True`, and most users will want to leave it this way. However, | ||
if you are experiencing any problems regarding metadata, and aren't interested in | ||
preserving metadata, then you can disable it. | ||
""" | ||
global _TRACK_META | ||
_TRACK_META = val | ||
|
||
|
||
def set_track_transforms(val: bool) -> None: | ||
""" | ||
Boolean to set whether transforms are tracked. If `True`, applied transforms will be | ||
associated its data by using subclasses of `MetaObj`. If `False`, then transforms | ||
won't be tracked. | ||
If both `set_track_meta` and `set_track_transforms` are set to | ||
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
`np.ndarray`) as opposed to our enhanced objects. | ||
By default, this is `True`, and most users will want to leave it this way. However, | ||
if you are experiencing any problems regarding transforms, and aren't interested in | ||
preserving transforms, then you can disable it. | ||
""" | ||
global _TRACK_TRANSFORMS | ||
_TRACK_TRANSFORMS = val | ||
|
||
|
||
def get_track_meta() -> bool: | ||
""" | ||
Return the boolean as to whether metadata is tracked. If `True`, metadata will be | ||
associated its data by using subclasses of `MetaObj`. If `False`, then data will be | ||
returned with empty metadata. | ||
If both `set_track_meta` and `set_track_transforms` are set to | ||
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
`np.ndarray`) as opposed to our enhanced objects. | ||
By default, this is `True`, and most users will want to leave it this way. However, | ||
if you are experiencing any problems regarding metadata, and aren't interested in | ||
preserving metadata, then you can disable it. | ||
""" | ||
return _TRACK_META | ||
|
||
|
||
def get_track_transforms() -> bool: | ||
""" | ||
Return the boolean as to whether transforms are tracked. If `True`, applied | ||
transforms will be associated its data by using subclasses of `MetaObj`. If `False`, | ||
then transforms won't be tracked. | ||
If both `set_track_meta` and `set_track_transforms` are set to | ||
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
`np.ndarray`) as opposed to our enhanced objects. | ||
By default, this is `True`, and most users will want to leave it this way. However, | ||
if you are experiencing any problems regarding transforms, and aren't interested in | ||
preserving transforms, then you can disable it. | ||
""" | ||
return _TRACK_TRANSFORMS | ||
|
||
|
||
class MetaObj: | ||
""" | ||
Abstract base class that stores data as well as any extra metadata. | ||
This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple | ||
inheritance. | ||
Metadata is stored in the form of a dictionary. | ||
Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) | ||
aside from the extended meta functionality. | ||
Copying of information: | ||
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the | ||
first instance of `MetaObj`. | ||
""" | ||
|
||
_meta: dict | ||
|
||
@staticmethod | ||
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: | ||
""" | ||
Recursively flatten input and return all instances of `MetaObj` as a single | ||
list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and | ||
their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type | ||
`MetaObj`. | ||
Args: | ||
args: Sequence of inputs to be flattened. | ||
Returns: | ||
list of nested `MetaObj` from input. | ||
""" | ||
out = [] | ||
for a in args: | ||
if isinstance(a, (list, tuple)): | ||
out += MetaObj.flatten_meta_objs(a) | ||
elif isinstance(a, MetaObj): | ||
out.append(a) | ||
return out | ||
|
||
def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: | ||
""" | ||
Copy an attribute from the first in a list of `MetaObj`. In the case of | ||
`torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so | ||
check them all. Copy the first to `self`. | ||
We also perform a deep copy of the data if desired. | ||
Args: | ||
attribute: string corresponding to attribute to be copied (e.g., `meta`). | ||
input_objs: List of `MetaObj`. We'll copy the attribute from the first one | ||
that contains that particular attribute. | ||
default_fn: If none of `input_objs` have the attribute that we're | ||
interested in, then use this default function (e.g., `lambda: {}`.) | ||
deep_copy: Should the attribute be deep copied? See `_copy_meta`. | ||
Returns: | ||
Returns `None`, but `self` should be updated to have the copied attribute. | ||
""" | ||
attributes = [getattr(i, attribute) for i in input_objs] | ||
if len(attributes) > 0: | ||
val = attributes[0] | ||
if deep_copy: | ||
val = deepcopy(val) | ||
setattr(self, attribute, val) | ||
else: | ||
setattr(self, attribute, default_fn()) | ||
|
||
def _copy_meta(self, input_objs: list[MetaObj]) -> None: | ||
""" | ||
Copy metadata from a list of `MetaObj`. For a given attribute, we copy the | ||
adjunct data from the first element in the list containing that attribute. | ||
If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g., | ||
`a+=1`), then don't. | ||
Args: | ||
input_objs: list of `MetaObj` to copy data from. | ||
""" | ||
id_in = id(input_objs[0]) if len(input_objs) > 0 else None | ||
deep_copy = id(self) != id_in | ||
self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) | ||
|
||
def get_default_meta(self) -> dict: | ||
"""Get the default meta. | ||
Returns: | ||
default metadata. | ||
""" | ||
return {} | ||
|
||
def __repr__(self) -> str: | ||
"""String representation of class.""" | ||
out: str = super().__repr__() | ||
|
||
out += "\nMetaData\n" | ||
if self.meta is not None: | ||
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) | ||
else: | ||
out += "None" | ||
|
||
return out | ||
|
||
@property | ||
def meta(self) -> dict: | ||
"""Get the meta.""" | ||
return self._meta | ||
|
||
@meta.setter | ||
def meta(self, d: dict) -> None: | ||
"""Set the meta.""" | ||
self._meta = d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import warnings | ||
from typing import Callable | ||
|
||
import torch | ||
|
||
from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms | ||
from monai.utils.enums import PostFix | ||
|
||
__all__ = ["MetaTensor"] | ||
|
||
|
||
class MetaTensor(MetaObj, torch.Tensor): | ||
""" | ||
Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata. | ||
Metadata is stored in the form of a dictionary. Nested, an affine matrix will be | ||
stored. This should be in the form of `torch.Tensor`. | ||
Behavior should be the same as `torch.Tensor` aside from the extended | ||
meta functionality. | ||
Copying of information: | ||
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the | ||
first instance of `MetaTensor`. | ||
Example: | ||
.. code-block:: python | ||
import torch | ||
from monai.data import MetaTensor | ||
t = torch.tensor([1,2,3]) | ||
affine = torch.eye(4) * 100 | ||
meta = {"some": "info"} | ||
m = MetaTensor(t, affine=affine, meta=meta) | ||
m2 = m+m | ||
assert isinstance(m2, MetaTensor) | ||
assert m2.meta["some"] == "info" | ||
assert m2.affine == affine | ||
Notes: | ||
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may | ||
not work if `im` is of type `MetaTensor`. This can be resolved with | ||
`torch.jit.trace(net, im.as_tensor())`. | ||
- A warning will be raised if in the constructor `affine` is not `None` and | ||
`meta` already contains the key `affine`. | ||
""" | ||
|
||
@staticmethod | ||
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: | ||
""" | ||
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it. | ||
Else, use the default value. Similar for the affine, except this could come from | ||
four places. | ||
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. | ||
""" | ||
out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore | ||
# set meta | ||
if meta is not None: | ||
out.meta = meta | ||
elif isinstance(x, MetaObj): | ||
out.meta = x.meta | ||
else: | ||
out.meta = out.get_default_meta() | ||
# set the affine | ||
if affine is not None: | ||
if "affine" in out.meta: | ||
warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.") | ||
out.affine = affine | ||
elif "affine" in out.meta: | ||
pass # nothing to do | ||
elif isinstance(x, MetaTensor): | ||
out.affine = x.affine | ||
else: | ||
out.affine = out.get_default_affine() | ||
out.affine = out.affine.to(out.device) | ||
|
||
return out | ||
|
||
def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: | ||
super()._copy_attr(attribute, input_objs, default_fn, deep_copy) | ||
val = getattr(self, attribute) | ||
if isinstance(val, torch.Tensor): | ||
setattr(self, attribute, val.to(self.device)) | ||
|
||
@classmethod | ||
def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: | ||
"""Wraps all torch functions.""" | ||
if kwargs is None: | ||
kwargs = {} | ||
ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) | ||
# e.g., __repr__ returns a string | ||
if not isinstance(ret, torch.Tensor): | ||
return ret | ||
if not (get_track_meta() or get_track_transforms()): | ||
return ret.as_tensor() | ||
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) | ||
ret._copy_meta(meta_args) | ||
ret.affine = ret.affine.to(ret.device) | ||
return ret | ||
|
||
def get_default_affine(self) -> torch.Tensor: | ||
return torch.eye(4, device=self.device) | ||
|
||
def as_tensor(self) -> torch.Tensor: | ||
""" | ||
Return the `MetaTensor` as a `torch.Tensor`. | ||
It is OS dependent as to whether this will be a deep copy or not. | ||
""" | ||
return self.as_subclass(torch.Tensor) # type: ignore | ||
|
||
def as_dict(self, key: str) -> dict: | ||
""" | ||
Get the object as a dictionary for backwards compatibility. | ||
Args: | ||
key: Base key to store main data. The key for the metadata will be | ||
determined using `PostFix.meta`. | ||
Return: | ||
A dictionary consisting of two keys, the main data (stored under `key`) and | ||
the metadata. | ||
""" | ||
return {key: self.as_tensor(), PostFix.meta(key): self.meta} | ||
|
||
@property | ||
def affine(self) -> torch.Tensor: | ||
"""Get the affine.""" | ||
return self.meta["affine"] # type: ignore | ||
|
||
@affine.setter | ||
def affine(self, d: torch.Tensor) -> None: | ||
"""Set the affine.""" | ||
self.meta["affine"] = d |
Oops, something went wrong.