Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/dev' into 4855-lazy-resampling…
Browse files Browse the repository at this point in the history
…-impl

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jan 23, 2023
2 parents 7518371 + e279463 commit 5e58297
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@
Lambda,
MapLabelValue,
RandCuCIM,
RandIdentity,
RandImageFilter,
RandLambda,
RemoveRepeatedChannel,
Expand Down
15 changes: 14 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections.abc import Mapping, Sequence
from copy import deepcopy
from functools import partial
from typing import Callable
from typing import Any, Callable

import numpy as np
import torch
Expand Down Expand Up @@ -75,6 +75,7 @@

__all__ = [
"Identity",
"RandIdentity",
"AsChannelFirst",
"AsChannelLast",
"AddChannel",
Expand Down Expand Up @@ -128,6 +129,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return img


class RandIdentity(RandomizableTrait):
"""
Do nothing to the data. This transform is random, so can be used to stop the caching of any
subsequent transforms.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, data: Any) -> Any:
return data


@deprecated(since="0.8", msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.")
class AsChannelFirst(Transform):
"""
Expand Down
49 changes: 49 additions & 0 deletions tests/test_randidentity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import monai.transforms as mt
from monai.data import CacheDataset
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class T(mt.Transform):
def __call__(self, x):
return x * 2


class TestIdentity(NumpyImageTestCase2D):
def test_identity(self):
for p in TEST_NDARRAYS:
img = p(self.imt)
identity = mt.RandIdentity()
assert_allclose(img, identity(img))

def test_caching(self, init=1, expect=4, expect_pre_cache=2):
# check that we get the correct result (two lots of T so should get 4)
x = init
transforms = mt.Compose([T(), mt.RandIdentity(), T()])
self.assertEqual(transforms(x), expect)

# check we get correct result with CacheDataset
x = [init]
ds = CacheDataset(x, transforms)
self.assertEqual(ds[0], expect)

# check that the cached value is correct
self.assertEqual(ds._cache[0], expect_pre_cache)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5e58297

Please sign in to comment.