From 14a37dc8fe04527c6c70ddcd3ecaae0dd20487b1 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Tue, 10 May 2022 05:25:06 -0400 Subject: [PATCH] Add GaussianSmooth as antialiasing filter in Resize reformat Signed-off-by: Can Zhao --- monai/transforms/spatial/array.py | 42 +++++++++++++++++++++++++++++++ tests/test_resize.py | 33 +++++++++++++++++++----- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6b67762b95..65df5d2b1b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -26,6 +26,7 @@ from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij, normalize_transform from monai.transforms.croppad.array import CenterSpatialCrop, Pad +from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( create_control_grid, @@ -622,6 +623,15 @@ class Resize(Transform): align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. """ backend = [TransformBackends.TORCH] @@ -632,17 +642,23 @@ def __init__( size_mode: str = "all", mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, + anti_aliasing: bool = False, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None, ) -> None: self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma def __call__( self, img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None, ) -> NdarrayOrTensor: """ Args: @@ -653,11 +669,23 @@ def __call__( align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ + anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing + anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma + img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if self.size_mode == "all": input_ndim = img_.ndim - 1 # spatial ndim @@ -677,6 +705,20 @@ def __call__( raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) + + if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): + factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_)) + if anti_aliasing_sigma is None: + # if sigma is not given, use the default sigma in skimage.transform.resize + anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist() + else: + # if sigma is given, use the given value for downsampling axis + anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(spatial_size_))) + for axis in range(len(spatial_size_)): + anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1) + anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) + img_ = anti_aliasing_filter(img_) + resized = torch.nn.functional.interpolate( input=img_.unsqueeze(0), size=spatial_size_, diff --git a/tests/test_resize.py b/tests/test_resize.py index 06246b2358..cb24cf2cc3 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -13,6 +13,7 @@ import numpy as np import skimage.transform +import torch from parameterized import parameterized from monai.transforms import Resize @@ -24,6 +25,10 @@ TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (2, 4, 6)] +TEST_CASE_3 = [{"spatial_size": 15, "anti_aliasing": True}, (6, 10, 15)] + +TEST_CASE_4 = [{"spatial_size": 6, "anti_aliasing": True, "anti_aliasing_sigma": 2.0}, (2, 4, 6)] + class TestResize(NumpyImageTestCase2D): def test_invalid_inputs(self): @@ -36,10 +41,15 @@ def test_invalid_inputs(self): resize(self.imt[0]) @parameterized.expand( - [((32, -1), "area"), ((32, 32), "area"), ((32, 32, 32), "trilinear"), ((256, 256), "bilinear")] + [ + ((32, -1), "area", True), + ((32, 32), "area", False), + ((32, 32, 32), "trilinear", True), + ((256, 256), "bilinear", False), + ] ) - def test_correct_results(self, spatial_size, mode): - resize = Resize(spatial_size, mode=mode) + def test_correct_results(self, spatial_size, mode, anti_aliasing): + resize = Resize(spatial_size, mode=mode, anti_aliasing=anti_aliasing) _order = 0 if mode.endswith("linear"): _order = 1 @@ -47,7 +57,7 @@ def test_correct_results(self, spatial_size, mode): spatial_size = (32, 64) expected = [ skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False + channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing ) for channel in self.imt[0] ] @@ -55,9 +65,20 @@ def test_correct_results(self, spatial_size, mode): expected = np.stack(expected).astype(np.float32) for p in TEST_NDARRAYS: out = resize(p(self.imt[0])) - assert_allclose(out, expected, type_test=False, atol=0.9) + if not anti_aliasing: + assert_allclose(out, expected, type_test=False, atol=0.9) + else: + # skimage uses reflect padding for anti-aliasing filter. + # Our implementation reuses GaussianSmooth() as anti-aliasing filter, which uses zero padding instead. + # Thus their results near the image boundary will be different. + if isinstance(out, torch.Tensor): + out = out.cpu().detach().numpy() + good = np.sum(np.isclose(expected, out, atol=0.9)) + self.assertLessEqual( + np.abs(good - expected.size) / float(expected.size), 0.2, "at most 20 percent mismatch " + ) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_longest_shape(self, input_param, expected_shape): input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) input_param["size_mode"] = "longest"