Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GaussianSmooth as antialiasing filter in Resize #4249

Merged
merged 1 commit into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij, normalize_transform
from monai.transforms.croppad.array import CenterSpatialCrop, Pad
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform
from monai.transforms.utils import (
create_control_grid,
Expand Down Expand Up @@ -622,6 +623,15 @@ class Resize(Transform):
align_corners: This only has an effect when mode is
'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
anti_aliasing: bool
Whether to apply a Gaussian filter to smooth the image prior
to downsampling. It is crucial to filter when downsampling
the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
anti_aliasing_sigma: {float, tuple of floats}, optional
Standard deviation for Gaussian filtering used when anti-aliasing.
By default, this value is chosen as (s - 1) / 2 where s is the
downsampling factor, where s > 1. For the up-size case, s < 1, no
anti-aliasing is performed prior to rescaling.
"""

backend = [TransformBackends.TORCH]
Expand All @@ -632,17 +642,23 @@ def __init__(
size_mode: str = "all",
mode: Union[InterpolateMode, str] = InterpolateMode.AREA,
align_corners: Optional[bool] = None,
anti_aliasing: bool = False,
anti_aliasing_sigma: Union[Sequence[float], float, None] = None,
) -> None:
self.size_mode = look_up_option(size_mode, ["all", "longest"])
self.spatial_size = spatial_size
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
self.align_corners = align_corners
self.anti_aliasing = anti_aliasing
self.anti_aliasing_sigma = anti_aliasing_sigma

def __call__(
self,
img: NdarrayOrTensor,
mode: Optional[Union[InterpolateMode, str]] = None,
align_corners: Optional[bool] = None,
anti_aliasing: Optional[bool] = None,
anti_aliasing_sigma: Union[Sequence[float], float, None] = None,
) -> NdarrayOrTensor:
"""
Args:
Expand All @@ -653,11 +669,23 @@ def __call__(
align_corners: This only has an effect when mode is
'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
anti_aliasing: bool, optional
Whether to apply a Gaussian filter to smooth the image prior
to downsampling. It is crucial to filter when downsampling
the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
anti_aliasing_sigma: {float, tuple of floats}, optional
Standard deviation for Gaussian filtering used when anti-aliasing.
By default, this value is chosen as (s - 1) / 2 where s is the
downsampling factor, where s > 1. For the up-size case, s < 1, no
anti-aliasing is performed prior to rescaling.

Raises:
ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions.

"""
anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing
anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma

img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
if self.size_mode == "all":
input_ndim = img_.ndim - 1 # spatial ndim
Expand All @@ -677,6 +705,20 @@ def __call__(
raise ValueError("spatial_size must be an int number if size_mode is 'longest'.")
scale = self.spatial_size / max(img_size)
spatial_size_ = tuple(int(round(s * scale)) for s in img_size)

if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])):
factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_))
if anti_aliasing_sigma is None:
# if sigma is not given, use the default sigma in skimage.transform.resize
anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist()
else:
# if sigma is given, use the given value for downsampling axis
anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(spatial_size_)))
for axis in range(len(spatial_size_)):
anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1)
anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma)
img_ = anti_aliasing_filter(img_)

resized = torch.nn.functional.interpolate(
input=img_.unsqueeze(0),
size=spatial_size_,
Expand Down
33 changes: 27 additions & 6 deletions tests/test_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import skimage.transform
import torch
from parameterized import parameterized

from monai.transforms import Resize
Expand All @@ -24,6 +25,10 @@

TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (2, 4, 6)]

TEST_CASE_3 = [{"spatial_size": 15, "anti_aliasing": True}, (6, 10, 15)]

TEST_CASE_4 = [{"spatial_size": 6, "anti_aliasing": True, "anti_aliasing_sigma": 2.0}, (2, 4, 6)]


class TestResize(NumpyImageTestCase2D):
def test_invalid_inputs(self):
Expand All @@ -36,28 +41,44 @@ def test_invalid_inputs(self):
resize(self.imt[0])

@parameterized.expand(
[((32, -1), "area"), ((32, 32), "area"), ((32, 32, 32), "trilinear"), ((256, 256), "bilinear")]
[
((32, -1), "area", True),
((32, 32), "area", False),
((32, 32, 32), "trilinear", True),
((256, 256), "bilinear", False),
]
)
def test_correct_results(self, spatial_size, mode):
resize = Resize(spatial_size, mode=mode)
def test_correct_results(self, spatial_size, mode, anti_aliasing):
resize = Resize(spatial_size, mode=mode, anti_aliasing=anti_aliasing)
_order = 0
if mode.endswith("linear"):
_order = 1
if spatial_size == (32, -1):
spatial_size = (32, 64)
expected = [
skimage.transform.resize(
channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False
channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing
)
for channel in self.imt[0]
]

expected = np.stack(expected).astype(np.float32)
for p in TEST_NDARRAYS:
out = resize(p(self.imt[0]))
assert_allclose(out, expected, type_test=False, atol=0.9)
if not anti_aliasing:
assert_allclose(out, expected, type_test=False, atol=0.9)
else:
# skimage uses reflect padding for anti-aliasing filter.
# Our implementation reuses GaussianSmooth() as anti-aliasing filter, which uses zero padding instead.
# Thus their results near the image boundary will be different.
if isinstance(out, torch.Tensor):
out = out.cpu().detach().numpy()
good = np.sum(np.isclose(expected, out, atol=0.9))
self.assertLessEqual(
np.abs(good - expected.size) / float(expected.size), 0.2, "at most 20 percent mismatch "
)

@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_longest_shape(self, input_param, expected_shape):
input_data = np.random.randint(0, 2, size=[3, 4, 7, 10])
input_param["size_mode"] = "longest"
Expand Down