Skip to content

Commit

Permalink
meta tensor (#4077)
Browse files Browse the repository at this point in the history
* meta tensor

Signed-off-by: Richard Brown <[email protected]>
  • Loading branch information
rijobro authored Apr 12, 2022
1 parent 4e7ca44 commit 1880d38
Show file tree
Hide file tree
Showing 6 changed files with 650 additions and 2 deletions.
11 changes: 11 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,14 @@ ThreadDataLoader
TestTimeAugmentation
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.TestTimeAugmentation


Meta Object
-----------
.. automodule:: monai.data.meta_obj
:members:

MetaTensor
----------
.. autoclass:: monai.data.MetaTensor
:members:
2 changes: 2 additions & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_tensor import MetaTensor
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
from .png_saver import PNGSaver
Expand Down
207 changes: 207 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from copy import deepcopy
from typing import Any, Callable, Sequence

_TRACK_META = True
_TRACK_TRANSFORMS = True

__all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"]


def set_track_meta(val: bool) -> None:
"""
Boolean to set whether metadata is tracked. If `True`, metadata will be associated
its data by using subclasses of `MetaObj`. If `False`, then data will be returned
with empty metadata.
If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.
By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
preserving metadata, then you can disable it.
"""
global _TRACK_META
_TRACK_META = val


def set_track_transforms(val: bool) -> None:
"""
Boolean to set whether transforms are tracked. If `True`, applied transforms will be
associated its data by using subclasses of `MetaObj`. If `False`, then transforms
won't be tracked.
If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.
By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding transforms, and aren't interested in
preserving transforms, then you can disable it.
"""
global _TRACK_TRANSFORMS
_TRACK_TRANSFORMS = val


def get_track_meta() -> bool:
"""
Return the boolean as to whether metadata is tracked. If `True`, metadata will be
associated its data by using subclasses of `MetaObj`. If `False`, then data will be
returned with empty metadata.
If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.
By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
preserving metadata, then you can disable it.
"""
return _TRACK_META


def get_track_transforms() -> bool:
"""
Return the boolean as to whether transforms are tracked. If `True`, applied
transforms will be associated its data by using subclasses of `MetaObj`. If `False`,
then transforms won't be tracked.
If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.
By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding transforms, and aren't interested in
preserving transforms, then you can disable it.
"""
return _TRACK_TRANSFORMS


class MetaObj:
"""
Abstract base class that stores data as well as any extra metadata.
This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple
inheritance.
Metadata is stored in the form of a dictionary.
Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`)
aside from the extended meta functionality.
Copying of information:
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaObj`.
"""

_meta: dict

@staticmethod
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
"""
Recursively flatten input and return all instances of `MetaObj` as a single
list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type
`MetaObj`.
Args:
args: Sequence of inputs to be flattened.
Returns:
list of nested `MetaObj` from input.
"""
out = []
for a in args:
if isinstance(a, (list, tuple)):
out += MetaObj.flatten_meta_objs(a)
elif isinstance(a, MetaObj):
out.append(a)
return out

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
"""
Copy an attribute from the first in a list of `MetaObj`. In the case of
`torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so
check them all. Copy the first to `self`.
We also perform a deep copy of the data if desired.
Args:
attribute: string corresponding to attribute to be copied (e.g., `meta`).
input_objs: List of `MetaObj`. We'll copy the attribute from the first one
that contains that particular attribute.
default_fn: If none of `input_objs` have the attribute that we're
interested in, then use this default function (e.g., `lambda: {}`.)
deep_copy: Should the attribute be deep copied? See `_copy_meta`.
Returns:
Returns `None`, but `self` should be updated to have the copied attribute.
"""
attributes = [getattr(i, attribute) for i in input_objs]
if len(attributes) > 0:
val = attributes[0]
if deep_copy:
val = deepcopy(val)
setattr(self, attribute, val)
else:
setattr(self, attribute, default_fn())

def _copy_meta(self, input_objs: list[MetaObj]) -> None:
"""
Copy metadata from a list of `MetaObj`. For a given attribute, we copy the
adjunct data from the first element in the list containing that attribute.
If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g.,
`a+=1`), then don't.
Args:
input_objs: list of `MetaObj` to copy data from.
"""
id_in = id(input_objs[0]) if len(input_objs) > 0 else None
deep_copy = id(self) != id_in
self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy)

def get_default_meta(self) -> dict:
"""Get the default meta.
Returns:
default metadata.
"""
return {}

def __repr__(self) -> str:
"""String representation of class."""
out: str = super().__repr__()

out += "\nMetaData\n"
if self.meta is not None:
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items())
else:
out += "None"

return out

@property
def meta(self) -> dict:
"""Get the meta."""
return self._meta

@meta.setter
def meta(self, d: dict) -> None:
"""Set the meta."""
self._meta = d
148 changes: 148 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import Callable

import torch

from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms
from monai.utils.enums import PostFix

__all__ = ["MetaTensor"]


class MetaTensor(MetaObj, torch.Tensor):
"""
Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.
Metadata is stored in the form of a dictionary. Nested, an affine matrix will be
stored. This should be in the form of `torch.Tensor`.
Behavior should be the same as `torch.Tensor` aside from the extended
meta functionality.
Copying of information:
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaTensor`.
Example:
.. code-block:: python
import torch
from monai.data import MetaTensor
t = torch.tensor([1,2,3])
affine = torch.eye(4) * 100
meta = {"some": "info"}
m = MetaTensor(t, affine=affine, meta=meta)
m2 = m+m
assert isinstance(m2, MetaTensor)
assert m2.meta["some"] == "info"
assert m2.affine == affine
Notes:
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
- A warning will be raised if in the constructor `affine` is not `None` and
`meta` already contains the key `affine`.
"""

@staticmethod
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor:
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affine, except this could come from
four places.
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
"""
out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore
# set meta
if meta is not None:
out.meta = meta
elif isinstance(x, MetaObj):
out.meta = x.meta
else:
out.meta = out.get_default_meta()
# set the affine
if affine is not None:
if "affine" in out.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.")
out.affine = affine
elif "affine" in out.meta:
pass # nothing to do
elif isinstance(x, MetaTensor):
out.affine = x.affine
else:
out.affine = out.get_default_affine()
out.affine = out.affine.to(out.device)

return out

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
super()._copy_attr(attribute, input_objs, default_fn, deep_copy)
val = getattr(self, attribute)
if isinstance(val, torch.Tensor):
setattr(self, attribute, val.to(self.device))

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
"""Wraps all torch functions."""
if kwargs is None:
kwargs = {}
ret: MetaTensor = super().__torch_function__(func, types, args, kwargs)
# e.g., __repr__ returns a string
if not isinstance(ret, torch.Tensor):
return ret
if not (get_track_meta() or get_track_transforms()):
return ret.as_tensor()
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
ret._copy_meta(meta_args)
ret.affine = ret.affine.to(ret.device)
return ret

def get_default_affine(self) -> torch.Tensor:
return torch.eye(4, device=self.device)

def as_tensor(self) -> torch.Tensor:
"""
Return the `MetaTensor` as a `torch.Tensor`.
It is OS dependent as to whether this will be a deep copy or not.
"""
return self.as_subclass(torch.Tensor) # type: ignore

def as_dict(self, key: str) -> dict:
"""
Get the object as a dictionary for backwards compatibility.
Args:
key: Base key to store main data. The key for the metadata will be
determined using `PostFix.meta`.
Return:
A dictionary consisting of two keys, the main data (stored under `key`) and
the metadata.
"""
return {key: self.as_tensor(), PostFix.meta(key): self.meta}

@property
def affine(self) -> torch.Tensor:
"""Get the affine."""
return self.meta["affine"] # type: ignore

@affine.setter
def affine(self, d: torch.Tensor) -> None:
"""Set the affine."""
self.meta["affine"] = d
Loading

0 comments on commit 1880d38

Please sign in to comment.