Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

meta tensor #4077

Merged
merged 33 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e15ec97
meta tensor
rijobro Apr 5, 2022
20f1b47
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro Apr 7, 2022
accca59
fixes
rijobro Apr 7, 2022
a15d462
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro Apr 7, 2022
36a1e75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2022
f71d60f
deep_copy
rijobro Apr 7, 2022
51e6b3b
fixes
rijobro Apr 7, 2022
1e9f370
torchscript
rijobro Apr 7, 2022
76e0cce
test pickling torchscript amp
rijobro Apr 7, 2022
de02d7f
fxies
rijobro Apr 7, 2022
8a60df1
test fixes
rijobro Apr 7, 2022
70bbcdb
fixes
rijobro Apr 7, 2022
eabd3bf
typos
rijobro Apr 8, 2022
4af4d50
fixes
rijobro Apr 8, 2022
4a2c211
fix test
rijobro Apr 8, 2022
6373711
fix
rijobro Apr 8, 2022
07f5117
fix?
rijobro Apr 8, 2022
723f475
affine lives inside meta
rijobro Apr 8, 2022
b25440c
move affine in meta
rijobro Apr 8, 2022
3e2eb02
flake8
rijobro Apr 11, 2022
b6a1f56
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro Apr 11, 2022
32cc5b3
fixes
rijobro Apr 11, 2022
7fe23fd
fixes
rijobro Apr 11, 2022
7f43b00
fixes
rijobro Apr 11, 2022
840e7df
pytorch min version 1.7
rijobro Apr 11, 2022
e2256da
Revert "pytorch min version 1.7"
rijobro Apr 11, 2022
79ce908
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_transforms
rijobro Apr 11, 2022
e3e567c
torch 1.9
rijobro Apr 11, 2022
87b572e
Merge branch 'dev' into MetaTensor_1st_PR
wyli Apr 12, 2022
56bc6d5
remove __init__, correct docstring
rijobro Apr 12, 2022
09eb602
adds docs
wyli Apr 12, 2022
7064843
Merge pull request #3 from wyli/adds-docs
rijobro Apr 12, 2022
81404ed
Merge branch 'dev' into MetaTensor_1st_PR
rijobro Apr 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_tensor import MetaTensor
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
from .png_saver import PNGSaver
Expand Down
205 changes: 205 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from copy import deepcopy
from typing import Any, Callable, Sequence

_TRACK_META = True
_TRACK_TRANSFORMS = True

__all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"]


def set_track_meta(val: bool) -> None:
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""
Boolean to set whether metadata is tracked. If `True`, metadata will be associated
its data by using subclasses of `MetaObj`. If `False`, then data will be returned
with empty metadata.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
preserving metadata, then you can disable it.
"""
global _TRACK_META
_TRACK_META = val


def set_track_transforms(val: bool) -> None:
"""
Boolean to set whether transforms are tracked. If `True`, applied transforms will be
associated its data by using subclasses of `MetaObj`. If `False`, then transforms
won't be tracked.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding transforms, and aren't interested in
preserving transforms, then you can disable it.
"""
global _TRACK_TRANSFORMS
_TRACK_TRANSFORMS = val


def get_track_meta() -> bool:
"""
Return the boolean as to whether metadata is tracked. If `True`, metadata will be
associated its data by using subclasses of `MetaObj`. If `False`, then data will be
returned with empty metadata.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
preserving metadata, then you can disable it.
"""
return _TRACK_META


def get_track_transforms() -> bool:
"""
Return the boolean as to whether transforms are tracked. If `True`, applied
transforms will be associated its data by using subclasses of `MetaObj`. If `False`,
then transforms won't be tracked.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding transforms, and aren't interested in
preserving transforms, then you can disable it.
"""
return _TRACK_TRANSFORMS


class MetaObj:
"""
Abstract base class that stores data as well as any extra metadata.

This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple
inheritance.

Metadata is stored in the form of a dictionary.

Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`)
aside from the extended meta functionality.

Copying of information:
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaObj`.
"""

_meta: dict

@staticmethod
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
"""
Recursively flatten input and return all instances of `MetaObj` as a single
list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type
`MetaObj`.

Args:
args: Sequence of inputs to be flattened.
Returns:
list of nested `MetaObj` from input.
"""
out = []
for a in args:
if isinstance(a, (list, tuple)):
out += MetaObj.flatten_meta_objs(a)
elif isinstance(a, MetaObj):
out.append(a)
return out

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
"""
Copy an attribute from the first in a list of `MetaObj`. In the case of
`torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so
check them all. Copy the first to `self`.

We also perform a deep copy of the data if desired.

Args:
attribute: string corresponding to attribute to be copied (e.g., `meta`).
input_objs: List of `MetaObj`. We'll copy the attribute from the first one
that contains that particular attribute.
default_fn: If none of `input_objs` have the attribute that we're
interested in, then use this default function (e.g., `lambda: {}`.)
deep_copy: Should the attribute be deep copied? See `_copy_meta`.

Returns:
Returns `None`, but `self` should be updated to have the copied attribute.
"""
attributes = [getattr(i, attribute) for i in input_objs]
if len(attributes) > 0:
val = attributes[0]
if deep_copy:
val = deepcopy(val)
setattr(self, attribute, val)
else:
setattr(self, attribute, default_fn())

def _copy_meta(self, input_objs: list[MetaObj]) -> None:
"""
Copy metadata from a list of `MetaObj`. For a given attribute, we copy the
adjunct data from the first element in the list containing that attribute.

If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g.,
`a+=1`), then don't.

Args:
input_objs: list of `MetaObj` to copy data from.

"""
id_in = id(input_objs[0]) if len(input_objs) > 0 else None
deep_copy = id(self) != id_in
self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy)

def get_default_meta(self) -> dict:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
"""Get the default meta.

Returns:
default metadata.
"""
return {}

def __repr__(self) -> str:
"""String representation of class."""
out: str = super().__repr__()

out += "\nMetaData\n"
if self.meta is not None:
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items())
else:
out += "None"

return out

@property
def meta(self) -> dict:
"""Get the meta."""
return self._meta
wyli marked this conversation as resolved.
Show resolved Hide resolved

@meta.setter
def meta(self, d: dict) -> None:
"""Set the meta."""
self._meta = d
149 changes: 149 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import Callable

import torch

from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms
from monai.utils.enums import PostFix

__all__ = ["MetaTensor"]


class MetaTensor(MetaObj, torch.Tensor):
"""
Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for meta
data.

Metadata is stored in the form of a dictionary. Nested, an affine matrix will be
stored. This should be in the form of `torch.Tensor`.

Behavior should be the same as `torch.Tensor` aside from the extended
meta functionality.

Copying of information:
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaTensor`.

Example:
.. code-block:: python

import torch
from monai.data import MetaTensor

t = torch.tensor([1,2,3])
affine = torch.eye(4)
meta = {"some": "info"}
m = MetaTensor(t, affine=affine, meta=meta)
m2 = m+m
assert isinstance(m2, MetaTensor)
assert m2.meta == meta
rijobro marked this conversation as resolved.
Show resolved Hide resolved

Notes:
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
- For older versions of pytorch (<=1.7), `torch.save(m, fname); m=torch.load(fname)`
may return a `torch.Tensor` instead of MetaTensor`.
- A warning will be raised if in the constructor `affine` is not `None` and
`meta` already contains the key `affine`.
"""

@staticmethod
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore

def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affin, except this could come from
four places.
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
"""
# set meta
if meta is not None:
self.meta = meta
elif isinstance(x, MetaObj):
self.meta = x.meta
else:
self.meta = self.get_default_meta()
# set the affine
if affine is not None:
if "affine" in self.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.")
self.affine = affine
elif "affine" in self.meta:
pass # nothing to do
elif isinstance(x, MetaTensor):
self.affine = x.affine
else:
self.affine = self.get_default_affine()
self.affine = self.affine.to(self.device)

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
super()._copy_attr(attribute, input_objs, default_fn, deep_copy)
val = getattr(self, attribute)
if isinstance(val, torch.Tensor):
setattr(self, attribute, val.to(self.device))

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
"""Wraps all torch functions."""
if kwargs is None:
kwargs = {}
ret: MetaTensor = super().__torch_function__(func, types, args, kwargs)
# e.g., __repr__ returns a string
if not isinstance(ret, torch.Tensor):
return ret
if not (get_track_meta() or get_track_transforms()):
return ret.as_tensor()
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
ret._copy_meta(meta_args)
ret.affine = ret.affine.to(ret.device)
return ret

def get_default_affine(self) -> torch.Tensor:
return torch.eye(4, device=self.device)

def as_tensor(self) -> torch.Tensor:
"""
Return the `MetaTensor` as a `torch.Tensor`.
It is OS dependent as to whether this will be a deep copy or not.
"""
return self.as_subclass(torch.Tensor) # type: ignore

def as_dict(self, key: str) -> dict:
"""
Get the object as a dictionary for backwards compatibility.

Args:
key: Base key to store main data. The key for the metadata will be
determined using `PostFix.meta`.

Return:
A dictionary consisting of two keys, the main data (stored under `key`) and
the metadata.
"""
return {key: self.as_tensor(), PostFix.meta(key): self.meta}

@property
def affine(self) -> torch.Tensor:
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""Get the affine."""
return self.meta["affine"] # type: ignore

@affine.setter
def affine(self, d: torch.Tensor) -> None:
"""Set the affine."""
self.meta["affine"] = d
Loading