Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4922 adding a minimal lazy transform interface #5407

Merged
merged 5 commits into from
Oct 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class MetaObj:
def __init__(self):
self._meta: dict = MetaObj.get_default_meta()
self._applied_operations: list = MetaObj.get_default_applied_operations()
self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops
self._is_batch: bool = False

@staticmethod
Expand Down Expand Up @@ -199,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None:
def pop_applied_operation(self) -> Any:
return self._applied_operations.pop()

@property
def pending_operations(self) -> list[dict]:
"""Get the pending operations. Defaults to ``[]``."""
if hasattr(self, "_pending_operations"):
return self._pending_operations
return MetaObj.get_default_applied_operations() # the same default as applied_ops

def push_pending_operation(self, t: Any) -> None:
self._pending_operations.append(t)

def pop_pending_operation(self) -> Any:
return self._pending_operations.pop()

@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
Expand Down
18 changes: 16 additions & 2 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import MetaKeys, PostFix, SpaceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor

__all__ = ["MetaTensor"]

Expand Down Expand Up @@ -445,6 +445,20 @@ def pixdim(self):
return [affine_to_spacing(a) for a in self.affine]
return affine_to_spacing(self.affine)

def peek_pending_shape(self):
"""Get the currently expected spatial shape as if all the pending operations are executed."""
res = None
if self.pending_operations:
res = self.pending_operations[-1].get(LazyAttr.SHAPE, None)
# default to spatial shape (assuming channel-first input)
return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res

def peek_pending_affine(self):
res = None
if self.pending_operations:
res = self.pending_operations[-1].get(LazyAttr.AFFINE, None)
return self.affine if res is None else res

def new_empty(self, size, dtype=None, device=None, requires_grad=False):
"""
must be defined for deepcopy to work
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
InterpolateMode,
InverseKeys,
JITMetadataKeys,
LazyAttr,
LossReduction,
MetaKeys,
Method,
Expand Down
14 changes: 14 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"AlgoEnsembleKeys",
"HoVerNetMode",
"HoVerNetBranch",
"LazyAttr",
]


Expand Down Expand Up @@ -616,3 +617,16 @@ class HoVerNetBranch(StrEnum):
HV = "horizontal_vertical"
NP = "nucleus_prediction"
NC = "type_prediction"


class LazyAttr(StrEnum):
"""
MetaTensor with pending operations requires some key attributes tracked especially when the primary array
is not up-to-date due to lazy evaluation.
This class specifies the set of key attributes to be tracked for each MetaTensor.
"""

SHAPE = "lazy_shape" # spatial shape
AFFINE = "lazy_affine"
PADDING_MODE = "lazy_padding_mode"
INTERP_MODE = "lazy_interpolation_mode"
9 changes: 9 additions & 0 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,15 @@ def test_construct_with_pre_applied_transforms(self):
m = MetaTensor(im, applied_operations=data["im"].applied_operations)
self.assertEqual(len(m.applied_operations), len(tr.transforms))

def test_pending_ops(self):
m, _ = self.get_im()
self.assertEqual(m.pending_operations, [])
self.assertEqual(m.peek_pending_shape(), (10, 8))
self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)
m.push_pending_operation({})
self.assertEqual(m.peek_pending_shape(), (10, 8))
self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)

@parameterized.expand(TESTS)
def test_multiprocessing(self, device=None, dtype=None):
"""multiprocessing sharing with 'device' and 'dtype'"""
Expand Down