diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 0f34458697..75cbec5607 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -660,6 +660,7 @@ rescale_instance_array, reset_ops_id, resize_center, + resolves_modes, sync_meta_info, weighted_patch_samples, zero_margins, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f263e89152..6fe433a0bc 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -15,7 +15,6 @@ from __future__ import annotations -import functools import warnings from collections.abc import Callable from copy import deepcopy @@ -54,6 +53,7 @@ create_shear, create_translate, map_spatial_axes, + resolves_modes, scale_affine, ) from monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis @@ -61,9 +61,7 @@ GridSampleMode, GridSamplePadMode, InterpolateMode, - NdimageMode, NumpyPadMode, - SplineMode, convert_to_cupy, convert_to_dst_type, convert_to_numpy, @@ -695,7 +693,7 @@ def __init__( ) -> None: self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size - self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) + self.mode = mode self.align_corners = align_corners self.anti_aliasing = anti_aliasing self.anti_aliasing_sigma = anti_aliasing_sigma @@ -759,7 +757,7 @@ def __call__( scale = self.spatial_size / max(img_size) sp_size = tuple(int(round(s * scale)) for s in img_size) - _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) + _mode = self.mode if mode is None else mode _align_corners = self.align_corners if align_corners is None else align_corners _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) return resize( # type: ignore @@ -831,8 +829,8 @@ def __init__( ) -> None: self.angle = angle self.keep_size = keep_size - self.mode: str = look_up_option(mode, GridSampleMode) - self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode) + self.mode: str = mode + self.padding_mode: str = padding_mode self.align_corners = align_corners self.dtype = dtype @@ -867,8 +865,8 @@ def __call__( """ img = convert_to_tensor(img, track_meta=get_track_meta()) _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - _mode = look_up_option(mode or self.mode, GridSampleMode) - _padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode) + _mode = mode or self.mode + _padding_mode = padding_mode or self.padding_mode _align_corners = self.align_corners if align_corners is None else align_corners im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] output_shape = im_shape if self.keep_size else None @@ -888,10 +886,11 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat)) + _, _m, _p, _ = resolves_modes(mode, padding_mode) xform = AffineTransform( normalized=False, - mode=mode, - padding_mode=padding_mode, + mode=_m, + padding_mode=_p, align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) @@ -953,7 +952,7 @@ def __init__( **kwargs, ) -> None: self.zoom = zoom - self.mode: InterpolateMode = InterpolateMode(mode) + self.mode = mode self.padding_mode = padding_mode self.align_corners = align_corners self.dtype = dtype @@ -991,7 +990,7 @@ def __call__( """ img = convert_to_tensor(img, track_meta=get_track_meta()) _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim - _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value + _mode = self.mode if mode is None else mode _padding_mode = padding_mode or self.padding_mode _align_corners = self.align_corners if align_corners is None else align_corners _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) @@ -1181,8 +1180,8 @@ def __init__( self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) self.keep_size = keep_size - self.mode: str = look_up_option(mode, GridSampleMode) - self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode) + self.mode: str = mode + self.padding_mode: str = padding_mode self.align_corners = align_corners self.dtype = dtype @@ -1231,8 +1230,8 @@ def __call__( rotator = Rotate( angle=self.x if ndim == 2 else (self.x, self.y, self.z), keep_size=self.keep_size, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), + mode=mode or self.mode, + padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) @@ -1406,7 +1405,7 @@ def __init__( raise ValueError( f"min_zoom and max_zoom must have same length, got {len(self.min_zoom)} and {len(self.max_zoom)}." ) - self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) + self.mode = mode self.padding_mode = padding_mode self.align_corners = align_corners self.dtype = dtype @@ -1467,7 +1466,7 @@ def __call__( xform = Zoom( self._zoom, keep_size=self.keep_size, - mode=look_up_option(mode or self.mode, InterpolateMode), + mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype, @@ -1815,35 +1814,6 @@ def __init__( self.align_corners = align_corners self.dtype = dtype - @staticmethod - @functools.lru_cache(None) - def resolve_modes(interp_mode, padding_mode): - """compute the backend and the corresponding mode for the given interpolation mode and padding mode.""" - _interp_mode = None - _padding_mode = None - if look_up_option(str(interp_mode), SplineMode, default=None) is not None: - backend = TransformBackends.NUMPY - else: - backend = TransformBackends.TORCH - - if (not USE_COMPILED) and (backend == TransformBackends.TORCH): - if str(interp_mode).lower().endswith("linear"): - _interp_mode = GridSampleMode("bilinear") - _interp_mode = GridSampleMode(interp_mode) - _padding_mode = GridSamplePadMode(padding_mode) - elif USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name - _padding_mode = 1 if padding_mode == "reflection" else padding_mode # type: ignore - if interp_mode == "bicubic": - _interp_mode = 3 # type: ignore - elif interp_mode == "bilinear": - _interp_mode = 1 # type: ignore - else: - _interp_mode = GridSampleMode(interp_mode) - else: # TransformBackends.NUMPY - _interp_mode = int(interp_mode) # type: ignore - _padding_mode = look_up_option(padding_mode, NdimageMode) - return backend, _interp_mode, _padding_mode - def __call__( self, img: torch.Tensor, @@ -1894,8 +1864,11 @@ def __call__( _align_corners = self.align_corners if align_corners is None else align_corners img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device) sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3) - backend, _interp_mode, _padding_mode = Resample.resolve_modes( - self.mode if mode is None else mode, self.padding_mode if padding_mode is None else padding_mode + backend, _interp_mode, _padding_mode, _ = resolves_modes( + self.mode if mode is None else mode, + self.padding_mode if padding_mode is None else padding_mode, + backend=None, + use_compiled=USE_COMPILED, ) if USE_COMPILED or backend == TransformBackends.NUMPY: diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index e78ee75cb7..591ebbb489 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -32,7 +32,7 @@ from monai.transforms.croppad.array import ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform -from monai.transforms.utils import create_rotate, create_translate, scale_affine +from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine from monai.transforms.utils_pytorch_numpy_unification import allclose from monai.utils import ( LazyAttr, @@ -172,8 +172,9 @@ def spatial_resample( with affine_xform.trace_transform(False): img = affine_xform(img, mode=mode, padding_mode=padding_mode) else: + _, _m, _p, _ = resolves_modes(mode, padding_mode) affine_xform = AffineTransform( # type: ignore - normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True + normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True ) img = affine_xform(img.unsqueeze(0), theta=xform.to(img), spatial_size=spatial_size).squeeze(0) # type: ignore if additional_dims: @@ -331,8 +332,9 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1) anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) + _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_.shape) - 1) resized = torch.nn.functional.interpolate( - input=img_.unsqueeze(0), size=out_size, mode=mode, align_corners=align_corners + input=img_.unsqueeze(0), size=out_size, mode=_m, align_corners=align_corners ) out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out @@ -396,8 +398,9 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t out = _maybe_new_metatensor(img) if transform_info.get(TraceKeys.LAZY_EVALUATION, False): return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + _, _m, _p, _ = resolves_modes(mode, padding_mode) xform = AffineTransform( - normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True + normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True ) img_t = out.to(dtype) transform_t, *_ = convert_to_dst_type(transform, img_t) @@ -468,11 +471,12 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, if transform_info.get(TraceKeys.LAZY_EVALUATION, False): return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info img_t = out.to(dtype) + _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1) zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( recompute_scale_factor=True, input=img_t.unsqueeze(0), scale_factor=list(scale_factor), - mode=mode, + mode=_m, align_corners=align_corners, ).squeeze(0) out, *_ = convert_to_dst_type(zoomed, dst=out, dtype=torch.float32) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 5699995ba0..6dc75141af 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -16,7 +16,7 @@ import warnings from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from contextlib import contextmanager -from functools import wraps +from functools import lru_cache, wraps from inspect import getmembers, isclass from typing import Any @@ -44,10 +44,13 @@ ) from monai.utils import ( GridSampleMode, + GridSamplePadMode, InterpolateMode, + NdimageMode, NumpyPadMode, PostFix, PytorchPadMode, + SplineMode, TraceKeys, ensure_tuple, ensure_tuple_rep, @@ -58,6 +61,7 @@ look_up_option, min_version, optional_import, + pytorch_after, ) from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor @@ -116,6 +120,7 @@ "attach_hook", "sync_meta_info", "reset_ops_id", + "resolves_modes", ] @@ -1843,5 +1848,124 @@ def squarepulse(sig, duty: float = 0.5): return y +def _to_numpy_resample_interp_mode(interp_mode): + ret = look_up_option(str(interp_mode), SplineMode, default=None) + if ret is not None: + return int(ret) + _mapping = { + InterpolateMode.NEAREST: SplineMode.ZERO, + InterpolateMode.NEAREST_EXACT: SplineMode.ZERO, + InterpolateMode.LINEAR: SplineMode.ONE, + InterpolateMode.BILINEAR: SplineMode.ONE, + InterpolateMode.TRILINEAR: SplineMode.ONE, + InterpolateMode.BICUBIC: SplineMode.THREE, + InterpolateMode.AREA: SplineMode.ZERO, + } + ret = look_up_option(str(interp_mode), _mapping, default=None) + if ret is not None: + return ret + return look_up_option(str(interp_mode), list(_mapping) + list(SplineMode)) # for better error msg + + +def _to_torch_resample_interp_mode(interp_mode): + ret = look_up_option(str(interp_mode), InterpolateMode, default=None) + if ret is not None: + return ret + _mapping = { + SplineMode.ZERO: InterpolateMode.NEAREST_EXACT if pytorch_after(1, 11) else InterpolateMode.NEAREST, + SplineMode.ONE: InterpolateMode.LINEAR, + SplineMode.THREE: InterpolateMode.BICUBIC, + } + ret = look_up_option(str(interp_mode), _mapping, default=None) + if ret is not None: + return ret + return look_up_option(str(interp_mode), list(_mapping) + list(InterpolateMode)) + + +def _to_numpy_resample_padding_mode(m): + ret = look_up_option(str(m), NdimageMode, default=None) + if ret is not None: + return ret + _mapping = { + GridSamplePadMode.ZEROS: NdimageMode.CONSTANT, + GridSamplePadMode.BORDER: NdimageMode.NEAREST, + GridSamplePadMode.REFLECTION: NdimageMode.REFLECT, + } + ret = look_up_option(str(m), _mapping, default=None) + if ret is not None: + return ret + return look_up_option(str(m), list(_mapping) + list(NdimageMode)) + + +def _to_torch_resample_padding_mode(m): + ret = look_up_option(str(m), GridSamplePadMode, default=None) + if ret is not None: + return ret + _mapping = { + NdimageMode.CONSTANT: GridSamplePadMode.ZEROS, + NdimageMode.GRID_CONSTANT: GridSamplePadMode.ZEROS, + NdimageMode.NEAREST: GridSamplePadMode.BORDER, + NdimageMode.REFLECT: GridSamplePadMode.REFLECTION, + NdimageMode.WRAP: GridSamplePadMode.REFLECTION, + NdimageMode.GRID_WRAP: GridSamplePadMode.REFLECTION, + NdimageMode.GRID_MIRROR: GridSamplePadMode.REFLECTION, + } + ret = look_up_option(str(m), _mapping, default=None) + if ret is not None: + return ret + return look_up_option(str(m), list(_mapping) + list(GridSamplePadMode)) + + +@lru_cache(None) +def resolves_modes( + interp_mode: str | None = "constant", padding_mode="zeros", backend=TransformBackends.TORCH, **kwargs +): + """ + Automatically adjust the resampling interpolation mode and padding mode, + so that they are compatible with the corresponding API of the `backend`. + Depending on the availability of the backends, when there's no exact + equivalent, a similar mode is returned. + + Args: + interp_mode: interpolation mode. + padding_mdoe: padding mode. + backend: optional backend of `TransformBackends`. If None, the backend will be decided from `interp_mode`. + kwargs: additional keyword arguments. currently support ``torch_interpolate_spatial_nd``, to provide + additional information to determine ``linear``, ``bilinear`` and ``trilinear``; + ``use_compiled`` to use MONAI's precompiled backend (pytorch c++ extensions), default to ``False``. + """ + _interp_mode, _padding_mode, _kwargs = None, None, (kwargs or {}).copy() + if backend is None: # infer backend + backend = ( + TransformBackends.NUMPY + if look_up_option(str(interp_mode), SplineMode, default=None) is not None + else TransformBackends.TORCH + ) + if backend == TransformBackends.NUMPY: + _interp_mode = _to_numpy_resample_interp_mode(interp_mode) + _padding_mode = _to_numpy_resample_padding_mode(padding_mode) + return backend, _interp_mode, _padding_mode, _kwargs + _interp_mode = _to_torch_resample_interp_mode(interp_mode) + _padding_mode = _to_torch_resample_padding_mode(padding_mode) + if str(_interp_mode).endswith("linear"): + nd = _kwargs.pop("torch_interpolate_spatial_nd", 2) + if nd == 1: + _interp_mode = InterpolateMode.LINEAR + elif nd == 3: + _interp_mode = InterpolateMode.TRILINEAR + else: + _interp_mode = InterpolateMode.BILINEAR # torch grid_sample bilinear is trilinear in 3D + if not _kwargs.pop("use_compiled", False): + return backend, _interp_mode, _padding_mode, _kwargs + _padding_mode = 1 if _padding_mode == "reflection" else _padding_mode + if _interp_mode == "bicubic": + _interp_mode = 3 + elif str(_interp_mode).endswith("linear"): + _interp_mode = 1 + else: + _interp_mode = GridSampleMode(_interp_mode) + return backend, _interp_mode, _padding_mode, _kwargs + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_resize.py b/tests/test_resize.py index 4925c441de..97a8f8dab2 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -29,6 +29,8 @@ TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (2, 4, 6)] +TEST_CASE_2_1 = [{"spatial_size": 6, "mode": 1, "align_corners": True}, (2, 4, 6)] + TEST_CASE_3 = [{"spatial_size": 15, "anti_aliasing": True}, (6, 10, 15)] TEST_CASE_4 = [{"spatial_size": 6, "anti_aliasing": True, "anti_aliasing_sigma": 2.0}, (2, 4, 6)] @@ -108,7 +110,7 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): np.abs(good - expected.size) / float(expected.size), diff_t, f"at most {diff_t} percent mismatch " ) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_1, TEST_CASE_3, TEST_CASE_4]) def test_longest_shape(self, input_param, expected_shape): input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) input_param["size_mode"] = "longest" diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 969db401ed..95c63e65f7 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -22,7 +22,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Rotate from tests.lazy_transforms_utils import test_resampler_lazy -from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion +from tests.utils import HAS_CUPY, TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion TEST_CASES_2D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: @@ -31,6 +31,8 @@ TEST_CASES_2D.append((p, -np.pi / 4.5, True, "nearest", "border" if USE_COMPILED else "reflection", False)) TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) + if HAS_CUPY: # 1 and cuda image requires cupy + TEST_CASES_2D.append((p, -np.pi / 2, False, 1, "constant", True)) TEST_CASES_3D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: @@ -39,6 +41,8 @@ TEST_CASES_3D.append((p, -np.pi / 4.5, True, "nearest", "border" if USE_COMPILED else "reflection", False)) TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) + if HAS_CUPY: + TEST_CASES_3D.append((p, -np.pi / 2, False, 1, "zeros", False)) TEST_CASES_SHAPE_3D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: @@ -143,7 +147,7 @@ def test_ill_case(self): rotate_fn = Rotate(10, keep_size=False) with self.assertRaises(ValueError): # wrong mode - rotate_fn(p(self.imt[0]), mode="trilinear") + rotate_fn(p(self.imt[0]), mode="trilinear_spell_error") if __name__ == "__main__": diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 5f3ee85aa1..0c3f67713d 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -238,7 +238,7 @@ ) TESTS.append( # 5D input [ - {"pixdim": 0.5, "padding_mode": "zeros", "mode": "nearest", "scale_extent": True}, + {"pixdim": 0.5, "padding_mode": "constant", "mode": "nearest", "scale_extent": True}, torch.ones((1, 368, 336, 368)), # data torch.tensor( [ diff --git a/tests/test_zoom.py b/tests/test_zoom.py index cb25abc29f..b614acc9e4 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -33,6 +33,7 @@ (1.5, "nearest", True), (1.5, "nearest", False), (0.8, "bilinear"), + (0.8, 1), (0.8, "area"), (1.5, "nearest", False, True), (0.8, "area", False, True), @@ -70,7 +71,7 @@ def test_correct_results(self, zoom, mode, *_): zoomed = zoom_fn(im) test_local_inversion(zoom_fn, zoomed, im) _order = 0 - if mode.endswith("linear"): + if mode == 1 or mode.endswith("linear"): _order = 1 expected = [] for channel in self.imt[0]: