From f40f8ce90815c2fcbbcbcf8856824cf979e63920 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Thu, 23 Mar 2023 19:54:28 +0000 Subject: [PATCH] 4855 lazy resampling impl -- Compose (#5860) part of https://github.com/Project-MONAI/MONAI/issues/4855 upgrade https://github.com/Project-MONAI/MONAI/pull/4911 to use the latest dev API ### Description Example usage: for a sequence of spatial transforms ```py xforms = [ mt.LoadImageD(keys, ensure_channel_first=True), mt.Orientationd(keys, "RAS"), mt.SpacingD(keys, (1.5, 1.5, 1.5)), mt.CenterScaleCropD(keys, roi_scale=0.9), # mt.CropForegroundD(keys, source_key="seg", k_divisible=5), mt.RandRotateD(keys, prob=1.0, range_y=np.pi / 2, range_x=np.pi / 3), mt.RandSpatialCropD(keys, roi_size=(76, 87, 73)), mt.RandScaleCropD(keys, roi_scale=0.9), mt.Resized(keys, (30, 40, 60)), # mt.NormalizeIntensityd(keys), mt.ZoomD(keys, 1.3, keep_size=False), mt.FlipD(keys), mt.Rotate90D(keys), mt.RandAffined(keys), mt.ResizeWithPadOrCropd(keys, spatial_size=(32, 43, 54)), mt.DivisiblePadD(keys, k=3), ] lazy_kwargs = dict(mode=("bilinear", 0), padding_mode=("border", "nearest"), dtype=(torch.float32, torch.uint8)) xform = mt.Compose(xforms, lazy_evaluation=True, overrides=lazy_kwargs, override_keys=keys) xform.set_random_state(0) ``` lazy_evaluation=True preserves more details ![Screenshot 2023-01-17 at 00 31 40](https://user-images.githubusercontent.com/831580/212784981-ea39833b-54ab-42fb-bc03-38b012281857.png) compared with the regular compose ![Screenshot 2023-01-17 at 00 31 43](https://user-images.githubusercontent.com/831580/212785016-ba3be8ff-f17f-47b4-8025-cd351a637a82.png) ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li Signed-off-by: Yiheng Wang Signed-off-by: KumoLiu Signed-off-by: Ben Murray Co-authored-by: Ben Murray Co-authored-by: binliu Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: KumoLiu --- docs/source/transforms.rst | 6 + monai/data/dataset.py | 12 ++ monai/data/meta_obj.py | 9 ++ monai/data/meta_tensor.py | 2 +- monai/transforms/compose.py | 193 +++++++++++++++++++++-- monai/transforms/lazy/__init__.py | 5 + monai/transforms/lazy/functional.py | 87 ++++++----- monai/transforms/lazy/utils.py | 40 +++-- monai/utils/enums.py | 2 + monai/utils/misc.py | 31 ++++ tests/min_tests.py | 1 + tests/test_integration_lazy_samples.py | 205 +++++++++++++++++++++++++ tests/test_monai_utils_misc.py | 54 +++++++ tests/test_one_of.py | 15 +- tests/test_random_order.py | 16 +- tests/test_resample.py | 2 +- 16 files changed, 578 insertions(+), 102 deletions(-) create mode 100644 tests/test_integration_lazy_samples.py create mode 100644 tests/test_monai_utils_misc.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 56fe4bc1e77..584f67bc623 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -2206,3 +2206,9 @@ Utilities .. automodule:: monai.transforms.utils_pytorch_numpy_unification :members: + +Lazy +---- +.. automodule:: monai.transforms.lazy + :members: + :imported-members: diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 040d583b0d7..5ef8d7e9031 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -322,7 +322,9 @@ def _pre_transform(self, item_transformed): break # this is to be consistent with CacheDataset even though it's not in a multi-thread situation. _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform + item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform) item_transformed = apply_transform(_xform, item_transformed) + item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) if self.reset_ops_id: reset_ops_id(item_transformed) return item_transformed @@ -348,7 +350,9 @@ def _post_transform(self, item_transformed): or not isinstance(_transform, Transform) ): start_post_randomize_run = True + item_transformed = self.transform.evaluate_with_overrides(item_transformed, _transform) item_transformed = apply_transform(_transform, item_transformed) + item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) return item_transformed def _cachecheck(self, item_transformed): @@ -496,7 +500,9 @@ def _pre_transform(self, item_transformed): if i == self.cache_n_trans: break _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform + item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform) item_transformed = apply_transform(_xform, item_transformed) + item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) reset_ops_id(item_transformed) return item_transformed @@ -514,7 +520,9 @@ def _post_transform(self, item_transformed): raise ValueError("transform must be an instance of monai.transforms.Compose.") for i, _transform in enumerate(self.transform.transforms): if i >= self.cache_n_trans: + item_transformed = self.transform.evaluate_with_overrides(item_transformed, item_transformed) item_transformed = apply_transform(_transform, item_transformed) + item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) return item_transformed @@ -884,7 +892,9 @@ def _load_cache_item(self, idx: int): if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform): break _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform + item = self.transform.evaluate_with_overrides(item, _xform) item = apply_transform(_xform, item) + item = self.transform.evaluate_with_overrides(item, None) if self.as_contiguous: item = convert_to_contiguous(item, memory_format=torch.contiguous_format) return item @@ -921,7 +931,9 @@ def _transform(self, index: int): start_run = True if self.copy_cache: data = deepcopy(data) + data = self.transform.evaluate_with_overrides(data, _transform) data = apply_transform(_transform, data) + data = self.transform.evaluate_with_overrides(data, None) return data diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 86ce7e33fbe..0dccaa9e1c4 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -214,6 +214,15 @@ def pending_operations(self) -> list[dict]: return self._pending_operations return MetaObj.get_default_applied_operations() # the same default as applied_ops + @property + def has_pending_operations(self) -> bool: + """ + Determine whether there are pending operations. + Returns: + True if there are pending operations; False if not + """ + return self.pending_operations is not None and len(self.pending_operations) > 0 + def push_pending_operation(self, t: Any) -> None: self._pending_operations.append(t) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 48b9320f995..5a7eb1bbb45 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -492,7 +492,7 @@ def peek_pending_affine(self): continue res = convert_to_dst_type(res, next_matrix)[0] next_matrix = monai.data.utils.to_affine_nd(r, next_matrix) - res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix) + res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix) # type: ignore return res def peek_pending_rank(self): diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 45e706e143b..0997d53dadb 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -21,20 +21,95 @@ import numpy as np import monai +import monai.transforms as mt +from monai.apps.utils import get_logger from monai.transforms.inverse import InvertibleTransform # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 + LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform, ) -from monai.utils import MAX_SEED, ensure_tuple, get_seed -from monai.utils.enums import TraceKeys +from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed +from monai.utils.misc import to_tuple_of_dictionaries -__all__ = ["Compose", "OneOf", "RandomOrder"] +logger = get_logger(__name__) + +__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides"] + + +def evaluate_with_overrides( + data, + upcoming, + lazy_evaluation: bool | None = False, + overrides: dict | None = None, + override_keys: Sequence[str] | None = None, + verbose: bool = False, +): + """ + The previously applied transform may have been lazily applied to MetaTensor `data` and + made `data.has_pending_operations` equals to True. Given the upcoming transform ``upcoming``, + this function determines whether `data.pending_operations` should be evaluated. If so, it will + evaluate the lazily applied transforms. + + Currently, the conditions for evaluation are: + + - ``lazy_evaluation`` is ``True``, AND + - the data is a ``MetaTensor`` and has pending operations, AND + - the upcoming transform is an instance of ``Identity`` or ``IdentityD`` or ``None``. + + The returned `data` will then be ready for the ``upcoming`` transform. + + Args: + data: data to be evaluated. + upcoming: the upcoming transform. + lazy_evaluation: whether to evaluate the pending operations. + override: keyword arguments to apply transforms. + override_keys: to which the override arguments are used when apply transforms. + verbose: whether to print debugging info when evaluate MetaTensor with pending operations. + + """ + if not lazy_evaluation: + return data # eager evaluation + overrides = (overrides or {}).copy() + if isinstance(data, monai.data.MetaTensor): + if data.has_pending_operations and ((isinstance(upcoming, (mt.Identityd, mt.Identity))) or upcoming is None): + data, _ = mt.apply_transforms(data, None, overrides=overrides) + if verbose: + next_name = "final output" if upcoming is None else f"'{upcoming.__class__.__name__}'" + logger.info(f"Evaluated - '{override_keys}' - up-to-date for - {next_name}") + elif verbose: + logger.info( + f"Lazy - '{override_keys}' - upcoming: '{upcoming.__class__.__name__}'" + f"- pending {len(data.pending_operations)}" + ) + return data + override_keys = ensure_tuple(override_keys) + if isinstance(data, dict): + if isinstance(upcoming, MapTransform): + applied_keys = {k for k in data if k in upcoming.keys} + if not applied_keys: + return data + else: + applied_keys = set(data.keys()) + + keys_to_override = {k for k in applied_keys if k in override_keys} + # generate a list of dictionaries with the appropriate override value per key + dict_overrides = to_tuple_of_dictionaries(overrides, override_keys) + for k in data: + if k in keys_to_override: + dict_for_key = dict_overrides[override_keys.index(k)] + data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, dict_for_key, k, verbose) + else: + data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, None, k, verbose) + + if isinstance(data, (list, tuple)): + return [evaluate_with_overrides(v, upcoming, lazy_evaluation, overrides, override_keys, verbose) for v in data] + return data class Compose(Randomizable, InvertibleTransform): @@ -114,7 +189,21 @@ class Compose(Randomizable, InvertibleTransform): log_stats: whether to log the detailed information of data and applied transform when error happened, for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. default to `False`. - + lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be + carried out on a transform by transform basis. If True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. + A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of + the pending operations and make the primary data up-to-date. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation + is True. If lazy_evaluation is False they are ignored. + currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, + please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. + override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If + ``overrides`` is set, ``override_keys`` must also be set. + verbose: whether to print debugging info when lazy_evaluation=True. """ def __init__( @@ -123,6 +212,10 @@ def __init__( map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, + lazy_evaluation: bool | None = None, + overrides: dict | None = None, + override_keys: Sequence[str] | None = None, + verbose: bool = False, ) -> None: if transforms is None: transforms = [] @@ -132,6 +225,16 @@ def __init__( self.log_stats = log_stats self.set_random_state(seed=get_seed()) + self.lazy_evaluation = lazy_evaluation + self.overrides = overrides + self.override_keys = override_keys + self.verbose = verbose + + if self.lazy_evaluation is not None: + for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf + if isinstance(t, LazyTransform): + t.lazy_evaluation = self.lazy_evaluation + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose: super().set_random_state(seed=seed, state=state) for _transform in self.transforms: @@ -172,9 +275,26 @@ def __len__(self): """Return number of transformations.""" return len(self.flatten().transforms) + def evaluate_with_overrides(self, input_, upcoming_xform): + """ + Args: + input_: input data to be transformed. + upcoming_xform: a transform used to determine whether to evaluate with override + """ + return evaluate_with_overrides( + input_, + upcoming_xform, + lazy_evaluation=self.lazy_evaluation, + overrides=self.overrides, + override_keys=self.override_keys, + verbose=self.verbose, + ) + def __call__(self, input_): for _transform in self.transforms: + input_ = self.evaluate_with_overrides(input_, _transform) input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) + input_ = self.evaluate_with_overrides(input_, None) return input_ def inverse(self, data): @@ -204,7 +324,21 @@ class OneOf(Compose): log_stats: whether to log the detailed information of data and applied transform when error happened, for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. default to `False`. - + lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. If False, transforms will be + carried out on a transform by transform basis. + A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of + the pending operations and make the primary data up-to-date. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation + is True. If lazy_evaluation is False they are ignored. + currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, + please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. + override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If + ``overrides`` is set, ``override_keys`` must also be set. + verbose: whether to print debugging info when lazy_evaluation=True. """ def __init__( @@ -214,8 +348,14 @@ def __init__( map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, + lazy_evaluation: bool | None = None, + overrides: dict | None = None, + override_keys: Sequence[str] | None = None, + verbose: bool = False, ) -> None: - super().__init__(transforms, map_items, unpack_items, log_stats) + super().__init__( + transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose + ) if len(self.transforms) == 0: weights = [] elif weights is None or isinstance(weights, float): @@ -265,8 +405,8 @@ def __call__(self, data): self.push_transform(data, extra_info={"index": index}) elif isinstance(data, Mapping): for key in data: # dictionary not change size during iteration - if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: - self.push_transform(data, key, extra_info={"index": index}) + if isinstance(data[key], monai.data.MetaTensor): + self.push_transform(data[key], extra_info={"index": index}) return data def inverse(self, data): @@ -278,7 +418,7 @@ def inverse(self, data): index = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["index"] elif isinstance(data, Mapping): for key in data: - if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + if isinstance(data[key], monai.data.MetaTensor): index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] else: raise RuntimeError( @@ -306,7 +446,21 @@ class RandomOrder(Compose): log_stats: whether to log the detailed information of data and applied transform when error happened, for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. default to `False`. - + lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. If False, transforms will be + carried out on a transform by transform basis. + A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of + the pending operations and make the primary data up-to-date. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation + is True. If lazy_evaluation is False they are ignored. + currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, + please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. + override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If + ``overrides`` is set, ``override_keys`` must also be set. + verbose: whether to print debugging info when lazy_evaluation=True. """ def __init__( @@ -315,8 +469,14 @@ def __init__( map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, + lazy_evaluation: bool | None = None, + overrides: dict | None = None, + override_keys: Sequence[str] | None = None, + verbose: bool = False, ) -> None: - super().__init__(transforms, map_items, unpack_items, log_stats) + super().__init__( + transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose + ) def __call__(self, input_): if len(self.transforms) == 0: @@ -331,8 +491,8 @@ def __call__(self, input_): self.push_transform(input_, extra_info={"applied_order": applied_order}) elif isinstance(input_, Mapping): for key in input_: # dictionary not change size during iteration - if isinstance(input_[key], monai.data.MetaTensor) or self.trace_key(key) in input_: - self.push_transform(input_, key, extra_info={"applied_order": applied_order}) + if isinstance(input_[key], monai.data.MetaTensor): + self.push_transform(input_[key], extra_info={"applied_order": applied_order}) return input_ def inverse(self, data): @@ -344,7 +504,7 @@ def inverse(self, data): applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["applied_order"] elif isinstance(data, Mapping): for key in data: - if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + if isinstance(data[key], monai.data.MetaTensor): applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"] else: raise RuntimeError( @@ -356,5 +516,8 @@ def inverse(self, data): # loop backwards over transforms for o in reversed(applied_order): - data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats) + if isinstance(self.transforms[o], InvertibleTransform): + data = apply_transform( + self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats + ) return data diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py index 1e97f894078..02349dd0f2c 100644 --- a/monai/transforms/lazy/__init__.py +++ b/monai/transforms/lazy/__init__.py @@ -8,3 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +from .functional import apply_transforms +from .utils import combine_transforms, resample diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 0a2517cf878..334f271e052 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -13,7 +13,6 @@ from typing import Any -import numpy as np import torch from monai.data.meta_tensor import MetaTensor @@ -24,19 +23,15 @@ kwargs_from_pending, resample, ) -from monai.utils import LazyAttr +from monai.utils import LazyAttr, look_up_option __all__ = ["apply_transforms"] +__override_keywords = {"mode", "padding_mode", "dtype", "align_corners", "resample_mode", "device"} + def apply_transforms( - data: torch.Tensor | MetaTensor, - pending: list | None = None, - mode: str | int | None = None, - padding_mode: str | None = None, - dtype=np.float64, - align_corners: bool = False, - resample_mode: str | None = None, + data: torch.Tensor | MetaTensor, pending: list | None = None, overrides: dict | None = None, **kwargs: Any ): """ This method applies pending transforms to `data` tensors. @@ -45,26 +40,35 @@ def apply_transforms( Args: data: A torch Tensor or a monai MetaTensor. pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. - mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers). - Interpolation mode to calculate output values. Defaults to None. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used - and the value represents the order of the spline interpolation. - See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html - padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. Defaults to None. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - When `mode` is an integer, using numpy/cupy backends, this argument accepts - {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. - See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html - dtype: data type for resampling computation. Defaults to ``float64``. - If ``None``, use the data type of input data`. - align_corners: Geometrically, we consider the pixels of the input as squares rather than points, when using - the PyTorch resampling backend. Defaults to ``False``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - resample_mode: the mode of resampling, currently support ``"auto"``. Setting to other values will use the - `monai.transforms.SpatialResample` for resampling (instead of potentially crop/pad). + overrides: a dictionary of overrides for the transform arguments. The keys must be one of: + + - mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order ``0-5`` (integers). + Interpolation mode to calculate output values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When it's `an integer`, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used + and the value represents the order of the spline interpolation. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + - padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `mode` is an integer, using numpy/cupy backends, this argument accepts + {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + - dtype: data type for resampling computation. Defaults to ``float64``. + If ``None``, use the data type of input data, this option may not be compatible the resampling backend. + - align_corners: Geometrically, we consider the pixels of the input as squares rather than points, when using + the PyTorch resampling backend. Defaults to ``False``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + - device: device for resampling computation. Defaults to ``None``. + - resample_mode: the mode of resampling, currently support ``"auto"``. Setting to other values will use the + :py:class:`monai.transforms.SpatialResample` for resampling (instead of potentially crop/pad). + """ + overrides = (overrides or {}).copy() + overrides.update((kwargs or {}).copy()) + for k in overrides: + look_up_option(k, __override_keywords) # check existence of the key + if isinstance(data, MetaTensor) and pending is None: pending = data.pending_operations.copy() data.clear_pending_operations() @@ -76,15 +80,17 @@ def apply_transforms( cumulative_xform = affine_from_pending(pending[0]) cur_kwargs = kwargs_from_pending(pending[0]) override_kwargs: dict[str, Any] = {} - if mode is not None: - override_kwargs[LazyAttr.INTERP_MODE] = mode - if padding_mode is not None: - override_kwargs[LazyAttr.PADDING_MODE] = padding_mode - if align_corners is not None: - override_kwargs[LazyAttr.ALIGN_CORNERS] = align_corners - if resample_mode is not None: - override_kwargs["resample_mode"] = resample_mode - override_kwargs[LazyAttr.DTYPE] = data.dtype if dtype is None else dtype + if "mode" in overrides: + override_kwargs[LazyAttr.INTERP_MODE] = overrides["mode"] + if "padding_mode" in overrides: + override_kwargs[LazyAttr.PADDING_MODE] = overrides["padding_mode"] + if "align_corners" in overrides: + override_kwargs[LazyAttr.ALIGN_CORNERS] = overrides["align_corners"] + if "resample_mode" in overrides: + override_kwargs[LazyAttr.RESAMPLE_MODE] = overrides["resample_mode"] + override_dtype = overrides.get("dtype", torch.float64) + override_kwargs[LazyAttr.DTYPE] = data.dtype if override_dtype is None else override_dtype + device = overrides.get("device") for p in pending[1:]: new_kwargs = kwargs_from_pending(p) @@ -92,16 +98,13 @@ def apply_transforms( # carry out an intermediate resample here due to incompatibility between arguments _cur_kwargs = cur_kwargs.copy() _cur_kwargs.update(override_kwargs) - sp_size = _cur_kwargs.pop(LazyAttr.SHAPE, None) - data = resample(data, cumulative_xform, sp_size, _cur_kwargs) + data = resample(data.to(device), cumulative_xform, _cur_kwargs) next_matrix = affine_from_pending(p) cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) cur_kwargs.update(override_kwargs) - sp_size = cur_kwargs.pop(LazyAttr.SHAPE, None) - data = resample(data, cumulative_xform, sp_size, cur_kwargs) + data = resample(data.to(device), cumulative_xform, cur_kwargs) if isinstance(data, MetaTensor): for p in pending: data.push_applied_operation(p) - return data, pending diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 1cdd4066353..61973fdab6f 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -20,7 +20,7 @@ from monai.config import NdarrayOrTensor from monai.data.utils import AFFINE_TOL from monai.transforms.utils_pytorch_numpy_unification import allclose -from monai.utils import LazyAttr, convert_to_numpy, convert_to_tensor +from monai.utils import LazyAttr, convert_to_numpy, convert_to_tensor, look_up_option __all__ = ["resample", "combine_transforms"] @@ -135,30 +135,33 @@ def requires_interp(matrix, atol=AFFINE_TOL): y_channel = y + 1 # the returned axis index starting with channel dim if x in ox or y_channel in oy: return None - else: - ox.append(x) - oy.append(y_channel) + ox.append(x) + oy.append(y_channel) elif not np.isclose(c, 0.0, atol=atol): return None return oy -def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs: dict | None = None): +__override_lazy_keywords = {*list(LazyAttr), "atol"} + + +def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None): """ - Resample `data` using the affine transformation defined by ``matrix`` and output spatial size ``spatial_size``. + Resample `data` using the affine transformation defined by ``matrix``. Args: data: input data to be resampled. matrix: affine transformation matrix. - spatial_size: output spatial size. kwargs: currently supports (see also: ``monai.utils.enums.LazyAttr``) - - "lazy_dtype" + + - "lazy_shape" for output spatial shape - "lazy_padding_mode" - "lazy_interpolation_mode" (this option might be ignored when ``mode="auto"``.) - "lazy_align_corners" + - "lazy_dtype" - "atol" for tolerance for matrix floating point comparison. - - "resample_mode" for resampling backend, default to `"auto"`. Setting to other values will use the - `monai.transforms.SpatialResample` for resampling. + - "lazy_resample_mode" for resampling backend, default to `"auto"`. Setting to other values will use the + `monai.transforms.SpatialResample` for resampling. See Also: :py:class:`monai.transforms.SpatialResample` @@ -167,24 +170,27 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs: raise NotImplementedError(f"Calling the dense grid resample API directly not implemented, {matrix.shape}.") if isinstance(data, monai.data.MetaTensor) and data.pending_operations: warnings.warn("data.pending_operations is not empty, the resampling output may be incorrect.") - kwargs = {} if kwargs is None else kwargs - atol = kwargs.pop("atol", AFFINE_TOL) - mode = kwargs.pop("resample_mode", "auto") + kwargs = kwargs or {} + for k in kwargs: + look_up_option(k, __override_lazy_keywords) + atol = kwargs.get("atol", AFFINE_TOL) + mode = kwargs.get(LazyAttr.RESAMPLE_MODE, "auto") init_kwargs = { - "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), - "align_corners": kwargs.pop(LazyAttr.ALIGN_CORNERS, False), + "dtype": kwargs.get(LazyAttr.DTYPE, data.dtype), + "align_corners": kwargs.get(LazyAttr.ALIGN_CORNERS, False), } ndim = len(matrix) - 1 img = convert_to_tensor(data=data, track_meta=monai.data.get_track_meta()) init_affine = monai.data.to_affine_nd(ndim, img.affine) + spatial_size = kwargs.get(LazyAttr.SHAPE, None) out_spatial_size = img.peek_pending_shape() if spatial_size is None else spatial_size out_spatial_size = convert_to_numpy(out_spatial_size, wrap_sequence=True) call_kwargs = { "spatial_size": out_spatial_size, "dst_affine": init_affine @ monai.utils.convert_to_dst_type(matrix, init_affine)[0], - "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), - "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), + "mode": kwargs.get(LazyAttr.INTERP_MODE), + "padding_mode": kwargs.get(LazyAttr.PADDING_MODE), } axes = requires_interp(matrix, atol=atol) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 8fd79a24da7..6b01e43b47f 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -631,6 +631,7 @@ class LazyAttr(StrEnum): MetaTensor with pending operations requires some key attributes tracked especially when the primary array is not up-to-date due to lazy evaluation. This class specifies the set of key attributes to be tracked for each MetaTensor. + See also: :py:func:`monai.transforms.lazy.utils.resample` for more details. """ SHAPE = "lazy_shape" # spatial shape @@ -639,6 +640,7 @@ class LazyAttr(StrEnum): INTERP_MODE = "lazy_interpolation_mode" DTYPE = "lazy_dtype" ALIGN_CORNERS = "lazy_align_corners" + RESAMPLE_MODE = "lazy_resample_mode" class BundleProperty(StrEnum): diff --git a/monai/utils/misc.py b/monai/utils/misc.py index a729688209f..f22716a3768 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -191,6 +191,37 @@ def ensure_tuple_rep(tup: Any, dim: int) -> tuple[Any, ...]: raise ValueError(f"Sequence must have length {dim}, got {len(tup)}.") +def to_tuple_of_dictionaries(dictionary_of_tuples: dict, keys: Any) -> tuple[dict[Any, Any], ...]: + """ + Given a dictionary whose values contain scalars or tuples (with the same length as ``keys``), + Create a dictionary for each key containing the scalar values mapping to that key. + + Args: + dictionary_of_tuples: a dictionary whose values are scalars or tuples whose length is + the length of ``keys`` + keys: a tuple of string values representing the keys in question + + Returns: + a tuple of dictionaries that contain scalar values, one dictionary for each key + + Raises: + ValueError: when values in the dictionary are tuples but not the same length as the length + of ``keys`` + + Examples: + >>> to_tuple_of_dictionaries({'a': 1 'b': (2, 3), 'c': (4, 4)}, ("x", "y")) + ({'a':1, 'b':2, 'c':4}, {'a':1, 'b':3, 'c':4}) + + """ + + keys = ensure_tuple(keys) + if len(keys) == 0: + return tuple({}) + + dict_overrides = {k: ensure_tuple_rep(v, len(keys)) for k, v in dictionary_of_tuples.items()} + return tuple({k: v[ik] for (k, v) in dict_overrides.items()} for ik in range(len(keys))) + + def fall_back_tuple( user_provided: Any, default: Sequence | NdarrayTensor, func: Callable = lambda x: x and x > 0 ) -> tuple[Any, ...]: diff --git a/tests/min_tests.py b/tests/min_tests.py index 36a6c11adc4..05f117013ef 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -111,6 +111,7 @@ def run_testsuit(): "test_integration_fast_train", "test_integration_gpu_customization", "test_integration_segmentation_3d", + "test_integration_lazy_samples", "test_integration_sliding_window", "test_integration_unet_2d", "test_integration_workflows", diff --git a/tests/test_integration_lazy_samples.py b/tests/test_integration_lazy_samples.py new file mode 100644 index 00000000000..807ab23f086 --- /dev/null +++ b/tests/test_integration_lazy_samples.py @@ -0,0 +1,205 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest +from glob import glob + +import nibabel as nib +import numpy as np +import torch + +import monai +import monai.transforms as mt +from monai.data import create_test_image_3d +from monai.utils import set_determinism +from tests.utils import DistTestCase, SkipIfBeforePyTorchVersion, skip_if_quick + + +def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None), num_workers=4, lazy=True): + print(f"test case: {locals()}") + images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) + segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) + train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] + device = "cuda:0" if torch.cuda.is_available() else "cpu" + num_workers = 0 if torch.cuda.is_available() else num_workers + + # define transforms for image and segmentation + lazy_kwargs = dict( + mode=("bilinear", 0), device=device, padding_mode=("border", "nearest"), dtype=(torch.float32, torch.uint8) + ) + train_transforms = mt.Compose( + [ + mt.LoadImaged(keys=["img", "seg"], reader=readers[0], image_only=True), + mt.EnsureChannelFirstd(keys=["img", "seg"]), + mt.Spacingd( + keys=["img", "seg"], + pixdim=[1.2, 0.8, 0.7], + mode=["bilinear", 0], + padding_mode=("border", "nearest"), + dtype=np.float32, + ), + mt.Orientationd(keys=["img", "seg"], axcodes="ARS"), + mt.RandRotate90d(keys=["img", "seg"], prob=1.0, spatial_axes=(1, 2)), + mt.ScaleIntensityd(keys="img"), + mt.IdentityD(keys=["seg"]), + mt.RandCropByPosNegLabeld( + keys=["img", "seg"], label_key="seg", spatial_size=[76, 82, 80], pos=1, neg=1, num_samples=4 + ), + mt.RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=(0, 2)), + mt.ResizeWithPadOrCropD(keys=["img", "seg"], spatial_size=[80, 72, 80]), + mt.Rotated(keys=["img", "seg"], angle=[np.pi / 2, np.pi / 2, 0], mode="nearest", keep_size=False), + ], + lazy_evaluation=lazy, + overrides=lazy_kwargs, + override_keys=("img", "seg"), + verbose=num_workers > 0, # testing both flags + ) + + # create a training data loader + if cachedataset == 2: + train_ds = monai.data.CacheDataset( + data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache=False, num_workers=0 + ) + elif cachedataset == 3: + train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir) + else: + train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + + # create UNet, DiceLoss and Adam optimizer + model = monai.networks.nets.UNet( + spatial_dims=3, in_channels=1, out_channels=1, channels=(2, 2, 2, 2), strides=(2, 2, 2), num_res_units=2 + ).to(device) + optimizer = torch.optim.Adam(model.parameters(), 5e-4) + loss_function = monai.losses.DiceLoss(sigmoid=True) + + saver = mt.SaveImage( + output_dir=os.path.join(root_dir, "output"), + dtype=np.float32, + output_ext=".nii.gz", + output_postfix=f"seg_{lazy}_{num_workers}", + mode="bilinear", + resample=False, + separate_folder=False, + print_log=False, + ) + + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + _g = torch.Generator() + _g.manual_seed(0) + set_determinism(0) + train_loader = monai.data.DataLoader( + train_ds, batch_size=1, shuffle=True, num_workers=num_workers, generator=_g, persistent_workers=num_workers > 0 + ) + all_coords = set() + for epoch in range(5): + print("-" * 10) + print(f"Epoch {epoch + 1}/5") + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_len = len(train_ds) // train_loader.batch_size + print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") + + for item, in_img, in_seg in zip(outputs, inputs, labels): # this decollates the batch, pt 1.9+ + item.copy_meta_from(in_img) + np.testing.assert_array_equal(item.pending_operations, []) + np.testing.assert_array_equal(in_seg.pending_operations, []) + ops = [0] + if len(item.applied_operations) > 1: + found = False + for idx, n in enumerate(item.applied_operations): # noqa + if n["class"] == "RandCropByPosNegLabel": + found = True + break + if found: + ops = item.applied_operations[idx]["extra_info"]["extra_info"]["cropped"] + img_name = os.path.basename(item.meta["filename_or_obj"]) + coords = f"{img_name} - {ops}" + print(coords) + # np.testing.assert_allclose(coords in all_coords, False) + all_coords.add(coords) + saver(item) # just testing the saving + saver(in_img) + saver(in_seg) + return ops + + +@skip_if_quick +@SkipIfBeforePyTorchVersion((1, 11)) +class IntegrationLazyResampling(DistTestCase): + def setUp(self): + monai.config.print_config() + set_determinism(seed=0) + + self.data_dir = tempfile.mkdtemp() + for i in range(3): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"seg{i:d}.nii.gz")) + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu:0" + + def tearDown(self): + set_determinism(seed=None) + shutil.rmtree(self.data_dir) + + def train_and_infer(self, idx=0): + results = [] + _readers = (None, None) + _w = 2 + if idx == 1: + _readers = ("itkreader", "itkreader") + _w = 1 + elif idx == 2: + _readers = ("itkreader", "nibabelreader") + _w = 0 + results = run_training_test( + self.data_dir, device=self.device, cachedataset=idx, readers=_readers, num_workers=_w, lazy=True + ) + results_expected = run_training_test( + self.data_dir, device=self.device, cachedataset=0, readers=_readers, num_workers=_w, lazy=False + ) + self.assertFalse(np.allclose(results, [0])) + self.assertFalse(np.allclose(results_expected, [0])) + np.testing.assert_allclose(results, results_expected) + lazy_files = glob(os.path.join(self.data_dir, "output", "*_True_*.nii.gz")) + regular_files = glob(os.path.join(self.data_dir, "output", "*_False_*.nii.gz")) + diffs = [] + for a, b in zip(sorted(lazy_files), sorted(regular_files)): + img_lazy = mt.LoadImage(image_only=True)(a) + img_regular = mt.LoadImage(image_only=True)(b) + diff = np.size(img_lazy) - np.sum(np.isclose(img_lazy, img_regular, atol=1e-4)) + diff_rate = diff / np.size(img_lazy) + diffs.append(diff_rate) + np.testing.assert_allclose(diff_rate, 0.0, atol=0.03) + print("volume diff:", diffs) + return results + + def test_training(self): + for i in range(4): + self.train_and_infer(i) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_monai_utils_misc.py b/tests/test_monai_utils_misc.py new file mode 100644 index 00000000000..46633e85ab0 --- /dev/null +++ b/tests/test_monai_utils_misc.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from parameterized import parameterized + +from monai.utils.misc import to_tuple_of_dictionaries + +TO_TUPLE_OF_DICTIONARIES_TEST_CASES = [ + ({}, tuple(), tuple()), + ({}, ("x",), ({},)), + ({}, ("x", "y"), ({}, {})), + ({"a": 1}, tuple(), tuple()), + ({"a": 1}, ("x",), ({"a": 1},)), + ({"a": (1,)}, ("x",), ({"a": 1},)), + ({"a": (1,)}, ("x", "y"), ValueError()), + ({"a": 1}, ("x", "y"), ({"a": 1}, {"a": 1})), + ({"a": (1, 2)}, tuple(), tuple()), + ({"a": (1, 2)}, ("x", "y"), ({"a": 1}, {"a": 2})), + ({"a": (1, 2, 3)}, ("x", "y"), ValueError()), + ({"b": (2,), "a": 1}, tuple(), tuple()), + ({"b": (2,), "a": 1}, ("x",), ({"b": 2, "a": 1},)), + ({"b": (2,), "a": 1}, ("x", "y"), ValueError()), + ({"b": (3, 2), "a": 1}, tuple(), tuple()), + ({"b": (3, 2), "a": 1}, ("x",), ValueError()), + ({"b": (3, 2), "a": 1}, ("x", "y"), ({"b": 3, "a": 1}, {"b": 2, "a": 1})), +] + + +class TestToTupleOfDictionaries(unittest.TestCase): + @parameterized.expand(TO_TUPLE_OF_DICTIONARIES_TEST_CASES) + def test_to_tuple_of_dictionaries(self, dictionary, keys, expected): + self._test_to_tuple_of_dictionaries(dictionary, keys, expected) + + def _test_to_tuple_of_dictionaries(self, dictionary, keys, expected): + if isinstance(expected, Exception): + with self.assertRaises(type(expected)): + to_tuple_of_dictionaries(dictionary, keys) + print(type(expected)) + else: + actual = to_tuple_of_dictionaries(dictionary, keys) + print(actual, expected) + self.assertTupleEqual(actual, expected) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 687ec71aad0..36980c23a76 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -27,7 +27,6 @@ RandShiftIntensityd, Resize, Resized, - TraceableTransform, Transform, ) from monai.transforms.compose import Compose @@ -156,16 +155,12 @@ def _match(a, b): @parameterized.expand(TEST_INVERSES) def test_inverse(self, transform, invertible, use_metatensor): - data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)} + data = {k: MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)} fwd_data = transform(data) if invertible: for k in KEYS: - t = ( - fwd_data[TraceableTransform.trace_key(k)][-1] - if not use_metatensor - else fwd_data[k].applied_operations[-1] - ) + t = fwd_data[k].applied_operations[-1] # make sure the OneOf index was stored self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) # make sure index exists and is in bounds @@ -176,12 +171,6 @@ def test_inverse(self, transform, invertible, use_metatensor): if invertible: for k in KEYS: - # check transform was removed - if not use_metatensor: - self.assertTrue( - len(fwd_inv_data[TraceableTransform.trace_key(k)]) - < len(fwd_data[TraceableTransform.trace_key(k)]) - ) # check data is same as original (and different from forward) self.assertEqual(fwd_inv_data[k], data[k]) self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) diff --git a/tests/test_random_order.py b/tests/test_random_order.py index a60202dd788..9ed22d30aee 100644 --- a/tests/test_random_order.py +++ b/tests/test_random_order.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.data import MetaTensor -from monai.transforms import RandomOrder, TraceableTransform +from monai.transforms import RandomOrder from monai.transforms.compose import Compose from monai.utils import set_determinism from monai.utils.enums import TraceKeys @@ -70,18 +70,14 @@ def _match(a, b): @parameterized.expand(TEST_INVERSES) def test_inverse(self, transform, invertible, use_metatensor): - data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)} + data = {k: MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)} fwd_data1 = transform(data) # test call twice won't affect inverse fwd_data2 = transform(data) if invertible: for k in KEYS: - t = ( - fwd_data1[TraceableTransform.trace_key(k)][-1] - if not use_metatensor - else fwd_data1[k].applied_operations[-1] - ) + t = fwd_data1[k].applied_operations[-1] # make sure the RandomOrder applied_order was stored self.assertEqual(t[TraceKeys.CLASS_NAME], RandomOrder.__name__) @@ -94,12 +90,6 @@ def test_inverse(self, transform, invertible, use_metatensor): for i, _fwd_inv_data in enumerate(fwd_inv_data): if invertible: for k in KEYS: - # check transform was removed - if not use_metatensor: - self.assertTrue( - len(_fwd_inv_data[TraceableTransform.trace_key(k)]) - < len(fwd_data[i][TraceableTransform.trace_key(k)]) - ) # check data is same as original (and different from forward) self.assertEqual(_fwd_inv_data[k], data[k]) self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k]) diff --git a/tests/test_resample.py b/tests/test_resample.py index 2df1b7a3ff0..4f9436f8ce6 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -34,7 +34,7 @@ def rotate_90_2d(): class TestResampleFunction(unittest.TestCase): @parameterized.expand(RESAMPLE_FUNCTION_CASES) def test_resample_function_impl(self, img, matrix, expected): - out = resample(convert_to_tensor(img), matrix, img.shape[1:], {"lazy_padding_mode": "border"}) + out = resample(convert_to_tensor(img), matrix, {"lazy_shape": img.shape[1:], "lazy_padding_mode": "border"}) assert_allclose(out[0], expected, type_test=False)