diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f9667402c9..a2f63a7482 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -91,24 +91,33 @@ class RandGaussianNoise(RandomizableTransform): mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. dtype: output data type, if None, same as input image. defaults to float32. + sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype: DtypeLike = np.float32) -> None: + def __init__( + self, + prob: float = 0.1, + mean: float = 0.0, + std: float = 0.1, + dtype: DtypeLike = np.float32, + sample_std: bool = True, + ) -> None: RandomizableTransform.__init__(self, prob) self.mean = mean self.std = std self.dtype = dtype self.noise: np.ndarray | None = None + self.sample_std = sample_std def randomize(self, img: NdarrayOrTensor, mean: float | None = None) -> None: super().randomize(None) if not self._do_transform: return None - rand_std = self.R.uniform(0, self.std) - noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape) + std = self.R.uniform(0, self.std) if self.sample_std else self.std + noise = self.R.normal(self.mean if mean is None else mean, std, size=img.shape) # noise is float64 array, convert to the output dtype to save memory self.noise, *_ = convert_data_type(noise, dtype=self.dtype) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 058ef87b95..7e93464e64 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -172,7 +172,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`. - Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if want to add + Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if you want to add different noise for every field, please use this transform separately. Args: @@ -183,6 +183,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): std: Standard deviation (spread) of distribution. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. + sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std. """ backend = RandGaussianNoise.backend @@ -195,10 +196,11 @@ def __init__( std: float = 0.1, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, + sample_std: bool = True, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype) + self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py index a56e54fe31..233b4dd1b6 100644 --- a/tests/test_rand_gaussian_noise.py +++ b/tests/test_rand_gaussian_noise.py @@ -22,22 +22,24 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append(("test_zero_mean", p, 0, 0.1)) - TESTS.append(("test_non_zero_mean", p, 1, 0.5)) + TESTS.append(("test_zero_mean", p, 0, 0.1, True)) + TESTS.append(("test_non_zero_mean", p, 1, 0.5, True)) + TESTS.append(("test_no_sample_std", p, 1, 0.5, False)) class TestRandGaussianNoise(NumpyImageTestCase2D): @parameterized.expand(TESTS) - def test_correct_results(self, _, im_type, mean, std): + def test_correct_results(self, _, im_type, mean, std, sample_std): seed = 0 - gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std) + gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std, sample_std=sample_std) gaussian_fn.set_random_state(seed) im = im_type(self.imt) noised = gaussian_fn(im) np.random.seed(seed) np.random.random() - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + _std = np.random.uniform(0, std) if sample_std else std + expected = self.imt + np.random.normal(mean, _std, size=self.imt.shape) if isinstance(noised, torch.Tensor): noised = noised.cpu() np.testing.assert_allclose(expected, noised, atol=1e-5) diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index bcbed98b5a..e3df196be2 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -22,8 +22,9 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1]) - TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5]) + TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1, True]) + TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5, True]) + TESTS.append(["test_no_sample_std", p, ["img1", "img2"], 1, 0.5, False]) seed = 0 @@ -31,15 +32,18 @@ class TestRandGaussianNoised(NumpyImageTestCase2D): @parameterized.expand(TESTS) - def test_correct_results(self, _, im_type, keys, mean, std): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) + def test_correct_results(self, _, im_type, keys, mean, std, sample_std): + gaussian_fn = RandGaussianNoised( + keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64, sample_std=sample_std + ) gaussian_fn.set_random_state(seed) im = im_type(self.imt) noised = gaussian_fn({k: im for k in keys}) np.random.seed(seed) # simulate the randomize() of transform np.random.random() - noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + _std = np.random.uniform(0, std) if sample_std else std + noise = np.random.normal(mean, _std, size=self.imt.shape) for k in keys: expected = self.imt + noise if isinstance(noised[k], torch.Tensor):