Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

meta tensor #4077

Merged
merged 33 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e15ec97
meta tensor
rijobro Apr 5, 2022
20f1b47
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro Apr 7, 2022
accca59
fixes
rijobro Apr 7, 2022
a15d462
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro Apr 7, 2022
36a1e75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2022
f71d60f
deep_copy
rijobro Apr 7, 2022
51e6b3b
fixes
rijobro Apr 7, 2022
1e9f370
torchscript
rijobro Apr 7, 2022
76e0cce
test pickling torchscript amp
rijobro Apr 7, 2022
de02d7f
fxies
rijobro Apr 7, 2022
8a60df1
test fixes
rijobro Apr 7, 2022
70bbcdb
fixes
rijobro Apr 7, 2022
eabd3bf
typos
rijobro Apr 8, 2022
4af4d50
fixes
rijobro Apr 8, 2022
4a2c211
fix test
rijobro Apr 8, 2022
6373711
fix
rijobro Apr 8, 2022
07f5117
fix?
rijobro Apr 8, 2022
723f475
affine lives inside meta
rijobro Apr 8, 2022
b25440c
move affine in meta
rijobro Apr 8, 2022
3e2eb02
flake8
rijobro Apr 11, 2022
b6a1f56
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro Apr 11, 2022
32cc5b3
fixes
rijobro Apr 11, 2022
7fe23fd
fixes
rijobro Apr 11, 2022
7f43b00
fixes
rijobro Apr 11, 2022
840e7df
pytorch min version 1.7
rijobro Apr 11, 2022
e2256da
Revert "pytorch min version 1.7"
rijobro Apr 11, 2022
79ce908
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_transforms
rijobro Apr 11, 2022
e3e567c
torch 1.9
rijobro Apr 11, 2022
87b572e
Merge branch 'dev' into MetaTensor_1st_PR
wyli Apr 12, 2022
56bc6d5
remove __init__, correct docstring
rijobro Apr 12, 2022
09eb602
adds docs
wyli Apr 12, 2022
7064843
Merge pull request #3 from wyli/adds-docs
rijobro Apr 12, 2022
81404ed
Merge branch 'dev' into MetaTensor_1st_PR
rijobro Apr 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_tensor import MetaTensor
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
from .png_saver import PNGSaver
Expand Down
172 changes: 172 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from copy import deepcopy
from typing import Any, Callable, Sequence

import numpy as np
import torch

_TRACK_META = True
_TRACK_TRANSFORMS = True

__all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"]


def set_track_meta(val: bool) -> None:
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""
Boolean to set whether metadata is tracked. If `True`,
`MetaTensor` will be returned where appropriate. If `False`,
`torch.Tensor` will be returned instead.
"""
global _TRACK_META
_TRACK_META = val


def set_track_transforms(val: bool) -> None:
"""
Boolean to set whether transforms are tracked.
"""
global _TRACK_TRANSFORMS
_TRACK_TRANSFORMS = val


def get_track_meta() -> bool:
"""
Get track meta data boolean.
"""
global _TRACK_META
return _TRACK_META


def get_track_transforms() -> bool:
"""
Get track transform boolean.
"""
global _TRACK_TRANSFORMS
return _TRACK_TRANSFORMS


class MetaObj:
"""
Class that stores meta and affine.

We store the affine as its own element, so that this can be updated by
transforms. All other meta data that we don't plan on touching until we
need to save the image to file lives in `meta`.

This allows for subclassing `np.ndarray` and `torch.Tensor`.

Copying metadata:
* For `c = a + b`, then the meta data will be copied from the first
instance of `MetaImage`.
"""

_meta: dict
_affine: torch.Tensor

def set_initial_val(self, attribute: str, input_arg: Any, input_tensor: MetaObj, default_fn: Callable) -> None:
"""
rijobro marked this conversation as resolved.
Show resolved Hide resolved
Set the initial value. Try to use input argument, but if this is None
and there is a MetaImage input, then copy that. Failing both these two,
use a default value.
"""
if input_arg is None:
input_arg = getattr(input_tensor, attribute, None)
if input_arg is None:
input_arg = default_fn(self)
setattr(self, attribute, input_arg)

@staticmethod
def get_tensors_or_arrays(args: Sequence[Any]) -> list[MetaObj]:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
"""
Recursively extract all instances of `MetaObj`.
Works for `torch.add(a, b)`, `torch.stack([a, b])` and numpy equivalents.
"""
out = []
for a in args:
if isinstance(a, (list, tuple)):
out += MetaObj.get_tensors_or_arrays(a)
elif isinstance(a, MetaObj):
out.append(a)
return out

def _copy_attr(
self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deepcopy_required: bool
rijobro marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Copy an attribute from the first in a list of `MetaObj`
In the cases `torch.add(a, b)` and `torch.add(input=a, other=b)`,
both `a` and `b` could be `MetaObj` or `torch.Tensor` so check
them all. Copy the first to the output, and make sure on correct
device.
Might have the MetaObj nested in list, e.g., `torch.stack([a, b])`.
"""
attributes = [getattr(i, attribute) for i in input_objs]
if len(attributes) > 0:
val = attributes[0]
if deepcopy_required:
val = deepcopy(val)
if isinstance(self, torch.Tensor) and isinstance(val, torch.Tensor):
rijobro marked this conversation as resolved.
Show resolved Hide resolved
val = val.to(self.device)
setattr(self, attribute, val)
else:
setattr(self, attribute, default_fn())

def _copy_meta(self, input_meta_objs: list[MetaObj]) -> None:
"""
Copy meta data from a list of `MetaObj`.
If there has been a change in `id` (e.g., `a+b`), then deepcopy. Else (e.g., `a+=1`), don't.
"""
id_in = id(input_meta_objs[0]) if len(input_meta_objs) > 0 else None
deepcopy_required = id(self) != id_in
attributes = ("affine", "meta")
default_fns: tuple[Callable, ...] = (self.get_default_affine, self.get_default_meta)
for attribute, default_fn in zip(attributes, default_fns):
self._copy_attr(attribute, input_meta_objs, default_fn, deepcopy_required)

def get_default_meta(self) -> dict:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
return {}

def get_default_affine(self) -> torch.Tensor | np.ndarray:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError()

def __repr__(self) -> str:
"""String representation of class."""
out: str = super().__repr__()

out += f"\nAffine\n{self.affine}"

out += "\nMetaData\n"
if self.meta is not None:
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items())
else:
out += "None"

return out

@property
def affine(self) -> torch.Tensor:
return self._affine

@affine.setter
def affine(self, d: torch.Tensor) -> None:
self._affine = d

@property
def meta(self) -> dict:
return self._meta
wyli marked this conversation as resolved.
Show resolved Hide resolved

@meta.setter
def meta(self, d: dict):
self._meta = d
74 changes: 74 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch

from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms

__all__ = ["MetaTensor"]


class MetaTensor(MetaObj, torch.Tensor):
"""
Class that extends upon `torch.Tensor`, adding support for meta data.

We store the affine as its own element, so that this can be updated by
transforms. All other meta data that we don't plan on touching until we
need to save the image to file lives in `meta`.

Behavior should be the same as `torch.Tensor` aside from the extended
functionality.

Copying metadata:
* For `c = a + b`, then the meta data will be copied from the first
instance of `MetaTensor`.
"""

@staticmethod
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore

def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None:
rijobro marked this conversation as resolved.
Show resolved Hide resolved
"""If `affine` is given, use it. Else, if `affine` exists in the input tensor, use it. Else, use
the default value. The same is true for `meta` and `transforms`."""
self.set_initial_val("affine", affine, x, self.get_default_affine)
self.set_initial_val("meta", meta, x, self.get_default_meta)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
"""Wraps all torch functions."""
if kwargs is None:
kwargs = {}
ret: MetaTensor = super().__torch_function__(func, types, args, kwargs)
# e.g., __repr__ returns a string
if not isinstance(ret, torch.Tensor):
return ret
if not (get_track_meta() or get_track_transforms()):
return ret.as_tensor()
meta_args = MetaObj.get_tensors_or_arrays(list(args) + list(kwargs.values()))
ret._copy_meta(meta_args)
return ret

def get_default_affine(self) -> torch.Tensor:
return torch.eye(4, device=self.device)

def as_tensor(self) -> torch.Tensor:
"""
Return the `MetaTensor` as a `torch.Tensor`.
It is OS dependent as to whether this will be a deep copy or not.
"""
with warnings.catch_warnings():
rijobro marked this conversation as resolved.
Show resolved Hide resolved
warnings.simplefilter("ignore")
return torch.tensor(self)
Loading