From 577bb65626c4795e932e909d31774b5e15b3488f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 21 Jun 2022 12:48:01 +0800 Subject: [PATCH 01/47] [DLMED] adapt Pad transform for MetaTensor Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 126 ++++++++++++++++++------------ monai/utils/__init__.py | 1 - monai/utils/type_conversion.py | 100 +++++++----------------- 3 files changed, 104 insertions(+), 123 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6537cf3e21..9109ce6490 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -23,11 +23,14 @@ from monai.config import IndexSelection from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_obj import get_track_meta from monai.data.utils import get_random_patch, get_valid_patch_size +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, convert_pad_mode, + create_translate, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -46,89 +49,112 @@ fall_back_tuple, look_up_option, ) -from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type +from monai.utils.enums import TraceKeys, TransformBackends +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor __all__ = [ - "Pad", - "SpatialPad", "BorderPad", - "DivisiblePad", - "SpatialCrop", - "CenterSpatialCrop", + "BoundingRect", "CenterScaleCrop", - "RandSpatialCrop", + "CenterSpatialCrop", + "CropForeground", + "Pad", + "RandCropByLabelClasses", + "RandCropByPosNegLabel", "RandScaleCrop", + "RandSpatialCrop", "RandSpatialCropSamples", - "CropForeground", "RandWeightedCrop", - "RandCropByPosNegLabel", - "RandCropByLabelClasses", "ResizeWithPadOrCrop", - "BoundingRect", + "SpatialCrop", + "SpatialPad", ] -class Pad(Transform): +class Pad(InvertibleTransform): """ Perform padding for a given an amount of padding in each dimension. - If input is `torch.Tensor`, `torch.nn.functional.pad` will be used, otherwise, `np.pad` will be used. + + `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, + in which case `np.pad` will be used. Args: to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + if None, must provide in the `__call__` at runtime. + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - kwargs: other arguments for the `np.pad` or `torch.pad` function. - note that `np.pad` treats channel dimension as the first dimension. - """ + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + kwargs: other arguments for the `torch.pad` function. - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + """ def __init__( self, - to_pad: List[Tuple[int, int]], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + to_pad: Optional[List[Tuple[int, int]]] = None, + mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, **kwargs, ) -> None: self.to_pad = to_pad self.mode = mode self.kwargs = kwargs - @staticmethod - def _np_pad(img: np.ndarray, all_pad_width, mode, **kwargs) -> np.ndarray: - return np.pad(img, all_pad_width, mode=mode, **kwargs) # type: ignore - - @staticmethod - def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: - pt_pad_width = [val for sublist in all_pad_width[1:] for val in sublist[::-1]][::-1] - # torch.pad expects `[B, C, H, W, [D]]` shape - return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) - def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + self, + img: torch.Tensor, + to_pad: Optional[List[Tuple[int, int]]] = None, + mode: Optional[Union[PytorchPadMode, str]] = None, + **kwargs, + ) -> torch.Tensor: """ Args: - img: data to be transformed, assuming `img` is channel-first and - padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"`` or ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. + to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. + default to `self.to_pad`. + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + default to `self.mode`. + kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. """ - if not np.asarray(self.to_pad).any(): + to_pad_ = self.to_pad if to_pad is None else to_pad + mode_ = self.mode if mode is None else mode + if to_pad is None: + raise ValueError("must provde `to_pad` to execute padding.") + kwargs_ = dict(self.kwargs) + kwargs_.update(kwargs) + + img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) + + if not np.asarray(to_pad_).any(): # all zeros, skip padding - return img - mode = convert_pad_mode(dst=img, mode=mode or self.mode).value - pad = self._pt_pad if isinstance(img, torch.Tensor) else self._np_pad - return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore + return img_t + + mode_ = convert_pad_mode(dst=img_t, mode=mode_).value + pad_width = [val for sublist in to_pad_[1:] for val in sublist[::-1]][::-1] + # torch.pad expects `[B, C, H, W, [D]]` shape + img_t = pad_pt(img_t.unsqueeze(0), pad_width, mode=mode_, **kwargs_).squeeze(0) + + if get_track_meta(): + spatial_rank = max(len(img_t.affine) - 1, 1) + to_shift = [-s[0] for s in to_pad_[1:]] # skipping the channel pad + mat = create_translate(spatial_rank, to_shift) + img_t.meta["affine"] = img_t.affine @ convert_to_dst_type(mat, img_t.affine)[0] + self.push_transform(img_t, extra_info={"padded": to_pad}) + return img_t + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + padded = transform[TraceKeys.EXTRA_INFO]["padded"] + if padded[0][0] != 0 or padded[0][1] != 0: + raise NotImplementedError( + "Inverse uses SpatialCrop, which hasn't yet been extended to crop channels. Trivial change." + ) + roi_start = [i[0] for i in padded[1:]] + roi_end = [i - j[1] for i, j in zip(data.shape[1:], padded[1:])] + cropper = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + with cropper.trace_transform(False): + return cropper(data) class SpatialPad(Transform): diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index f53cfdaef0..33b2a5fa2a 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -93,7 +93,6 @@ convert_to_cupy, convert_to_dst_type, convert_to_list, - convert_to_meta_tensor, convert_to_numpy, convert_to_tensor, dtype_numpy_to_torch, diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 2d88a269fe..312e08b807 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -34,14 +34,13 @@ "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", - "convert_to_meta_tensor", "convert_to_dst_type", ] def get_numpy_dtype_from_string(dtype: str) -> np.dtype: """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" - return np.zeros([], dtype=dtype).dtype # type: ignore + return np.empty([], dtype=dtype).dtype # type: ignore def get_torch_dtype_from_string(dtype: str) -> torch.dtype: @@ -51,12 +50,12 @@ def get_torch_dtype_from_string(dtype: str) -> torch.dtype: def dtype_torch_to_numpy(dtype: torch.dtype) -> np.dtype: """Convert a torch dtype to its numpy equivalent.""" - return torch.zeros([], dtype=dtype).numpy().dtype # type: ignore + return torch.empty([], dtype=dtype).numpy().dtype # type: ignore def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: """Convert a numpy dtype to its torch equivalent.""" - return torch.from_numpy(np.zeros([], dtype=dtype)).dtype + return torch.from_numpy(np.empty([], dtype=dtype)).dtype def get_equivalent_dtype(dtype, data_type): @@ -99,11 +98,15 @@ def get_dtype(data: Any): def convert_to_tensor( - data, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = False + data, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = False, + track_meta: bool = False, ): """ - Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, - recursively check every item and convert it to PyTorch Tensor. + Utility to convert the input data to a PyTorch Tensor, if tracking meta, convert to `MetaTensor`. + If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -113,13 +116,21 @@ def convert_to_tensor( device: target device to put the converted Tensor data. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. - + track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. + default to `False`. """ + def _convert_tensor(tensor): + if not isinstance(tensor, torch.Tensor): + tensor = torch.as_tensor(tensor) + if track_meta and not isinstance(tensor, monai.data.MetaTensor): + return monai.data.MetaTensor(tensor) + if not track_meta and isinstance(tensor, monai.data.MetaTensor): + return tensor.as_tensor(tensor) + return tensor + if isinstance(data, torch.Tensor): - if isinstance(data, monai.data.MetaTensor): - data = data.as_tensor() - return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore + return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format) if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 @@ -128,74 +139,21 @@ def convert_to_tensor( # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims if data.ndim > 0: data = np.ascontiguousarray(data) - return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + return _convert_tensor(data, dtype=dtype, device=device) elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): - return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + return _convert_tensor(data, dtype=dtype, device=device) # type: ignore elif isinstance(data, list): list_ret = [convert_to_tensor(i, dtype=dtype, device=device) for i in data] - return torch.as_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret # type: ignore + return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret elif isinstance(data, tuple): tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data) - return torch.as_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret # type: ignore + return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret # type: ignore elif isinstance(data, dict): return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()} return data -def convert_to_meta_tensor( - data, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = False -): - """ - Utility to convert the input data to a MetaTensor. If passing a dictionary, list or tuple, - recursively check every item and convert it to MetaTensor. - - Args: - data: input data can be MetaTensor, PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. - will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. - for dictionary, list or tuple, convert every item to a Tensor if applicable. - dtype: target data type to when converting to Tensor. - device: target device to put the converted Tensor data. - wrap_sequence: if `False`, then lists will recursively call this function. - E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. - - """ - if isinstance(data, torch.Tensor): - out = data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore - if not isinstance(out, monai.data.MetaTensor): - out = monai.data.MetaTensor(out) - return out - if isinstance(data, np.ndarray): - # skip array of string classes and object, refer to: - # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 - if re.search(r"[SaUO]", data.dtype.str) is None: - # numpy array with 0 dims is also sequence iterable, - # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - if data.ndim > 0: - data = np.ascontiguousarray(data) - return monai.data.MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore - elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): - return monai.data.MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore - elif isinstance(data, list): - list_ret = [convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data] - return ( - monai.data.MetaTensor(torch.as_tensor(list_ret, dtype=dtype, device=device)) # type: ignore - if wrap_sequence - else list_ret - ) - elif isinstance(data, tuple): - tuple_ret = tuple(convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data) - return ( - monai.data.MetaTensor(torch.as_tensor(tuple_ret, dtype=dtype, device=device)) # type: ignore - if wrap_sequence - else tuple_ret - ) - elif isinstance(data, dict): - return {k: convert_to_meta_tensor(v, dtype=dtype, device=device) for k, v in data.items()} - - return data - - def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, @@ -329,11 +287,9 @@ def convert_data_type( data_: NdarrayTensor - if issubclass(output_type, monai.data.MetaTensor): - data_ = convert_to_meta_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) - return data_, orig_type, orig_device if issubclass(output_type, torch.Tensor): - data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) + track_meta = True if issubclass(output_type, monai.data.MetaTensor) else False + data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta) return data_, orig_type, orig_device if issubclass(output_type, np.ndarray): data_ = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence) From 0c0c32d1ab906c3b5aed30473aef005fbfd5f646 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 21 Jun 2022 12:53:07 +0800 Subject: [PATCH 02/47] [DLMED] format code Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 19 ++++++++++--------- monai/utils/type_conversion.py | 1 + 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 9109ce6490..84997db8df 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -53,21 +53,22 @@ from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor __all__ = [ + "Pad", + "SpatialPad", "BorderPad", - "BoundingRect", - "CenterScaleCrop", + "DivisiblePad", + "SpatialCrop", "CenterSpatialCrop", - "CropForeground", - "Pad", - "RandCropByLabelClasses", - "RandCropByPosNegLabel", - "RandScaleCrop", + "CenterScaleCrop", "RandSpatialCrop", + "RandScaleCrop", "RandSpatialCropSamples", + "CropForeground", "RandWeightedCrop", + "RandCropByPosNegLabel", + "RandCropByLabelClasses", "ResizeWithPadOrCrop", - "SpatialCrop", - "SpatialPad", + "BoundingRect", ] diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 312e08b807..f2bc61decb 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -120,6 +120,7 @@ def convert_to_tensor( default to `False`. """ + def _convert_tensor(tensor): if not isinstance(tensor, torch.Tensor): tensor = torch.as_tensor(tensor) From 63e36b6cb41e163024729010534cd9363c6356dc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 11:24:15 +0800 Subject: [PATCH 03/47] [DLMED] update inverse and spatial_pad Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 55 +++---- monai/transforms/inverse.py | 249 +++++++++++++++++++++++------- tests/test_spatial_pad.py | 22 +-- 3 files changed, 228 insertions(+), 98 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 84997db8df..c019c14731 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -24,6 +24,7 @@ from monai.config import IndexSelection from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Randomizable, Transform @@ -89,6 +90,8 @@ class Pad(InvertibleTransform): """ + backend = [TransformBackends.TORCH] + def __init__( self, to_pad: Optional[List[Tuple[int, int]]] = None, @@ -137,13 +140,16 @@ def __call__( img_t = pad_pt(img_t.unsqueeze(0), pad_width, mode=mode_, **kwargs_).squeeze(0) if get_track_meta(): - spatial_rank = max(len(img_t.affine) - 1, 1) - to_shift = [-s[0] for s in to_pad_[1:]] # skipping the channel pad - mat = create_translate(spatial_rank, to_shift) - img_t.meta["affine"] = img_t.affine @ convert_to_dst_type(mat, img_t.affine)[0] - self.push_transform(img_t, extra_info={"padded": to_pad}) + self._update_meta(tensor=img_t, to_pad=to_pad_) + self.push_transform(img_t, extra_info={"padded": to_pad_}) return img_t + def _update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): + spatial_rank = max(len(tensor.affine) - 1, 1) + to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad + mat = create_translate(spatial_rank, to_shift) + tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) padded = transform[TraceKeys.EXTRA_INFO]["padded"] @@ -158,16 +164,10 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return cropper(data) -class SpatialPad(Transform): +class SpatialPad(Pad): """ Performs padding to the data, symmetric for all sides or all on one side for each dimension. - If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. - Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). - - Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad - for additional details. - Args: spatial_size: the spatial size of output data after padding, if a dimension of the input data size is bigger than the pad size, will not pad that dimension. @@ -176,30 +176,24 @@ class SpatialPad(Transform): `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - kwargs: other arguments for the `np.pad` or `torch.pad` function. - note that `np.pad` treats channel dimension as the first dimension. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + default to `self.mode`. + kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. """ - backend = Pad.backend - def __init__( self, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + mode: Union[PytorchPadMode, str] = NumpyPadMode.CONSTANT, **kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = look_up_option(method, Method) - self.mode = mode - self.kwargs = kwargs + super().__init__(mode=mode, **kwargs) def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_size = fall_back_tuple(self.spatial_size, data_shape) @@ -212,8 +206,11 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + self, + img: torch.Tensor, + mode: Optional[Union[PytorchPadMode, str]] = None, + **kwargs, + ) -> torch.Tensor: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -228,12 +225,8 @@ def __call__( """ data_pad_width = self._determine_data_pad_width(img.shape[1:]) all_pad_width = [(0, 0)] + data_pad_width - if not np.asarray(all_pad_width).any(): - # all zeros, skip padding - return img - padder = Pad(to_pad=all_pad_width, mode=mode or self.mode, **self.kwargs) - return padder(img) + return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) class BorderPad(Transform): diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index bdaa2f9b40..0ec45ccfb9 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -8,8 +8,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os -from typing import Hashable, Mapping, Optional, Tuple +import warnings +from contextlib import contextmanager +from typing import Any, Hashable, Mapping, Optional, Tuple import torch @@ -26,13 +29,10 @@ class TraceableTransform(Transform): `trace_key: list of transforms` to each data dictionary. The ``__call__`` method of this transform class must be implemented so - that the transformation information for each key is stored when - ``__call__`` is called. If the transforms were applied to keys "image" and - "label", there will be two extra keys in the dictionary: "image_transforms" - and "label_transforms" (based on `TraceKeys.KEY_SUFFIX`). Each list - contains a list of the transforms applied to that key. + that the transformation information for each key is stored in ``data.applied_operations`` + when ``__call__`` is called. - The information in ``data[key_transform]`` will be compatible with the + The information in ``data.applied_operations`` will be compatible with the default collate since it only stores strings, numbers and arrays. `tracing` could be enabled by `self.set_tracing` or setting @@ -52,41 +52,206 @@ def trace_key(key: Hashable = None): return TraceKeys.KEY_SUFFIX return str(key) + TraceKeys.KEY_SUFFIX - def push_transform( - self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None - ) -> None: - """Push to a stack of applied transforms for that key.""" + def get_transform_info( + self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> dict: + """ + Return a dictionary with the relevant information pertaining to an applied + transform. - if not self.tracing: - return + Args: + - data: input data. Can be dictionary or MetaTensor. We can use `shape` to + determine the original size of the object (unless that has been given + explicitly, see `orig_size`). + - key: if data is a dictionary, data[key] will be modified + - extra_info: if desired, any extra information pertaining to the applied + transform can be stored in this dictionary. These are often needed for + computing the inverse transformation. + - orig_size: sometimes during the inverse it is useful to know what the size + of the original image was, in which case it can be supplied here. + + Returns: + Dictionary of data pertaining to the applied transformation. + """ info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} if orig_size is not None: info[TraceKeys.ORIG_SIZE] = orig_size - elif key in data and hasattr(data[key], "shape"): + elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"): info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] + elif hasattr(data, "shape"): + info[TraceKeys.ORIG_SIZE] = data.shape[1:] if extra_info is not None: info[TraceKeys.EXTRA_INFO] = extra_info # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) if hasattr(self, "_do_transform"): # RandomizableTransform info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore + return info - if key in data and isinstance(data[key], MetaTensor): - data[key].push_applied_operation(info) - else: - # If this is the first, create list - if self.trace_key(key) not in data: - if not isinstance(data, dict): - data = dict(data) - data[self.trace_key(key)] = [] - data[self.trace_key(key)].append(info) - - def pop_transform(self, data: Mapping, key: Hashable = None): - """Remove the most recent applied transform.""" + def push_transform( + self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> None: + """ + Push to a stack of applied transforms. + + Data can be one of two types: + 1. A `MetaTensor` + 2. A dictionary of data containing arrays/tensors and auxiliary data. In + this case, a key must be supplied (the dictionary-based approach is deprecated). + + If `data` is of type `MetaTensor`, then the applied transform will be added to its internal list. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to its internal list. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms`. (This is deprecated.) + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor`. + + Args: + - data: dictionary of data or `MetaTensor` + - key: if data is a dictionary, data[key] will be modified + - extra_info: if desired, any extra information pertaining to the applied + transform can be stored in this dictionary. These are often needed for + computing the inverse transformation. + - orig_size: sometimes during the inverse it is useful to know what the size + of the original image was, in which case it can be supplied here. + + Returns: + None, but data has been updated to store the applied transformation. + """ if not self.tracing: return - if key in data and isinstance(data[key], MetaTensor): - return data[key].pop_applied_operation() - return data.get(self.trace_key(key), []).pop() + info = self.get_transform_info(data, key, extra_info, orig_size) + + if isinstance(data, MetaTensor): + data.push_applied_operation(info) + elif isinstance(data, Mapping): + if key in data and isinstance(data[key], MetaTensor): + data[key].push_applied_operation(info) + else: + # If this is the first, create list + if self.trace_key(key) not in data: + if not isinstance(data, dict): + data = dict(data) + data[self.trace_key(key)] = [] + data[self.trace_key(key)].append(info) + else: + warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") + + def check_transforms_match(self, transform: Mapping) -> None: + """Check transforms are of same instance.""" + xform_id = transform.get(TraceKeys.ID, "") + if xform_id == id(self): + return + # TraceKeys.NONE to skip the id check + if xform_id == TraceKeys.NONE: + return + xform_name = transform.get(TraceKeys.CLASS_NAME, "") + # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) + if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: + return + raise RuntimeError( + f"Error {self.__class__.__name__} getting the most recently " + f"applied invertible transform {xform_name} {xform_id} != {id(self)}." + ) + + def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): + """ + Get most recent transform. + + Data can be one of two things: + 1. A `MetaTensor` + 2. A dictionary of data containing arrays/tensors and auxiliary data. In + this case, a key must be supplied (the dictionary-based approach is deprecated). + + If `data` is of type `MetaTensor`, then the applied transform will be added to its internal list. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to its internal list. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms`. (This is deprecated.) + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor`. + + Args: + - data: dictionary of data or `MetaTensor` + - key: if data is a dictionary, data[key] will be modified + - check: if true, check that `self` is the same type as the most recently-applied transform. + - pop: if true, remove the transform as it is returned. + + Returns: + Dictionary of most recently applied transform + + Raises: + - RuntimeError: data is neither `MetaTensor` nor dictionary + """ + if not self.tracing: + raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") + if isinstance(data, MetaTensor): + all_transforms = data.applied_operations + elif isinstance(data, Mapping): + if key in data and isinstance(data[key], MetaTensor): + all_transforms = data[key].applied_operations + else: + all_transforms = data[self.trace_key(key)] + else: + raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") + if check: + self.check_transforms_match(all_transforms[-1]) + return all_transforms.pop() if pop else all_transforms[-1] + + def pop_transform(self, data, key: Hashable = None, check: bool = True): + """ + Return and pop the most recent transform. + + Data can be one of two things: + 1. A `MetaTensor` + 2. A dictionary of data containing arrays/tensors and auxilliary data. In + this case, a key must be supplied. + + If `data` is of type `MetaTensor`, then the applied transform will be added to + its internal list. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to + its internal list. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms`. + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor`. + + Args: + - data: dictionary of data or `MetaTensor` + - key: if data is a dictionary, data[key] will be modified + - check: if true, check that `self` is the same type as the most recently-applied transform. + + Returns: + Dictionary of most recently applied transform + + Raises: + - RuntimeError: data is neither `MetaTensor` nor dictionary + """ + return self.get_most_recent_transform(data, key, check, pop=True) + + @contextmanager + def trace_transform(self, to_trace: bool): + """Temporarily set the tracing status of a transform with a context manager.""" + prev = self.tracing + self.tracing = to_trace + yield + self.tracing = prev class InvertibleTransform(TraceableTransform): @@ -103,7 +268,7 @@ class InvertibleTransform(TraceableTransform): different parameters being passed to each label (e.g., different interpolation for image and label). - - the inverse transforms are applied in a last- in-first-out order. As + - the inverse transforms are applied in a last-in-first-out order. As the inverse is applied, its entry is removed from the list detailing the applied transformations. That is to say that during the forward pass, the list of applied transforms grows, and then during the @@ -126,29 +291,7 @@ class InvertibleTransform(TraceableTransform): """ - def check_transforms_match(self, transform: Mapping) -> None: - """Check transforms are of same instance.""" - xform_name = transform.get(TraceKeys.CLASS_NAME, "") - xform_id = transform.get(TraceKeys.ID, "") - if xform_id == id(self): - return - # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: - return - raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") - - def get_most_recent_transform(self, data: Mapping, key: Hashable = None): - """Get most recent transform.""" - if not self.tracing: - raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") - if isinstance(data[key], MetaTensor): - transform = data[key].applied_operations[-1] - else: - transform = data[self.trace_key(key)][-1] - self.check_transforms_match(transform) - return transform - - def inverse(self, data: dict) -> dict: + def inverse(self, data: Any) -> Any: """ Inverse of ``__call__``. diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 4cdeb6d64e..603c407b5d 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -19,7 +19,7 @@ from monai.transforms import SpatialPad from monai.utils.enums import NumpyPadMode, PytorchPadMode from monai.utils.misc import set_determinism -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [] @@ -69,7 +69,7 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): results_2 = [] input_data = self.get_arr(input_shape) # check result is the same regardless of input type - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: padder = SpatialPad(**input_param) r1 = padder(p(input_data)) r2 = padder(p(input_data), mode=input_param["mode"]) @@ -81,19 +81,13 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): torch.testing.assert_allclose(results[0], results[-1], atol=0, rtol=1e-5) def test_pad_kwargs(self): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: input_data = p(np.zeros((3, 8, 4))) - if isinstance(input_data, torch.Tensor): - result = ( - SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) - .cpu() - .numpy() - ) - else: - result = SpatialPad( - spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) - )(img=input_data) - torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) + result = ( + SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) + .cpu() + .numpy() + ) torch.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1, rtol=1e-7, atol=0) From 8f9814be08a067f34a753c14c352d0a538826f8f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 12:00:25 +0800 Subject: [PATCH 04/47] [DLMED] update border pad Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 44 ++++++++++++++----------------- tests/test_border_pad.py | 7 ++--- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index c019c14731..6cd2a273a7 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -188,7 +188,7 @@ def __init__( self, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[PytorchPadMode, str] = NumpyPadMode.CONSTANT, + mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, **kwargs, ) -> None: self.spatial_size = spatial_size @@ -215,12 +215,11 @@ def __call__( Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + default to `self.mode`. + kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. """ data_pad_width = self._determine_data_pad_width(img.shape[1:]) @@ -229,13 +228,12 @@ def __call__( return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) -class BorderPad(Transform): +class BorderPad(Pad): """ Pad the input data by adding specified borders to every dimension. Args: spatial_border: specified size for every spatial border. Any -ve values will be set to 0. It can be 3 shapes: - - single int number, pad all the borders with the same size. - length equals the length of image shape, pad every spatial dimension separately. for example, image shape(CHW) is [1, 4, 4], spatial_border is [2, 1], @@ -255,31 +253,30 @@ class BorderPad(Transform): """ - backend = Pad.backend - def __init__( self, spatial_border: Union[Sequence[int], int], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, **kwargs, ) -> None: self.spatial_border = spatial_border - self.mode = mode - self.kwargs = kwargs + super().__init__(mode=mode, **kwargs) def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + self, + img: torch.Tensor, + mode: Optional[Union[PytorchPadMode, str]] = None, + **kwargs, + ) -> torch.Tensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + default to `self.mode`. + kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. Raises: ValueError: When ``self.spatial_border`` does not contain ints. @@ -306,8 +303,7 @@ def __call__( ) all_pad_width = [(0, 0)] + data_pad_width - padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) - return padder(img) + return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) class DivisiblePad(Transform): diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index b632ff831f..97e463ec3a 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -16,7 +16,7 @@ from monai.transforms import BorderPad from monai.utils import NumpyPadMode -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TEST_CASE_1 = [{"spatial_border": 2, "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 12, 12, 8))] @@ -38,7 +38,7 @@ class TestBorderPad(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_pad_shape(self, input_param, input_data, expected_val): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: padder = BorderPad(**input_param) r1 = padder(p(input_data)) r2 = padder(input_data, mode=input_param["mode"]) @@ -46,9 +46,10 @@ def test_pad_shape(self, input_param, input_data, expected_val): self.assertAlmostEqual(r2.shape, expected_val.shape) def test_pad_kwargs(self): - padder = BorderPad(spatial_border=2, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) + padder = BorderPad(spatial_border=2, mode="constant", value=1) result = padder(np.zeros((3, 8, 4))) np.testing.assert_allclose(result[:, :2, 2:6], np.ones((3, 2, 4))) + result = padder(np.zeros((3, 8, 4)), mode="constant", value=2) np.testing.assert_allclose(result[:, :, :2], np.ones((3, 12, 2)) + 1) From 549fe2aae98c26ee03cf06ff014ba43bd006961a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 12:16:59 +0800 Subject: [PATCH 05/47] [DLMED] update divisible pad Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 49 +++++++++++++++---------------- tests/test_divisible_pad.py | 12 +++----- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6cd2a273a7..4268bdbb35 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -306,7 +306,7 @@ def __call__( return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) -class DivisiblePad(Transform): +class DivisiblePad(Pad): """ Pad the input data, so that the spatial sizes are divisible by `k`. """ @@ -316,8 +316,8 @@ class DivisiblePad(Transform): def __init__( self, k: Union[Sequence[int], int], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, + mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, **kwargs, ) -> None: """ @@ -325,43 +325,42 @@ def __init__( k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - kwargs: other arguments for the `np.pad` or `torch.pad` function. - note that `np.pad` treats channel dimension as the first dimension. + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + default to `self.mode`. + kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. See also :py:class:`monai.transforms.SpatialPad` """ self.k = k - self.mode: NumpyPadMode = NumpyPadMode(mode) self.method: Method = Method(method) - self.kwargs = kwargs + super().__init__(mode=mode, **kwargs) def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + self, + img: torch.Tensor, + mode: Optional[Union[PytorchPadMode, str]] = None, + **kwargs, + ) -> torch.Tensor: """ Args: - img: data to be transformed, assuming `img` is channel-first - and padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + img: data to be transformed, assuming `img` is channel-first and + padding doesn't apply to the channel dim. + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + default to `self.mode`. + kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. """ new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) - spatial_pad = SpatialPad(spatial_size=new_size, method=self.method, mode=mode or self.mode, **self.kwargs) - - return spatial_pad(img) + spatial_pad = SpatialPad(spatial_size=new_size, method=self.method) + data_pad_width = spatial_pad._determine_data_pad_width(img.shape[1:]) + all_pad_width = [(0, 0)] + data_pad_width + return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) class SpatialCrop(Transform): diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index f940636fa8..05f4c57bf0 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -16,11 +16,11 @@ from parameterized import parameterized from monai.transforms import DivisiblePad -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: # pad first dim to be divisible by 7, the second unchanged. TESTS.append([{"k": (7, -1), "mode": "constant"}, p(np.zeros((3, 8, 7))), p(np.zeros((3, 14, 7)))]) @@ -40,13 +40,9 @@ def test_pad_shape(self, input_param, input_data, expected_val): self.assertAlmostEqual(result.shape, expected_val.shape) def test_pad_kwargs(self): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: input_data = p(np.zeros((3, 8, 4))) - if isinstance(input_data, np.ndarray): - result = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))(input_data) - np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) - else: - result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() + result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() torch.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1, rtol=1e-7, atol=0) From 382b3e20e60a00604cadffa995cf0b369776c67f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 13:23:18 +0800 Subject: [PATCH 06/47] [DLMED] update spatial crop Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 ++ monai/transforms/__init__.py | 1 + monai/transforms/croppad/array.py | 103 +++++++++++++++++++++++++----- tests/test_spatial_crop.py | 10 +-- 4 files changed, 99 insertions(+), 21 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 2eb2537b49..4c2f0e5901 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -105,6 +105,12 @@ Crop and Pad :members: :special-members: __call__ +`Crop` +"""""" +.. autoclass:: Crop + :members: + :special-members: __call__ + `SpatialCrop` """"""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCrop.png diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d4f09474de..68ced7ac6d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -16,6 +16,7 @@ BoundingRect, CenterScaleCrop, CenterSpatialCrop, + Crop, CropForeground, DivisiblePad, Pad, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 4268bdbb35..f2ea8fbe4a 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -58,6 +58,7 @@ "SpatialPad", "BorderPad", "DivisiblePad", + "Crop", "SpatialCrop", "CenterSpatialCrop", "CenterScaleCrop", @@ -363,23 +364,15 @@ def __call__( return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) -class SpatialCrop(Transform): +class Crop(InvertibleTransform): """ - General purpose cropper to produce sub-volume region of interest (ROI). - If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. - So the cropped result may be smaller than the expected ROI, and the cropped results of several images may - not have exactly the same shape. - It can support to crop ND spatial (channel-first) data. + Perform crop operation on the input image. - The cropped region can be parameterised in various ways: - - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`) - - a spatial center and size - - the start and end coordinates of the ROI """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] - def __init__( + def compute_slices( self, roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, @@ -388,6 +381,8 @@ def __init__( roi_slices: Optional[Sequence[slice]] = None, ) -> None: """ + Compute the crop slices based on specified `center & size` or `start & end`. + Args: roi_center: voxel coordinates for center of the crop ROI. roi_size: size of the crop ROI, if a dimension of ROI size is bigger than image size, @@ -396,6 +391,7 @@ def __init__( roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image, use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. + """ roi_start_torch: torch.Tensor @@ -427,14 +423,89 @@ def __init__( else: self.slices = [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: + """ + Apply the transform to `img`, assuming `img` is channel-first and + slicing doesn't apply to the channel dim. + """ + if self.slices is None: + raise ValueError("must compute the crop slices first.") + orig_size = img.shape[1:] + sd = len(img.shape[1:]) # spatial dims + slices = list(self.slices) + if len(slices) < sd: + slices += [slice(None)] * (sd - len(slices)) + # Add in the channel (no cropping) + slices = [slice(None)] + slices[:sd] + + img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) + img_t = img_t[tuple(slices)] + if get_track_meta(): + self._update_meta(tensor=img_t, slices=slices) + cropped_from_start = np.asarray([s.indices(o)[0] for s, o in zip(slices[1:], orig_size)]) + cropped_from_end = np.asarray(orig_size) - img_t.shape[1:] - cropped_from_start + cropped = list(chain(*zip(cropped_from_start.tolist(), cropped_from_end.tolist()))) + self.push_transform(img_t, extra_info={"cropped": cropped}) + return img_t + + def _update_meta(self, tensor: MetaTensor, slices: List): + spatial_rank = max(len(tensor.affine) - 1, 1) + to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]] + mat = create_translate(spatial_rank, to_shift) + tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + + def inverse(self, img: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(img) + cropped = transform[TraceKeys.EXTRA_INFO]["cropped"] + # the amount we pad is equal to the amount we cropped in each direction + inverse_transform = BorderPad(cropped) + # Apply inverse transform + with inverse_transform.trace_transform(False): + return inverse_transform(img) + + +class SpatialCrop(Crop): + """ + General purpose cropper to produce sub-volume region of interest (ROI). + If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. + So the cropped result may be smaller than the expected ROI, and the cropped results of several images may + not have exactly the same shape. + It can support to crop ND spatial (channel-first) data. + The cropped region can be parameterised in various ways: + - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`) + - a spatial center and size + - the start and end coordinates of the ROI + """ + + def __init__( + self, + roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_slices: Optional[Sequence[slice]] = None, + ) -> None: + """ + Args: + roi_center: voxel coordinates for center of the crop ROI. + roi_size: size of the crop ROI, if a dimension of ROI size is bigger than image size, + will not crop that dimension of the image. + roi_start: voxel coordinates for start of the crop ROI. + roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image, + use the end coordinate of image. + roi_slices: list of slices for each of the spatial dimensions. + """ + self.compute_slices( + roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices, + ) + + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. + """ - sd = min(len(self.slices), len(img.shape[1:])) # spatial dims - slices = [slice(None)] + self.slices[:sd] - return img[tuple(slices)] + return super().__call__(img=img) class CenterSpatialCrop(Transform): diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index bf1eb11491..ebf4665a23 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import SpatialCrop -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [ [{"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], @@ -37,15 +37,15 @@ class TestSpatialCrop(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): input_data = np.random.randint(0, 2, size=input_shape) results = [] - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS + (None,): + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL + (None,): input_param_mod = { k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() } im = p(input_data) result = SpatialCrop(**input_param_mod)(im) - self.assertEqual(type(im), type(result)) - if isinstance(result, torch.Tensor): + self.assertTrue(isinstance(result, torch.Tensor)) + if isinstance(im, torch.Tensor): self.assertEqual(result.device, im.device) self.assertTupleEqual(result.shape, expected_shape) results.append(result) From d3962e4061664a07a754b362c28d360052a7a852 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 20:33:44 +0800 Subject: [PATCH 07/47] [DLMED] make thread safe Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 32 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index f2ea8fbe4a..8cf71f3a6c 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -372,14 +372,14 @@ class Crop(InvertibleTransform): backend = [TransformBackends.TORCH] + @staticmethod def compute_slices( - self, roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_slices: Optional[Sequence[slice]] = None, - ) -> None: + ): """ Compute the crop slices based on specified `center & size` or `start & end`. @@ -397,8 +397,8 @@ def compute_slices( if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): - raise ValueError("Only slice steps of 1/None are currently supported") - self.slices = list(roi_slices) + raise ValueError("only slice steps of 1/None are currently supported") + return list(roi_slices) else: if roi_center is not None and roi_size is not None: roi_center, *_ = convert_data_type( @@ -406,33 +406,31 @@ def compute_slices( ) roi_size, *_ = convert_to_dst_type(src=roi_size, dst=roi_center, wrap_sequence=True) _zeros = torch.zeros_like(roi_center) - roi_start_torch = maximum(roi_center - floor_divide(roi_size, 2), _zeros) # type: ignore + roi_start_torch = maximum(roi_center - floor_divide(roi_size, 2), _zeros) roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch) else: if roi_start is None or roi_end is None: - raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") + raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.") roi_start_torch, *_ = convert_data_type( data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True ) - roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore + roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True) roi_end_torch = maximum(roi_end_torch, roi_start_torch) # convert to slices (accounting for 1d) if roi_start_torch.numel() == 1: - self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] + return [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] else: - self.slices = [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] + return [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. + """ - if self.slices is None: - raise ValueError("must compute the crop slices first.") orig_size = img.shape[1:] sd = len(img.shape[1:]) # spatial dims - slices = list(self.slices) if len(slices) < sd: slices += [slice(None)] * (sd - len(slices)) # Add in the channel (no cropping) @@ -495,7 +493,7 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. """ - self.compute_slices( + self.slices = self.compute_slices( roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices, ) @@ -505,10 +503,10 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: slicing doesn't apply to the channel dim. """ - return super().__call__(img=img) + return super().__call__(img=img, slices=self.slices) -class CenterSpatialCrop(Transform): +class CenterSpatialCrop(Crop): """ Crop at the center of image with specified ROI size. If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -523,8 +521,6 @@ class CenterSpatialCrop(Transform): the spatial size of output data will be [32, 40, 40]. """ - backend = SpatialCrop.backend - def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size From 57f635aa48dabcbfa3fbcd0c0f396605d1706ffd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 20:42:18 +0800 Subject: [PATCH 08/47] [DLMED] update CenterSpatialCrop Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 9 +++++---- tests/test_center_spatial_crop.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 8cf71f3a6c..acf3dcd5bd 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -524,15 +524,16 @@ class CenterSpatialCrop(Crop): def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. + """ roi_size = fall_back_tuple(self.roi_size, img.shape[1:]) - center = [i // 2 for i in img.shape[1:]] - cropper = SpatialCrop(roi_center=center, roi_size=roi_size) - return cropper(img) + roi_center = [i // 2 for i in img.shape[1:]] + slices = self.compute_slices(roi_center=roi_center, roi_size=roi_size) + return super().__call__(img=img, slices=slices) class CenterScaleCrop(Transform): diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 09f61be2f1..771ba650a9 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms import CenterSpatialCrop TEST_CASE_0 = [{"roi_size": [2, 2, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 3)] @@ -38,13 +39,13 @@ class TestCenterSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) + self.assertTrue(isinstance(result, MetaTensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) + self.assertTrue(isinstance(result, MetaTensor)) np.testing.assert_allclose(result, expected_value) From 41c38b0320e8c8085f59a00bec4fa98a5862bc2e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 20:55:31 +0800 Subject: [PATCH 09/47] [DLMED] update scale crop Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 13 ++++++------- tests/test_center_scale_crop.py | 5 +++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index acf3dcd5bd..4f19d774ea 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -536,7 +536,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: return super().__call__(img=img, slices=slices) -class CenterScaleCrop(Transform): +class CenterScaleCrop(Crop): """ Crop at the center of image with specified scale of ROI size. @@ -546,17 +546,16 @@ class CenterScaleCrop(Transform): """ - backend = CenterSpatialCrop.backend - def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: img_size = img.shape[1:] ndim = len(img_size) - roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - sp_crop = CenterSpatialCrop(roi_size=roi_size) - return sp_crop(img=img) + roi_size = fall_back_tuple([ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)], img_size) + roi_center = [i // 2 for i in img.shape[1:]] + slices = self.compute_slices(roi_center=roi_center, roi_size=roi_size) + return super().__call__(img=img, slices=slices) class RandSpatialCrop(Randomizable, Transform): diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index f22651e3e0..5476321165 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms import CenterScaleCrop TEST_CASE_0 = [{"roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] @@ -38,13 +39,13 @@ class TestCenterScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) + self.assertTrue(isinstance(result, MetaTensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterScaleCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) + self.assertTrue(isinstance(result, MetaTensor)) np.testing.assert_allclose(result, expected_value) From ad85387673e7e5f9dde025a489ffd1122100501e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 22:39:59 +0800 Subject: [PATCH 10/47] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 4f19d774ea..d385d3edc6 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -552,7 +552,9 @@ def __init__(self, roi_scale: Union[Sequence[float], float]): def __call__(self, img: torch.Tensor) -> torch.Tensor: img_size = img.shape[1:] ndim = len(img_size) - roi_size = fall_back_tuple([ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)], img_size) + roi_size = fall_back_tuple( + [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)], img_size, + ) roi_center = [i // 2 for i in img.shape[1:]] slices = self.compute_slices(roi_center=roi_center, roi_size=roi_size) return super().__call__(img=img, slices=slices) From 647b239c99bb6556d17d7bb4f4e705a4da1ff01f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 11:04:06 +0800 Subject: [PATCH 11/47] [DLMED] update random spatial crop Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 32 +++++++++++++++---------------- tests/test_rand_spatial_crop.py | 4 ++-- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index d385d3edc6..9243a8d8ac 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -430,6 +430,7 @@ def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor """ orig_size = img.shape[1:] + slices = list(slices) sd = len(img.shape[1:]) # spatial dims if len(slices) < sd: slices += [slice(None)] * (sd - len(slices)) @@ -524,16 +525,18 @@ class CenterSpatialCrop(Crop): def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size + def compute_slices(self, spatial_size: Sequence[int]): + roi_size = fall_back_tuple(self.roi_size, spatial_size) + roi_center = [i // 2 for i in spatial_size] + return super().compute_slices(roi_center=roi_center, roi_size=roi_size) + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - roi_size = fall_back_tuple(self.roi_size, img.shape[1:]) - roi_center = [i // 2 for i in img.shape[1:]] - slices = self.compute_slices(roi_center=roi_center, roi_size=roi_size) - return super().__call__(img=img, slices=slices) + return super().__call__(img=img, slices=self.compute_slices(img.shape[1:])) class CenterScaleCrop(Crop): @@ -552,15 +555,12 @@ def __init__(self, roi_scale: Union[Sequence[float], float]): def __call__(self, img: torch.Tensor) -> torch.Tensor: img_size = img.shape[1:] ndim = len(img_size) - roi_size = fall_back_tuple( - [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)], img_size, - ) - roi_center = [i // 2 for i in img.shape[1:]] - slices = self.compute_slices(roi_center=roi_center, roi_size=roi_size) - return super().__call__(img=img, slices=slices) + roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] + cropper = CenterSpatialCrop(roi_size=roi_size) + return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) -class RandSpatialCrop(Randomizable, Transform): +class RandSpatialCrop(Randomizable, Crop): """ Crop image with random size or specific size ROI. It can crop at a random position as center or at the image center. And allows to set the minimum and maximum size to limit the randomly generated ROI. @@ -584,8 +584,6 @@ class RandSpatialCrop(Randomizable, Transform): if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`. """ - backend = CenterSpatialCrop.backend - def __init__( self, roi_size: Union[Sequence[int], int], @@ -609,9 +607,9 @@ def randomize(self, img_size: Sequence[int]) -> None: self._size = tuple(self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))) if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + self._slices = get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -620,9 +618,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: if self._size is None: raise RuntimeError("self._size not specified.") if self.random_center: - return img[self._slices] + return super().__call__(img=img, slices=self._slices) cropper = CenterSpatialCrop(self._size) - return cropper(img) + return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) class RandScaleCrop(RandSpatialCrop): diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 8f4bb0fffa..5521ede350 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCrop -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_0 = [ {"roi_size": [3, 3, -1], "random_center": True}, @@ -57,7 +57,7 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: cropper = RandSpatialCrop(**input_param) result = cropper(p(input_data)) roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] From 3c842364045d15829897670b00018b8ebc855d6c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 12:24:13 +0800 Subject: [PATCH 12/47] [DLMED] update random scale crop Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 2 +- tests/test_rand_scale_crop.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 9243a8d8ac..6349f09c95 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -655,7 +655,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index 5d6312002f..58ed65bf0d 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCrop -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -56,13 +56,13 @@ class TestRandScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: result = RandScaleCrop(**input_param)(p(input_data)) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: cropper = RandScaleCrop(**input_param) result = cropper(p(input_data)) roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] @@ -70,7 +70,7 @@ def test_value(self, input_param, input_data): @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: cropper = RandScaleCrop(**input_param) cropper.set_random_state(seed=123) result = cropper(p(input_data)) From 0e1a322d8f5d97682816415ff259ad9148ef790d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 13:50:36 +0800 Subject: [PATCH 13/47] [DLMED] update random spatial crop samples Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 9 +++++++-- tests/test_rand_spatial_crop_samples.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6349f09c95..40a7ddc523 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -50,6 +50,7 @@ fall_back_tuple, look_up_option, ) +from monai.utils import ImageMetaKey as Key from monai.utils.enums import TraceKeys, TransformBackends from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor @@ -726,12 +727,16 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: + def __call__(self, img: torch.Tensor) -> List[torch.Tensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ - return [self.cropper(img) for _ in range(self.num_samples)] + ret = [self.cropper(img) for _ in range(self.num_samples)] + if get_track_meta(): + for i, r in enumerate(ret): + r.meta[Key.PATCH_INDEX] = i + return ret class CropForeground(Transform): diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 18fdf38773..de96908cc6 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropSamples -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, @@ -71,7 +71,7 @@ class TestRandSpatialCropSamples(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last_item): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: xform = RandSpatialCropSamples(**input_param) xform.set_random_state(1234) result = xform(p(input_data)) From b6c48114768cf94680e677a4f1a2e1ac630afbb4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 18:41:46 +0800 Subject: [PATCH 14/47] [DLMED] adjust Pad design Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 106 +++++++++--------------------- 1 file changed, 30 insertions(+), 76 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 40a7ddc523..027d405023 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -104,6 +104,16 @@ def __init__( self.mode = mode self.kwargs = kwargs + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + """ + dynamically compute the pad width according to the spatial shape. + + Args: + spatial_shape: spatial shape of the original image. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + def __call__( self, img: torch.Tensor, @@ -124,9 +134,9 @@ def __call__( """ to_pad_ = self.to_pad if to_pad is None else to_pad + if to_pad_ is None: + to_pad_ = self.compute_pad_width(img.shape[1:]) mode_ = self.mode if mode is None else mode - if to_pad is None: - raise ValueError("must provde `to_pad` to execute padding.") kwargs_ = dict(self.kwargs) kwargs_.update(kwargs) @@ -197,37 +207,23 @@ def __init__( self.method: Method = look_up_option(method, Method) super().__init__(mode=mode, **kwargs) - def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: - spatial_size = fall_back_tuple(self.spatial_size, data_shape) - if self.method == Method.SYMMETRIC: - pad_width = [] - for i, sp_i in enumerate(spatial_size): - width = max(sp_i - data_shape[i], 0) - pad_width.append((width // 2, width - (width // 2))) - return pad_width - return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - - def __call__( - self, - img: torch.Tensor, - mode: Optional[Union[PytorchPadMode, str]] = None, - **kwargs, - ) -> torch.Tensor: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: """ + dynamically compute the pad width according to the spatial shape. + Args: - img: data to be transformed, assuming `img` is channel-first and - padding doesn't apply to the channel dim. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - default to `self.mode`. - kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. + spatial_shape: spatial shape of the original image. """ - data_pad_width = self._determine_data_pad_width(img.shape[1:]) - all_pad_width = [(0, 0)] + data_pad_width - - return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) + spatial_size = fall_back_tuple(self.spatial_size, spatial_shape) + if self.method == Method.SYMMETRIC: + pad_width = [] + for i, sp_i in enumerate(spatial_size): + width = max(sp_i - spatial_shape[i], 0) + pad_width.append((width // 2, width - (width // 2))) + else: + pad_width = [(0, max(sp_i - spatial_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] + return [(0, 0)] + pad_width class BorderPad(Pad): @@ -264,29 +260,7 @@ def __init__( self.spatial_border = spatial_border super().__init__(mode=mode, **kwargs) - def __call__( - self, - img: torch.Tensor, - mode: Optional[Union[PytorchPadMode, str]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Args: - img: data to be transformed, assuming `img` is channel-first and - padding doesn't apply to the channel dim. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - default to `self.mode`. - kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. - - Raises: - ValueError: When ``self.spatial_border`` does not contain ints. - ValueError: When ``self.spatial_border`` length is not one of - [1, len(spatial_shape), 2*len(spatial_shape)]. - - """ - spatial_shape = img.shape[1:] + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_border = ensure_tuple(self.spatial_border) if not all(isinstance(b, int) for b in spatial_border): raise ValueError(f"self.spatial_border must contain only ints, got {spatial_border}.") @@ -303,9 +277,7 @@ def __call__( f"Unsupported spatial_border length: {len(spatial_border)}, available options are " f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) - - all_pad_width = [(0, 0)] + data_pad_width - return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) + return [(0, 0)] + data_pad_width class DivisiblePad(Pad): @@ -341,28 +313,10 @@ def __init__( self.method: Method = Method(method) super().__init__(mode=mode, **kwargs) - def __call__( - self, - img: torch.Tensor, - mode: Optional[Union[PytorchPadMode, str]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Args: - img: data to be transformed, assuming `img` is channel-first and - padding doesn't apply to the channel dim. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - default to `self.mode`. - kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. - - """ - new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k) spatial_pad = SpatialPad(spatial_size=new_size, method=self.method) - data_pad_width = spatial_pad._determine_data_pad_width(img.shape[1:]) - all_pad_width = [(0, 0)] + data_pad_width - return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs) + return spatial_pad.compute_pad_width(spatial_shape) class Crop(InvertibleTransform): From e432e0365ce169b7616ea92bdd03c2f931759e69 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 20:52:41 +0800 Subject: [PATCH 15/47] [DLMED] update CropForeground Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 66 ++++++++++++++++++------------- tests/test_crop_foreground.py | 4 +- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 027d405023..158d6e2afc 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -142,14 +142,12 @@ def __call__( img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - if not np.asarray(to_pad_).any(): - # all zeros, skip padding - return img_t - - mode_ = convert_pad_mode(dst=img_t, mode=mode_).value - pad_width = [val for sublist in to_pad_[1:] for val in sublist[::-1]][::-1] - # torch.pad expects `[B, C, H, W, [D]]` shape - img_t = pad_pt(img_t.unsqueeze(0), pad_width, mode=mode_, **kwargs_).squeeze(0) + # all zeros, skip padding + if np.asarray(to_pad_).any(): + mode_ = convert_pad_mode(dst=img_t, mode=mode_).value + pad_width = [val for sublist in to_pad_[1:] for val in sublist[::-1]][::-1] + # torch.pad expects `[B, C, H, W, [D]]` shape + img_t = pad_pt(img_t.unsqueeze(0), pad_width, mode=mode_, **kwargs_).squeeze(0) if get_track_meta(): self._update_meta(tensor=img_t, to_pad=to_pad_) @@ -693,7 +691,7 @@ def __call__(self, img: torch.Tensor) -> List[torch.Tensor]: return ret -class CropForeground(Transform): +class CropForeground(Crop): """ Crop an image using a bounding box. The bounding box is generated by selecting foreground using select_fn at channels channel_indices. margin is added in each spatial dimension of the bounding box. @@ -725,8 +723,6 @@ def threshold_at_one(x): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__( self, select_fn: Callable = is_positive, @@ -735,7 +731,7 @@ def __init__( allow_smaller: bool = True, return_coords: bool = False, k_divisible: Union[Sequence[int], int] = 1, - mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT, + mode: Optional[Union[PytorchPadMode, str]] = PytorchPadMode.CONSTANT, **pad_kwargs, ) -> None: """ @@ -750,14 +746,10 @@ def __init__( return_coords: whether return the coordinates of spatial bounding box for foreground. k_divisible: make each spatial dimension to be divisible by k, default to 1. if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. - note that `np.pad` treats channel dimension as the first dimension. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + pad_kwargs: other arguments for the `torch.pad` function. """ self.select_fn = select_fn @@ -766,10 +758,9 @@ def __init__( self.allow_smaller = allow_smaller self.return_coords = return_coords self.k_divisible = k_divisible - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.pad_kwargs = pad_kwargs + self.padder = Pad(mode=mode, **pad_kwargs) - def compute_bounding_box(self, img: NdarrayOrTensor): + def compute_bounding_box(self, img: torch.Tensor): """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. @@ -790,33 +781,52 @@ def compute_bounding_box(self, img: NdarrayOrTensor): def crop_pad( self, - img: NdarrayOrTensor, + img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, - mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + mode: Optional[Union[PytorchPadMode, str]] = None, + **pad_kwargs, ): """ Crop and pad based on the bounding box. """ - cropped = SpatialCrop(roi_start=box_start, roi_end=box_end)(img) + slices = self.compute_slices(roi_start=box_start, roi_end=box_end) + cropped = super().__call__(img=img, slices=slices) pad_to_start = np.maximum(-box_start, 0) pad_to_end = np.maximum(box_end - np.asarray(img.shape[1:]), 0) pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - return BorderPad(spatial_border=pad, mode=mode or self.mode, **self.pad_kwargs)(cropped) + pad_width = BorderPad(spatial_border=pad).compute_pad_width(cropped.shape[1:]) + ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, **pad_kwargs) + # combine the traced cropping and padding into one transformation + # by taking the padded info and placing it in a key inside the crop info. + if get_track_meta(): + app_op = ret.applied_operations.pop(-1) + ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = app_op + return ret - def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None): + def __call__(self, img: torch.Tensor, mode: Optional[Union[PytorchPadMode, str]] = None, **pad_kwargs): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ box_start, box_end = self.compute_bounding_box(img) - cropped = self.crop_pad(img, box_start, box_end, mode) + cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs) if self.return_coords: return cropped, box_start, box_end return cropped + def inverse(self, img: torch.Tensor) -> torch.Tensor: + transform = self.get_most_recent_transform(img) + # we moved the padding info in the forward, so put it back for the inverse + pad_info = transform[TraceKeys.EXTRA_INFO].pop("pad_info") + img.applied_operations.append(pad_info) + # first inverse the padder + inv = self.padder.inverse(img) + # and then inverse the cropper (self) + return super().inverse(inv) + class RandWeightedCrop(Randomizable, Transform): """ diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index af945673fe..a9a891100c 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -16,11 +16,11 @@ from parameterized import parameterized from monai.transforms import CropForeground -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TEST_COORDS, TESTS = [], [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_COORDS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, From 0807a10721edf27d87b53a2beb7706da8a12d88d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 21:03:29 +0800 Subject: [PATCH 16/47] [DLMED] update random weighted crop Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 10 ++++++---- tests/test_rand_weighted_crop.py | 11 ++++++----- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 158d6e2afc..2637479ec3 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -859,7 +859,7 @@ def randomize(self, weight_map: NdarrayOrTensor) -> None: spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - def __call__(self, img: NdarrayOrTensor, weight_map: Optional[NdarrayOrTensor] = None) -> List[NdarrayOrTensor]: + def __call__(self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = None) -> List[torch.Tensor]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. @@ -880,9 +880,11 @@ def __call__(self, img: NdarrayOrTensor, weight_map: Optional[NdarrayOrTensor] = self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) results: List[NdarrayOrTensor] = [] - for center in self.centers: - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - results.append(cropper(img)) + for i, center in enumerate(self.centers): + cropped = SpatialCrop(roi_center=center, roi_size=_spatial_size)(img) + if get_track_meta(): + cropped.meta[Key.PATCH_INDEX] = i + results.append(cropped) return results diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index dae7f05016..952aff8327 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -15,8 +15,9 @@ import torch from parameterized.parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import RandWeightedCrop -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose def get_data(ndim): @@ -30,8 +31,8 @@ def get_data(ndim): TESTS = [] -for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: im = SEG1_2D weight = np.zeros_like(im) weight[0, 30, 17] = 1.1 @@ -161,10 +162,10 @@ def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, # if desired ROI is larger than image, check image is unchanged if all(s >= i for i, s in zip(img.shape[1:], input_params["spatial_size"])): for res in result: - self.assertEqual(type(img), type(res)) + self.assertEqual(MetaTensor, type(res)) if isinstance(img, torch.Tensor): self.assertEqual(res.device, img.device) - assert_allclose(res, img) + assert_allclose(res, img, type_test=False) if __name__ == "__main__": From 6706d078d2ab6d9903aebdc0f70d000532f38bad Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 21:10:05 +0800 Subject: [PATCH 17/47] [DLMED] update RandCropPosNeg Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 28 +++++++++++++----------- tests/test_rand_crop_by_pos_neg_label.py | 4 ++-- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 2637479ec3..5810c353fe 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -943,16 +943,16 @@ class RandCropByPosNegLabel(Randomizable, Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = SpatialCrop.backend def __init__( self, spatial_size: Union[Sequence[int], int], - label: Optional[NdarrayOrTensor] = None, + label: Optional[torch.Tensor] = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, image_threshold: float = 0.0, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, @@ -975,10 +975,10 @@ def __init__( def randomize( self, - label: NdarrayOrTensor, + label: torch.Tensor, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, ) -> None: if fg_indices is None or bg_indices is None: if self.fg_indices is not None and self.bg_indices is not None: @@ -1002,12 +1002,12 @@ def randomize( def __call__( self, - img: NdarrayOrTensor, - label: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + img: torch.Tensor, + label: Optional[torch.Tensor] = None, + image: Optional[torch.Tensor] = None, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, - ) -> List[NdarrayOrTensor]: + ) -> List[torch.Tensor]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -1030,12 +1030,14 @@ def __call__( image = self.image self.randomize(label, fg_indices, bg_indices, image) - results: List[NdarrayOrTensor] = [] + results: List[torch.Tensor] = [] if self.centers is not None: - for center in self.centers: + for i, center in enumerate(self.centers): roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - cropper = SpatialCrop(roi_center=center, roi_size=roi_size) - results.append(cropper(img)) + cropped = SpatialCrop(roi_center=center, roi_size=roi_size)(img) + if get_track_meta(): + cropped.meta[Key.PATCH_INDEX] = i + results.append(cropped) return results diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index 1d9e2612c7..f6da393ab9 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import RandCropByPosNegLabel -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [ [ @@ -103,7 +103,7 @@ def convert_data_type(im_type, d, keys=("img", "image", "label")): @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_shape): results = [] - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: input_param_mod = self.convert_data_type(p, input_param) input_data_mod = self.convert_data_type(p, input_data) cropper = RandCropByPosNegLabel(**input_param_mod) From d7fa6445e573e3ebd674249b8ae32b8c55d3dbb6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 21:14:47 +0800 Subject: [PATCH 18/47] [DLMED] update rand crop by label classes Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 26 +++++++++++++----------- tests/test_rand_crop_by_label_classes.py | 4 ++-- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 5810c353fe..4bf73d9fbd 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1106,16 +1106,16 @@ class RandCropByLabelClasses(Randomizable, Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = SpatialCrop.backend def __init__( self, spatial_size: Union[Sequence[int], int], ratios: Optional[List[Union[float, int]]] = None, - label: Optional[NdarrayOrTensor] = None, + label: Optional[torch.Tensor] = None, num_classes: Optional[int] = None, num_samples: int = 1, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, image_threshold: float = 0.0, indices: Optional[List[NdarrayOrTensor]] = None, allow_smaller: bool = False, @@ -1133,9 +1133,9 @@ def __init__( def randomize( self, - label: NdarrayOrTensor, + label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, ) -> None: indices_: Sequence[NdarrayOrTensor] if indices is None: @@ -1151,11 +1151,11 @@ def randomize( def __call__( self, - img: NdarrayOrTensor, - label: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + img: torch.Tensor, + label: Optional[torch.Tensor] = None, + image: Optional[torch.Tensor] = None, indices: Optional[List[NdarrayOrTensor]] = None, - ) -> List[NdarrayOrTensor]: + ) -> List[torch.Tensor]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -1176,10 +1176,12 @@ def __call__( self.randomize(label, indices, image) results: List[NdarrayOrTensor] = [] if self.centers is not None: - for center in self.centers: + for i, center in enumerate(self.centers): roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - results.append(cropper(img)) + cropped = SpatialCrop(roi_center=tuple(center), roi_size=roi_size)(img) + if get_track_meta(): + cropped.meta[Key.PATCH_INDEX] = i + results.append(cropped) return results diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index 11d73df74e..b1165e8986 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -15,10 +15,10 @@ from parameterized import parameterized from monai.transforms import ClassesToIndices, RandCropByLabelClasses -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS_INDICES, TESTS_SHAPE = [], [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: # One-Hot label TESTS_INDICES.append( [ From 0278d7c552705534259a493d7c7645c8dd901b4e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 21:26:47 +0800 Subject: [PATCH 19/47] [DLMED] update ResizeCropOrPad Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 54 +++++++++++++++++---------- tests/test_resize_with_pad_or_crop.py | 6 +-- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 4bf73d9fbd..368affb23d 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -43,7 +43,6 @@ from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum from monai.utils import ( Method, - NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, @@ -1186,7 +1185,7 @@ def __call__( return results -class ResizeWithPadOrCrop(Transform): +class ResizeWithPadOrCrop(InvertibleTransform): """ Resize an image to a target spatial size by either centrally cropping the image or padding it evenly with a user-specified mode. @@ -1196,17 +1195,12 @@ class ResizeWithPadOrCrop(Transform): Args: spatial_size: the spatial size of output data after padding or crop. If has non-positive values, the corresponding size of input image will be used (no padding). - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. - note that `np.pad` treats channel dimension as the first dimension. - + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + pad_kwargs: other arguments for the `torch.pad` function. """ backend = list(set(SpatialPad.backend) & set(CenterSpatialCrop.backend)) @@ -1214,28 +1208,48 @@ class ResizeWithPadOrCrop(Transform): def __init__( self, spatial_size: Union[Sequence[int], int], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, + mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, **pad_kwargs, ): self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + self, img: torch.Tensor, mode: Optional[Union[PytorchPadMode, str]] = None, **pad_kwargs, + ) -> torch.Tensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and padding or cropping doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. + pad_kwargs: other arguments for the `torch.pad` function. + """ - return self.padder(self.cropper(img), mode=mode) + ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) + # remove the individual info and combine + if get_track_meta(): + pad_info = ret.applied_operations.pop(-1) + crop_info = ret.applied_operations.pop(-1) + self.push_transform(ret, extra_info={"pad_info": pad_info, "crop_info": crop_info}) + return ret + + def inverse(self, img: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(img) + return self.inverse_transform(img, transform) + + def inverse_transform(self, img: torch.Tensor, transform) -> torch.Tensor: + # we joined the cropping and padding, so put them back before calling the inverse + crop_info = transform[TraceKeys.EXTRA_INFO].pop("crop_info") + pad_info = transform[TraceKeys.EXTRA_INFO].pop("pad_info") + img.applied_operations.append(crop_info) + img.applied_operations.append(pad_info) + # first inverse the padder + inv = self.padder.inverse(img) + # and then inverse the cropper (self) + return self.cropper.inverse(inv) class BoundingRect(Transform): diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index f81e1d4b08..f911a04334 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -16,12 +16,12 @@ from parameterized import parameterized from monai.transforms import ResizeWithPadOrCrop -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TEST_CASES = [ [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8)], [ - {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, + {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "value": 1}, (3, 8, 8, 4), (3, 15, 4, 8), ], @@ -34,7 +34,7 @@ class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: if isinstance(p(0), torch.Tensor) and ( "constant_values" in input_param or input_param["mode"] == "reflect" ): From 1a50f68726484576342fb747fca9eb9d96cbed1a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 23 Jun 2022 23:36:25 +0800 Subject: [PATCH 20/47] [DLMED] restore numpy pad Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 131 +++++++++++++++++--------- tests/test_border_pad.py | 3 +- tests/test_divisible_pad.py | 6 +- tests/test_resize_with_pad_or_crop.py | 2 +- tests/test_spatial_pad.py | 16 +++- 5 files changed, 106 insertions(+), 52 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 368affb23d..d6486556d7 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -84,19 +84,23 @@ class Pad(InvertibleTransform): Args: to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. if None, must provide in the `__call__` at runtime. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - kwargs: other arguments for the `torch.pad` function. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ - backend = [TransformBackends.TORCH] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, to_pad: Optional[List[Tuple[int, int]]] = None, - mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: str = PytorchPadMode.CONSTANT, **kwargs, ) -> None: self.to_pad = to_pad @@ -113,11 +117,21 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + @staticmethod + def _np_pad(img: np.ndarray, pad_width, mode, **kwargs) -> np.ndarray: + return np.pad(img, pad_width, mode=mode, **kwargs) + + @staticmethod + def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: + pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] + # torch.pad expects `[B, C, H, W, [D]]` shape + return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) + def __call__( self, img: torch.Tensor, to_pad: Optional[List[Tuple[int, int]]] = None, - mode: Optional[Union[PytorchPadMode, str]] = None, + mode: Optional[str] = None, **kwargs, ) -> torch.Tensor: """ @@ -125,11 +139,14 @@ def __call__( img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. default to `self.to_pad`. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - default to `self.mode`. - kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ to_pad_ = self.to_pad if to_pad is None else to_pad @@ -143,15 +160,24 @@ def __call__( # all zeros, skip padding if np.asarray(to_pad_).any(): - mode_ = convert_pad_mode(dst=img_t, mode=mode_).value - pad_width = [val for sublist in to_pad_[1:] for val in sublist[::-1]][::-1] - # torch.pad expects `[B, C, H, W, [D]]` shape - img_t = pad_pt(img_t.unsqueeze(0), pad_width, mode=mode_, **kwargs_).squeeze(0) - + try: + mode_ = convert_pad_mode(dst=img_t, mode=mode_).value + out = self._pt_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) + # but if mode or args don't exist in pytorch, use numpy instead + except (ValueError, TypeError) as err: + if "Unsupported option" in str(err) or "unexpected keyword" in str(err): + # extract metadata + img_np = img_t.detach().cpu().numpy() + mode = convert_pad_mode(dst=img_np, mode=mode_).value + out = torch.as_tensor(self._np_pad(img_np, pad_width=to_pad_, mode=mode_, **kwargs_)) + if get_track_meta(): + out = MetaTensor(out, meta=img_t.meta, applied_operations=img_t.applied_operations) # type: ignore + else: + out = img_t if get_track_meta(): - self._update_meta(tensor=img_t, to_pad=to_pad_) - self.push_transform(img_t, extra_info={"padded": to_pad_}) - return img_t + self._update_meta(tensor=out, to_pad=to_pad_) + self.push_transform(out, extra_info={"padded": to_pad_}) + return out def _update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): spatial_rank = max(len(tensor.affine) - 1, 1) @@ -185,11 +211,14 @@ class SpatialPad(Pad): `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - default to `self.mode`. - kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -197,7 +226,7 @@ def __init__( self, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: str = PytorchPadMode.CONSTANT, **kwargs, ) -> None: self.spatial_size = spatial_size @@ -251,7 +280,7 @@ class BorderPad(Pad): def __init__( self, spatial_border: Union[Sequence[int], int], - mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: str = PytorchPadMode.CONSTANT, **kwargs, ) -> None: self.spatial_border = spatial_border @@ -288,7 +317,7 @@ def __init__( self, k: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: str = PytorchPadMode.CONSTANT, **kwargs, ) -> None: """ @@ -298,11 +327,14 @@ def __init__( if `k` is an int, the same `k` be applied to all the input spatial dimensions. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - default to `self.mode`. - kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ @@ -730,7 +762,7 @@ def __init__( allow_smaller: bool = True, return_coords: bool = False, k_divisible: Union[Sequence[int], int] = 1, - mode: Optional[Union[PytorchPadMode, str]] = PytorchPadMode.CONSTANT, + mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, ) -> None: """ @@ -745,10 +777,14 @@ def __init__( return_coords: whether return the coordinates of spatial bounding box for foreground. k_divisible: make each spatial dimension to be divisible by k, default to 1. if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - pad_kwargs: other arguments for the `torch.pad` function. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ self.select_fn = select_fn @@ -783,7 +819,7 @@ def crop_pad( img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, - mode: Optional[Union[PytorchPadMode, str]] = None, + mode: Optional[str] = None, **pad_kwargs, ): """ @@ -804,7 +840,7 @@ def crop_pad( ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = app_op return ret - def __call__(self, img: torch.Tensor, mode: Optional[Union[PytorchPadMode, str]] = None, **pad_kwargs): + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. @@ -1197,10 +1233,15 @@ class ResizeWithPadOrCrop(InvertibleTransform): If has non-positive values, the corresponding size of input image will be used (no padding). method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - pad_kwargs: other arguments for the `torch.pad` function. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. + """ backend = list(set(SpatialPad.backend) & set(CenterSpatialCrop.backend)) @@ -1209,23 +1250,27 @@ def __init__( self, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, ): self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) def __call__( - self, img: torch.Tensor, mode: Optional[Union[PytorchPadMode, str]] = None, **pad_kwargs, + self, img: torch.Tensor, mode: Optional[ str] = None, **pad_kwargs, ) -> torch.Tensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and padding or cropping doesn't apply to the channel dim. - mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. - pad_kwargs: other arguments for the `torch.pad` function. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index 97e463ec3a..b06aa7c564 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -46,10 +46,9 @@ def test_pad_shape(self, input_param, input_data, expected_val): self.assertAlmostEqual(r2.shape, expected_val.shape) def test_pad_kwargs(self): - padder = BorderPad(spatial_border=2, mode="constant", value=1) + padder = BorderPad(spatial_border=2, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) result = padder(np.zeros((3, 8, 4))) np.testing.assert_allclose(result[:, :2, 2:6], np.ones((3, 2, 4))) - result = padder(np.zeros((3, 8, 4)), mode="constant", value=2) np.testing.assert_allclose(result[:, :, :2], np.ones((3, 12, 2)) + 1) diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index 05f4c57bf0..4428078f40 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -42,7 +42,11 @@ def test_pad_shape(self, input_param, input_data, expected_val): def test_pad_kwargs(self): for p in TEST_NDARRAYS_ALL: input_data = p(np.zeros((3, 8, 4))) - result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() + if isinstance(input_data, np.ndarray): + result = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))(input_data) + np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) + else: + result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() torch.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1, rtol=1e-7, atol=0) diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index f911a04334..2eb39bfe4d 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -21,7 +21,7 @@ TEST_CASES = [ [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8)], [ - {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "value": 1}, + {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, (3, 8, 8, 4), (3, 15, 4, 8), ], diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 603c407b5d..932760c3d9 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -83,11 +83,17 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): def test_pad_kwargs(self): for p in TEST_NDARRAYS_ALL: input_data = p(np.zeros((3, 8, 4))) - result = ( - SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) - .cpu() - .numpy() - ) + if isinstance(input_data, torch.Tensor): + result = ( + SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) + .cpu() + .numpy() + ) + else: + result = SpatialPad( + spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) + )(img=input_data) + torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) torch.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1, rtol=1e-7, atol=0) From d97d6afabf4df0941b43da1d9bb4de01d362af0f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 11:49:06 +0800 Subject: [PATCH 21/47] [DLMED] update dict spatial pad Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 ++ monai/transforms/__init__.py | 4 +- monai/transforms/croppad/dictionary.py | 94 +++++++++++++++----------- 3 files changed, 63 insertions(+), 41 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4c2f0e5901..85963651f4 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1001,6 +1001,12 @@ Dictionary Transforms Crop and Pad (Dict) ^^^^^^^^^^^^^^^^^^^ +`Padd` +"""""" +.. autoclass:: Padd + :members: + :special-members: __call__ + `SpatialPadd` """"""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialPadd.png diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 68ced7ac6d..b08cfeaf28 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -50,7 +50,9 @@ DivisiblePadd, DivisiblePadD, DivisiblePadDict, - PadModeSequence, + Padd, + PadD, + PadDict, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 50cc767cab..25734e8885 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -23,6 +23,7 @@ from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np +import torch from monai.config import IndexSelection, KeysCollection from monai.config.type_definitions import NdarrayOrTensor @@ -33,6 +34,7 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + Pad, RandCropByLabelClasses, RandCropByPosNegLabel, ResizeWithPadOrCrop, @@ -55,7 +57,6 @@ from monai.utils.enums import PostFix, TraceKeys __all__ = [ - "PadModeSequence", "SpatialPadd", "BorderPadd", "DivisiblePadd", @@ -103,24 +104,64 @@ "RandCropByLabelClassesDict", ] -PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] DEFAULT_POST_FIX = PostFix.meta() -class SpatialPadd(MapTransform, InvertibleTransform): +class Padd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.Pad`. + + """ + backend = Pad.backend + + def __init__( + self, keys: KeysCollection, padder: Pad, mode: str = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + padder: pad transform for the input image. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.padder = padder + self.mode = ensure_tuple_rep(mode, len(self.keys)) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key, m in self.key_iterator(d, self.mode): + d[key] = self.padder(d[key], mode=m) + return d + + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.padder.inverse(d[key]) + return d + + +class SpatialPadd(Padd): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. - """ - backend = SpatialPad.backend + """ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -147,36 +188,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = SpatialPad(spatial_size, method, **kwargs) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - d[key] = self.padder(d[key], mode=m) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = transform[TraceKeys.ORIG_SIZE] - if self.padder.method == Method.SYMMETRIC: - current_size = d[key].shape[1:] - roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] - else: - roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] - - inverse_transform = SpatialCrop(roi_center, orig_size) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + padder = SpatialPad(spatial_size, method, **kwargs) + super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class BorderPadd(MapTransform, InvertibleTransform): @@ -191,7 +204,7 @@ def __init__( self, keys: KeysCollection, spatial_border: Union[Sequence[int], int], - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -270,7 +283,7 @@ def __init__( self, keys: KeysCollection, k: Union[Sequence[int], int], - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, allow_missing_keys: bool = False, **kwargs, @@ -835,7 +848,7 @@ def __init__( margin: Union[Sequence[int], int] = 0, allow_smaller: bool = True, k_divisible: Union[Sequence[int], int] = 1, - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, @@ -1435,7 +1448,7 @@ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, method: Union[Method, str] = Method.SYMMETRIC, **pad_kwargs, @@ -1528,6 +1541,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +PadD = PadDict = Padd SpatialPadD = SpatialPadDict = SpatialPadd BorderPadD = BorderPadDict = BorderPadd DivisiblePadD = DivisiblePadDict = DivisiblePadd From ab1c7ca3dbf0198e91601917eb789bc379c08296 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 12:06:56 +0800 Subject: [PATCH 22/47] [DLMED] update border pad and divisible pad Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 74 +++----------------------- 1 file changed, 8 insertions(+), 66 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 25734e8885..bca7d36e3f 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -192,7 +192,7 @@ def __init__( super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) -class BorderPadd(MapTransform, InvertibleTransform): +class BorderPadd(Padd): """ Pad the input data by adding specified borders to every dimension. Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`. @@ -204,7 +204,7 @@ def __init__( self, keys: KeysCollection, spatial_border: Union[Sequence[int], int], - mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -235,43 +235,11 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = BorderPad(spatial_border=spatial_border, **kwargs) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - d[key] = self.padder(d[key], mode=m) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - roi_start = np.array(self.padder.spatial_border) - # Need to convert single value to [min1,min2,...] - if roi_start.size == 1: - roi_start = np.full((len(orig_size)), roi_start) - # need to convert [min1,max1,min2,...] to [min1,min2,...] - elif roi_start.size == 2 * orig_size.size: - roi_start = roi_start[::2] - roi_end = np.array(transform[TraceKeys.ORIG_SIZE]) + roi_start - - inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + padder = BorderPad(spatial_border=spatial_border, **kwargs) + super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) -class DivisiblePadd(MapTransform, InvertibleTransform): +class DivisiblePadd(Padd): """ Pad the input data, so that the spatial sizes are divisible by `k`. Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`. @@ -283,7 +251,7 @@ def __init__( self, keys: KeysCollection, k: Union[Sequence[int], int], - mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, allow_missing_keys: bool = False, **kwargs, @@ -311,34 +279,8 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = DivisiblePad(k=k, method=method, **kwargs) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - d[key] = self.padder(d[key], mode=m) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - roi_start = np.floor((current_size - orig_size) / 2) - roi_end = orig_size + roi_start - inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + padder = DivisiblePad(k=k, method=method, **kwargs) + super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class SpatialCropd(MapTransform, InvertibleTransform): From e9f995e78f7750eed1a3380cfe086bd22e11d87d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 12:57:35 +0800 Subject: [PATCH 23/47] [DLMED] update spatial crop dict Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 +++ monai/transforms/__init__.py | 6 ++ monai/transforms/croppad/dictionary.py | 118 ++++++++++++++++++------- tests/test_spatial_cropd.py | 4 +- 4 files changed, 107 insertions(+), 33 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 85963651f4..7eaa17ea43 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1031,6 +1031,18 @@ Crop and Pad (Dict) :members: :special-members: __call__ +`Cropd` +""""""" +.. autoclass:: Cropd + :members: + :special-members: __call__ + +`RandCropd` +""""""""""" +.. autoclass:: RandCropd + :members: + :special-members: __call__ + `SpatialCropd` """""""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCropd.png diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b08cfeaf28..f02737a4ef 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -44,6 +44,9 @@ CenterSpatialCropd, CenterSpatialCropD, CenterSpatialCropDict, + Cropd, + CropD, + CropDict, CropForegroundd, CropForegroundD, CropForegroundDict, @@ -53,6 +56,9 @@ Padd, PadD, PadDict, + RandCropd, + RandCropD, + RandCropDict, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index bca7d36e3f..4151d6a79d 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -32,6 +32,7 @@ BorderPad, BoundingRect, CenterSpatialCrop, + Crop, CropForeground, DivisiblePad, Pad, @@ -57,9 +58,12 @@ from monai.utils.enums import PostFix, TraceKeys __all__ = [ + "Padd", "SpatialPadd", "BorderPadd", "DivisiblePadd", + "Cropd", + "RandCropd", "SpatialCropd", "CenterSpatialCropd", "CenterScaleCropd", @@ -72,12 +76,18 @@ "ResizeWithPadOrCropd", "BoundingRectd", "RandCropByLabelClassesd", + "PadD", + "PadDict", "SpatialPadD", "SpatialPadDict", "BorderPadD", "BorderPadDict", "DivisiblePadD", "DivisiblePadDict", + "CropD", + "CropDict", + "RandCropD", + "RandCropDict", "SpatialCropD", "SpatialCropDict", "CenterSpatialCropD", @@ -283,7 +293,78 @@ def __init__( super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) -class SpatialCropd(MapTransform, InvertibleTransform): +class Cropd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of abstract class :py:class:`monai.transforms.Crop`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + cropper: crop transform for the input image. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = Crop.backend + + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) + self.cropper = cropper + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.cropper(d[key]) + return d + + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.cropper.inverse(d[key]) + return d + + +class RandCropd(Cropd, Randomizable): + """ + Base class for random crop transform. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + cropper: random crop transform for the input image. + allow_missing_keys: don't raise exception if key is missing. + + """ + backend = Crop.backend + + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) + self.cropper = cropper + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCropd": + super().set_random_state(seed, state) + if isinstance(self.cropper, Randomizable): + self.cropper.set_random_state(seed, state) + return self + + def randomize(self, img_size: Sequence[int]) -> None: + if isinstance(self.cropper, Randomizable): + self.cropper.randomize(img_size) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + # only randomize at start + self.randomize(d[self.first_key(d)].shape[1:]) + + for key in self.key_iterator(d): + kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} + d[key] = self.cropper(d[key], **kwargs) + return d + + +class SpatialCropd(Cropd): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`. General purpose cropper to produce sub-volume region of interest (ROI). @@ -322,37 +403,10 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. allow_missing_keys: don't raise exception if key is missing. - """ - super().__init__(keys, allow_missing_keys) - self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key in self.key_iterator(d): - self.push_transform(d, key) - d[key] = self.cropper(d[key]) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + """ + cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) class CenterSpatialCropd(MapTransform, InvertibleTransform): @@ -1487,6 +1541,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N SpatialPadD = SpatialPadDict = SpatialPadd BorderPadD = BorderPadDict = BorderPadd DivisiblePadD = DivisiblePadDict = DivisiblePadd +CropD = CropDict = Cropd +RandCropD = RandCropDict = RandCropd SpatialCropD = SpatialCropDict = SpatialCropd CenterSpatialCropD = CenterSpatialCropDict = CenterSpatialCropd CenterScaleCropD = CenterScaleCropDict = CenterScaleCropd diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 5b16f460fd..87f11a106d 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -15,10 +15,10 @@ from parameterized import parameterized from monai.transforms import SpatialCropd -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TESTS.append( [ {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, From 08d926cc385300436d1bdfaafa3477c4d6704edc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 16:12:36 +0800 Subject: [PATCH 24/47] [DLMED] update center spatial crop Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 39 ++------------------------ tests/test_center_spatial_cropd.py | 6 ++-- 2 files changed, 6 insertions(+), 39 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4151d6a79d..aff73a0a71 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -379,8 +379,6 @@ class SpatialCropd(Cropd): - the start and end coordinates of the ROI """ - backend = SpatialCrop.backend - def __init__( self, keys: KeysCollection, @@ -409,7 +407,7 @@ def __init__( super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class CenterSpatialCropd(MapTransform, InvertibleTransform): +class CenterSpatialCropd(Cropd): """ Dictionary-based wrapper of :py:class:`monai.transforms.CenterSpatialCrop`. If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -427,42 +425,11 @@ class CenterSpatialCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = CenterSpatialCrop.backend - def __init__( self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False ) -> None: - super().__init__(keys, allow_missing_keys) - self.cropper = CenterSpatialCrop(roi_size) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key in self.key_iterator(d): - orig_size = d[key].shape[1:] - d[key] = self.cropper(d[key]) - self.push_transform(d, key, orig_size=orig_size) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) - # in each direction, if original size is even and current size is odd, += 1 - pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 - pad_to_end = orig_size - current_size - pad_to_start - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + cropper = CenterSpatialCrop(roi_size) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) class CenterScaleCropd(MapTransform, InvertibleTransform): diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index bdbc1a5031..02e100bb92 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -15,10 +15,10 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCropd -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_SHAPES = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_SHAPES.append( [{"keys": "img", "roi_size": [2, -1, -1]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 3, 3)] ) @@ -28,7 +28,7 @@ ) TEST_CASES = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_CASES.append( [ {"keys": "img", "roi_size": [2, 2]}, From d006038e378210c0d208469c3f2393e54cabd041 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 17:01:05 +0800 Subject: [PATCH 25/47] [DLMED] update rand scale crop dict Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 27 ++-- monai/transforms/croppad/dictionary.py | 172 +++---------------------- tests/test_rand_scale_cropd.py | 8 +- tests/test_rand_spatial_cropd.py | 6 +- 4 files changed, 40 insertions(+), 173 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index d6486556d7..5b48a39b0b 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -593,12 +593,14 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. + """ - self.randomize(img.shape[1:]) + if randomize: + self.randomize(img.shape[1:]) if self._size is None: raise RuntimeError("self._size not specified.") if self.random_center: @@ -639,19 +641,26 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: torch.Tensor) -> torch.Tensor: - """ - Apply the transform to `img`, assuming `img` is channel-first and - slicing doesn't apply to the channel dim. - """ - img_size = img.shape[1:] + def get_max_roi_size(self, img_size): ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: self.max_roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.max_roi_scale, ndim), img_size)] else: self.max_roi_size = None - return super().__call__(img=img) + + def randomize(self, img_size: Sequence[int]) -> None: + self.get_max_roi_size(img_size) + super().randomize(img_size) + + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + """ + Apply the transform to `img`, assuming `img` is channel-first and + slicing doesn't apply to the channel dim. + + """ + self.get_max_roi_size(img.shape[1:]) + return super().__call__(img=img, randomize=randomize) class RandSpatialCropSamples(Randomizable, Transform): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index aff73a0a71..81268f247c 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,11 +15,10 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -import contextlib from copy import deepcopy from enum import Enum from itertools import chain -from math import ceil, floor +from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -27,10 +26,10 @@ from monai.config import IndexSelection, KeysCollection from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.croppad.array import ( BorderPad, BoundingRect, + CenterScaleCrop, CenterSpatialCrop, Crop, CropForeground, @@ -38,6 +37,8 @@ Pad, RandCropByLabelClasses, RandCropByPosNegLabel, + RandScaleCrop, + RandSpatialCrop, ResizeWithPadOrCrop, SpatialCrop, SpatialPad, @@ -338,8 +339,7 @@ class RandCropd(Cropd, Randomizable): backend = Crop.backend def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): - super().__init__(keys, allow_missing_keys) - self.cropper = cropper + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -355,9 +355,8 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - # only randomize at start + # the first key must exist to execute random operations self.randomize(d[self.first_key(d)].shape[1:]) - for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} d[key] = self.cropper(d[key], **kwargs) @@ -432,7 +431,7 @@ def __init__( super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class CenterScaleCropd(MapTransform, InvertibleTransform): +class CenterScaleCropd(Cropd): """ Dictionary-based wrapper of :py:class:`monai.transforms.CenterScaleCrop`. Note: as using the same scaled ROI to crop, all the input data specified by `keys` should have @@ -446,54 +445,14 @@ class CenterScaleCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = CenterSpatialCrop.backend - def __init__( self, keys: KeysCollection, roi_scale: Union[Sequence[float], float], allow_missing_keys: bool = False ) -> None: - super().__init__(keys, allow_missing_keys=allow_missing_keys) - self.roi_scale = roi_scale - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - first_key: Union[Hashable, List] = self.first_key(d) - if first_key == []: - return d - - # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = d[first_key].shape[1:] # type: ignore - ndim = len(img_size) - roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - cropper = CenterSpatialCrop(roi_size) - for key in self.key_iterator(d): - self.push_transform(d, key, orig_size=img_size) - d[key] = cropper(d[key]) - - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) - # in each direction, if original size is even and current size is odd, += 1 - pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 - pad_to_end = orig_size - current_size - pad_to_start - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + cropper = CenterScaleCrop(roi_scale) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): +class RandSpatialCropd(RandCropd): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. Crop image with random size or specific size ROI. It can crop at a random position as @@ -523,8 +482,6 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = CenterSpatialCrop.backend - def __init__( self, keys: KeysCollection, @@ -534,78 +491,11 @@ def __init__( random_size: bool = True, allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys, allow_missing_keys) - self.roi_size = roi_size - self.max_roi_size = max_roi_size - self.random_center = random_center - self.random_size = random_size - self._slices: Optional[Tuple[slice, ...]] = None - self._size: Optional[Sequence[int]] = None - - def randomize(self, img_size: Sequence[int]) -> None: - self._size = fall_back_tuple(self.roi_size, img_size) - if self.random_size: - max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size) - if any(i > j for i, j in zip(self._size, max_size)): - raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.") - self._size = [self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))] - if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - first_key: Union[Hashable, List] = self.first_key(d) - if first_key == []: - return d - - self.randomize(d[first_key].shape[1:]) # type: ignore - if self._size is None: - raise RuntimeError("self._size not specified.") - for key in self.key_iterator(d): - if self.random_center: - self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore - d[key] = d[key][self._slices] - else: - self.push_transform(d, key) - cropper = CenterSpatialCrop(self._size) - d[key] = cropper(d[key]) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = transform[TraceKeys.ORIG_SIZE] - random_center = self.random_center - pad_to_start = np.empty((len(orig_size)), dtype=np.int32) - pad_to_end = np.empty((len(orig_size)), dtype=np.int32) - if random_center: - for i, _slice in enumerate(transform[TraceKeys.EXTRA_INFO]["slices"]): - pad_to_start[i] = _slice[0] - pad_to_end[i] = orig_size[i] - _slice[1] - else: - current_size = d[key].shape[1:] - for i, (o_s, c_s) in enumerate(zip(orig_size, current_size)): - pad_to_start[i] = pad_to_end[i] = (o_s - c_s) / 2 - if o_s % 2 == 0 and c_s % 2 == 1: - pad_to_start[i] += 1 - elif o_s % 2 == 1 and c_s % 2 == 0: - pad_to_end[i] += 1 - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class RandScaleCropd(RandSpatialCropd): +class RandScaleCropd(RandCropd): """ Dictionary-based version :py:class:`monai.transforms.RandScaleCrop`. Crop image with random size or specific size ROI. @@ -630,8 +520,6 @@ class RandScaleCropd(RandSpatialCropd): allow_missing_keys: don't raise exception if key is missing. """ - backend = RandSpatialCropd.backend - def __init__( self, keys: KeysCollection, @@ -641,38 +529,8 @@ def __init__( random_size: bool = True, allow_missing_keys: bool = False, ) -> None: - super().__init__( - keys=keys, - roi_size=-1, - max_roi_size=None, - random_center=random_center, - random_size=random_size, - allow_missing_keys=allow_missing_keys, - ) - self.roi_scale = roi_scale - self.max_roi_scale = max_roi_scale - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - first_key: Union[Hashable, List] = self.first_key(data) # type: ignore - if first_key == []: - return data # type: ignore - - img_size = data[first_key].shape[1:] # type: ignore - ndim = len(img_size) - self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - if self.max_roi_scale is not None: - self.max_roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.max_roi_scale, ndim), img_size)] - else: - self.max_roi_size = None - return super().__call__(data=data) - - -@contextlib.contextmanager -def _nullcontext(x): - """ - This is just like contextlib.nullcontext but also works in Python 3.6. - """ - yield x + cropper = RandScaleCrop(roi_scale, max_roi_scale, random_center, random_size) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 5e833fef98..177cf7fa49 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCropd -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -25,7 +25,7 @@ TEST_CASE_2 = [ # test `allow_missing_keys` with key "label" - {"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, + {"keys": ["img", "label"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, (3, 3, 3, 3), ] @@ -68,11 +68,11 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: cropper = RandScaleCropd(**input_param) input_data["img"] = p(input_data["img"]) result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] assert_allclose( result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False ) diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 9e6e86eea2..29d988562f 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropd -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TEST_CASE_0 = [ {"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, @@ -63,12 +63,12 @@ def test_shape(self, input_param, input_data, expected_shape): def test_value(self, input_param, input_data): cropper = RandSpatialCropd(**input_param) result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: cropper = RandSpatialCropd(**input_param) cropper.set_random_state(seed=123) input_data["img"] = p(input_data["img"]) From aadc19d3ae9a6fecf55a46cfc5d9a8f7bf976de5 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 18:04:30 +0800 Subject: [PATCH 26/47] [DLMED] update rand spatial crop samples dict Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 73 +++++++----------------- tests/test_rand_spatial_crop_samplesd.py | 41 +++++++------ 2 files changed, 41 insertions(+), 73 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 81268f247c..2de3a63d3d 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -39,6 +39,7 @@ RandCropByPosNegLabel, RandScaleCrop, RandSpatialCrop, + RandSpatialCropSamples, ResizeWithPadOrCrop, SpatialCrop, SpatialPad, @@ -46,7 +47,6 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( - allow_missing_keys_mode, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, is_positive, @@ -54,8 +54,10 @@ map_classes_to_indices, weighted_patch_samples, ) +from monai.utils import MAX_SEED from monai.utils import ImageMetaKey as Key from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PostFix, TraceKeys __all__ = [ @@ -533,7 +535,7 @@ def __init__( super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): +class RandSpatialCropSamplesd(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`. Crop image with random size or specific size ROI to generate a list of N samples. @@ -562,15 +564,6 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. Raises: @@ -578,8 +571,8 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): """ - backend = RandSpatialCropd.backend - + @deprecated_arg(name="meta_keys", since="0.8") + @deprecated_arg(name="meta_key_postfix", since="0.8") def __init__( self, keys: KeysCollection, @@ -593,57 +586,33 @@ def __init__( allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - if num_samples < 1: - raise ValueError(f"num_samples must be positive, got {num_samples}.") - self.num_samples = num_samples - self.cropper = RandSpatialCropd(keys, roi_size, max_roi_size, random_center, random_size, allow_missing_keys) - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.cropper = RandSpatialCropSamples(roi_size, num_samples, max_roi_size, random_center, random_size) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandSpatialCropSamplesd": super().set_random_state(seed, state) - self.cropper.set_random_state(seed, state) return self def randomize(self, data: Optional[Any] = None) -> None: - pass + self.sub_seed = self.R.randint(MAX_SEED, dtype="uint32") def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: - ret = [] - for i in range(self.num_samples): - d = dict(data) - # deep copy all the unmodified data - for key in set(data.keys()).difference(set(self.keys)): - d[key] = deepcopy(data[key]) - cropped = self.cropper(d) - # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd - for key in self.key_iterator(cropped): - cropped[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore - cropped[self.trace_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in cropped: - cropped[meta_key] = {} # type: ignore - cropped[meta_key][Key.PATCH_INDEX] = i # type: ignore - ret.append(cropped) + # output starts as empty list of dictionaries + ret: List[Dict[Hashable, torch.Tensor]] = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) + + # for each key we reset the random state to ensure crops are the same + self.randomize() + for key in self.key_iterator(dict(data)): + self.cropper.set_random_state(seed=self.sub_seed) + for i, im in enumerate(self.cropper(data[key])): + ret[i][key] = im return ret - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: - d = deepcopy(dict(data)) - # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd - # Need to revert that since we're calling RandSpatialCropd's inverse - for key in self.key_iterator(d): - d[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__ - d[self.trace_key(key)][-1][TraceKeys.ID] = id(self.cropper) - context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext - with context_manager(self.cropper): - return self.cropper.inverse(d) - class CropForegroundd(MapTransform, InvertibleTransform): """ diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 0891068488..4da438d2a0 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -14,46 +14,45 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Compose, RandSpatialCropSamplesd, ToTensord -from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS, assert_allclose +from monai.transforms import Compose, DivisiblePadd, RandSpatialCropSamplesd +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)], + [(3, 2, 2, 2), (3, 2, 3, 3), (3, 2, 3, 2), (3, 2, 3, 2)], { "img": np.array( [ - [[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]], - [[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]], - [[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]], + [[[1, 2], [4, 5], [7, 8]], [[10, 11], [13, 14], [16, 17]]], + [[[28, 29], [31, 32], [34, 35]], [[37, 38], [40, 41], [43, 44]]], + [[[55, 56], [58, 59], [61, 62]], [[64, 65], [67, 68], [70, 71]]], ] ), "seg": np.array( [ - [[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]], - [[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]], - [[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]], + [[[80, 79], [77, 76], [74, 73]], [[71, 70], [68, 67], [65, 64]]], + [[[53, 52], [50, 49], [47, 46]], [[44, 43], [41, 40], [38, 37]]], + [[[26, 25], [23, 22], [20, 19]], [[17, 16], [14, 13], [11, 10]]], ] ), }, ] TEST_CASE_2 = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_CASE_2.append( [ {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, [ - (3, 3, 3, 3), - (3, 2, 3, 3), (3, 2, 2, 3), - (3, 2, 3, 3), + (3, 2, 2, 3), (3, 3, 3, 3), + (3, 2, 3, 3), (3, 3, 3, 3), - (3, 2, 2, 3), + (3, 2, 3, 3), + (3, 2, 3, 3), (3, 3, 2, 3), ], { @@ -90,10 +89,10 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): self.assertTupleEqual(item["img"].shape, expected) self.assertTupleEqual(item["seg"].shape, expected) for i, item in enumerate(result): - self.assertEqual(item[PostFix.meta("img")]["patch_index"], i) - self.assertEqual(item[PostFix.meta("seg")]["patch_index"], i) - assert_allclose(item["img"], expected_last["img"], type_test=True) - assert_allclose(item["seg"], expected_last["seg"], type_test=True) + self.assertEqual(item["img"].meta["patch_index"], i) + self.assertEqual(item["seg"].meta["patch_index"], i) + assert_allclose(item["img"], expected_last["img"], type_test=False) + assert_allclose(item["seg"], expected_last["seg"], type_test=False) def test_deep_copy(self): data = {"img": np.ones((1, 10, 11, 12))} @@ -101,11 +100,11 @@ def test_deep_copy(self): sampler = RandSpatialCropSamplesd( keys=["img"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False ) - transform = Compose([ToTensord(keys="img"), sampler]) + transform = Compose([DivisiblePadd(keys="img", k=5), sampler]) samples = transform(data) self.assertEqual(len(samples), num_samples) for sample in samples: - self.assertEqual(len(sample["img_transforms"]), len(transform)) + self.assertEqual(len(sample["img"].applied_operations), len(transform)) if __name__ == "__main__": From 4b4548b5791538149fdc6ad79eda9f3905fc2259 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 18:13:12 +0800 Subject: [PATCH 27/47] [DLMED] update crop foreground dict Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 40 ++++---------------------- tests/test_crop_foregroundd.py | 11 +++---- 2 files changed, 12 insertions(+), 39 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 2de3a63d3d..08e199b1a7 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -570,6 +570,7 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform): ValueError: When ``num_samples`` is nonpositive. """ + backend = RandSpatialCropSamples.backend @deprecated_arg(name="meta_keys", since="0.8") @deprecated_arg(name="meta_key_postfix", since="0.8") @@ -614,7 +615,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab return ret -class CropForegroundd(MapTransform, InvertibleTransform): +class CropForegroundd(Cropd): """ Dictionary-based version :py:class:`monai.transforms.CropForeground`. Crop only the foreground object of the expected images. @@ -627,8 +628,6 @@ class CropForegroundd(MapTransform, InvertibleTransform): channels. And it can also add margin to every dim of the bounding box of foreground object. """ - backend = CropForeground.backend - def __init__( self, keys: KeysCollection, @@ -676,7 +675,7 @@ def __init__( self.source_key = source_key self.start_coord_key = start_coord_key self.end_coord_key = end_coord_key - self.cropper = CropForeground( + cropper = CropForeground( select_fn=select_fn, channel_indices=channel_indices, margin=margin, @@ -684,46 +683,19 @@ def __init__( k_divisible=k_divisible, **pad_kwargs, ) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + self.mode = ensure_tuple_rep(mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) d[self.start_coord_key] = box_start d[self.end_coord_key] = box_end for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"box_start": box_start, "box_end": box_end}) d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - cur_size = np.asarray(d[key].shape[1:]) - extra_info = transform[TraceKeys.EXTRA_INFO] - box_start = np.asarray(extra_info["box_start"]) - box_end = np.asarray(extra_info["box_end"]) - # first crop the padding part - roi_start = np.maximum(-box_start, 0) - roi_end = cur_size - np.maximum(box_end - orig_size, 0) - - d[key] = SpatialCrop(roi_start=roi_start, roi_end=roi_end)(d[key]) - - # update bounding box to pad - pad_to_start = np.maximum(box_start, 0) - pad_to_end = orig_size - np.minimum(box_end, orig_size) - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - # second pad back the original size - d[key] = BorderPad(pad)(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d - class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): """ diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index fa69143827..1bd08e13e9 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -15,11 +15,12 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForegroundd -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_POSITION, TESTS = [], [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_POSITION.append( [ @@ -153,10 +154,10 @@ class TestCropForegroundd(unittest.TestCase): def test_value(self, argments, input_data, expected_data): result = CropForegroundd(**argments)(input_data) r, i = result["img"], input_data["img"] - self.assertEqual(type(r), type(i)) - if isinstance(r, torch.Tensor): + self.assertEqual(type(r), MetaTensor) + if isinstance(i, torch.Tensor): self.assertEqual(r.device, i.device) - assert_allclose(r, expected_data) + assert_allclose(r, expected_data, type_test=False) @parameterized.expand(TEST_POSITION) def test_foreground_position(self, argments, input_data, _): From 2027dfb19916df4cd5bfdd98cf67b537061753a0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 18:44:01 +0800 Subject: [PATCH 28/47] [DLMED] update rand weighted crop dict Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 9 ++- monai/transforms/croppad/dictionary.py | 106 +++++++------------------ tests/test_rand_weighted_cropd.py | 67 ++++++++-------- 3 files changed, 67 insertions(+), 115 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 5b48a39b0b..6d99c39297 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -903,13 +903,16 @@ def randomize(self, weight_map: NdarrayOrTensor) -> None: spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - def __call__(self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = None) -> List[torch.Tensor]: + def __call__( + self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = None, randomize: bool = True, + ) -> List[torch.Tensor]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. weight_map: weight map used to generate patch samples. The weights must be non-negative. Each element denotes a sampling weight of the spatial location. 0 indicates no sampling. It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)` + randomize: whether to execute random operations, defautl to `True`. Returns: A list of image patches @@ -921,13 +924,15 @@ def __call__(self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = No if img.shape[1:] != weight_map.shape[1:]: raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") - self.randomize(weight_map) + if randomize: + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) results: List[NdarrayOrTensor] = [] for i, center in enumerate(self.centers): cropped = SpatialCrop(roi_center=center, roi_size=_spatial_size)(img) if get_track_meta(): cropped.meta[Key.PATCH_INDEX] = i + cropped.meta["crop_center"] = center results.append(cropped) return results diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 08e199b1a7..cef360e13c 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -40,6 +40,7 @@ RandScaleCrop, RandSpatialCrop, RandSpatialCropSamples, + RandWeightedCrop, ResizeWithPadOrCrop, SpatialCrop, SpatialPad, @@ -52,7 +53,6 @@ is_positive, map_binary_to_indices, map_classes_to_indices, - weighted_patch_samples, ) from monai.utils import MAX_SEED from monai.utils import ImageMetaKey as Key @@ -697,7 +697,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d -class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): +class RandWeightedCropd(Randomizable, MapTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -709,16 +709,6 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): spatial_size: the spatial size of the image patch e.g. [224, 224, 128]. If its components have non-positive values, the corresponding size of `img` will be used. num_samples: number of samples (image patches) to take in the returned list. - center_coord_key: if specified, the actual sampling location will be stored with the corresponding key. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. See Also: @@ -727,6 +717,9 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): backend = SpatialCrop.backend + @deprecated_arg(name="meta_keys", since="0.8") + @deprecated_arg(name="meta_key_postfix", since="0.8") + @deprecated_arg(name="center_coord_key", since="0.8", msg_suffix="coords stored in img.meta['crop_center']") def __init__( self, keys: KeysCollection, @@ -739,78 +732,33 @@ def __init__( allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) - self.spatial_size = ensure_tuple(spatial_size) self.w_key = w_key - self.num_samples = int(num_samples) - self.center_coord_key = center_coord_key - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: List[np.ndarray] = [] - - def randomize(self, weight_map: NdarrayOrTensor) -> None: - self.centers = weighted_patch_samples( - spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R - ) + self.cropper = RandWeightedCrop(spatial_size, num_samples) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: - d = dict(data) - self.randomize(d[self.w_key]) - _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) - - # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(data) for _ in range(self.num_samples)] - # fill in the extra keys with unmodified data - for i in range(self.num_samples): - for key in set(data.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(data[key]) - for key in self.key_iterator(d): - img = d[key] - if img.shape[1:] != d[self.w_key].shape[1:]: - raise ValueError( - f"data {key} and weight map {self.w_key} spatial shape mismatch: " - f"{img.shape[1:]} vs {d[self.w_key].shape[1:]}." - ) - for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - orig_size = img.shape[1:] - results[i][key] = cropper(img) - self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) - if self.center_coord_key: - results[i][self.center_coord_key] = center - # fill in the extra keys with unmodified data - for i in range(self.num_samples): - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCropd": + super().set_random_state(seed, state) + if isinstance(self.cropper, Randomizable): + self.cropper.set_random_state(seed, state) + return self - return results + def randomize(self, weight_map: NdarrayOrTensor) -> None: + self.cropper.randomize(weight_map) - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - current_size = np.asarray(d[key].shape[1:]) - center = transform[TraceKeys.EXTRA_INFO]["center"] - cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: # type: ignore + # output starts as empty list of dictionaries + ret: List = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) - return d + self.randomize(weight_map=data[self.w_key]) + for key in self.key_iterator(data): + for i, im in enumerate(self.cropper(data[key], weight_map=data[self.w_key], randomize=False)): + ret[i][key] = im + return ret class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index a357398f1c..9ef91fd177 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -14,14 +14,13 @@ import numpy as np from monai.transforms.croppad.dictionary import RandWeightedCropd -from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose class TestRandWeightedCrop(NumpyImageTestCase2D): def test_rand_weighted_crop_small_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.seg1[0] n_samples = 3 crop = RandWeightedCropd("img", "w", (10, 12), n_samples) @@ -34,12 +33,12 @@ def test_rand_weighted_crop_small_roi(self): result = crop(d) self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) - for c, e in zip(crop.centers, [[80, 21], [30, 17], [40, 31]]): + for c, e in zip(crop.cropper.centers, [[80, 21], [30, 17], [40, 31]]): assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_default_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.imt[0] n_samples = 3 crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") @@ -52,13 +51,13 @@ def test_rand_weighted_crop_default_roi(self): result = crop(data) self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) - for c, e in zip(crop.centers, [[14, 32], [105, 32], [20, 32]]): + for c, e in zip(crop.cropper.centers, [[14, 32], [105, 32], [20, 32]]): assert_allclose(c, e, type_test=False) - assert_allclose(result[1]["coords"], [105, 32], type_test=False) + assert_allclose(result[1]["im"].meta["crop_center"], [105, 32], type_test=False) def test_rand_weighted_crop_large_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.segn[0] n_samples = 3 crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") @@ -71,13 +70,13 @@ def test_rand_weighted_crop_large_roi(self): self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) - for c, e in zip(crop.centers, [[64, 32], [64, 32], [64, 32]]): + for c, e in zip(crop.cropper.centers, [[64, 32], [64, 32], [64, 32]]): assert_allclose(c, e, type_test=False) - assert_allclose(result[1]["location"], [64, 32], type_test=False) + assert_allclose(result[1]["img"].meta["crop_center"], [64, 32], type_test=False) def test_rand_weighted_crop_bad_w(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.imt[0] n_samples = 3 crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) @@ -90,14 +89,14 @@ def test_rand_weighted_crop_bad_w(self): self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) - for c, e in zip(crop.centers, [[63, 37], [31, 43], [66, 20]]): + for c, e in zip(crop.cropper.centers, [[63, 37], [31, 43], [66, 20]]): assert_allclose(c, e, type_test=False) class TestRandWeightedCrop3D(NumpyImageTestCase3D): def test_rand_weighted_crop_small_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.seg1[0] n_samples = 3 crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) @@ -109,12 +108,12 @@ def test_rand_weighted_crop_small_roi(self): result = crop({"img": p(img), "w": q(weight)}) self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) - for c, e in zip(crop.centers, [[11, 23, 21], [5, 30, 17], [8, 40, 31]]): + for c, e in zip(crop.cropper.centers, [[11, 23, 21], [5, 30, 17], [8, 40, 31]]): assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_default_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.imt[0] n_samples = 3 crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) @@ -127,12 +126,12 @@ def test_rand_weighted_crop_default_roi(self): self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) - for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): + for c, e in zip(crop.cropper.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_large_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.segn[0] n_samples = 3 crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) @@ -143,12 +142,12 @@ def test_rand_weighted_crop_large_roi(self): result = crop({"img": p(img), "w": q(weight)}) self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): + for c, e in zip(crop.cropper.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_bad_w(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.imt[0] n_samples = 3 crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) @@ -161,12 +160,12 @@ def test_rand_weighted_crop_bad_w(self): self.assertTrue(len(result) == n_samples) np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) - for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): + for c, e in zip(crop.cropper.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_patch_index(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: img = self.imt[0] n_samples = 3 crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) @@ -176,16 +175,16 @@ def test_rand_weighted_crop_patch_index(self): weight[0, 24, 21] = 1 crop.set_random_state(10) result = crop( - {"img": p(img), "seg": p(self.segn[0]), "w": q(weight), PostFix.meta("img"): {"affine": None}} + {"img": p(img), "seg": p(self.segn[0]), "w": q(weight)} ) self.assertTrue(len(result) == n_samples) - for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): + for c, e in zip(crop.cropper.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): assert_allclose(c, e, type_test=False) for i in range(n_samples): np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i][PostFix.meta("img")]["patch_index"], i) - np.testing.assert_allclose(result[i][PostFix.meta("seg")]["patch_index"], i) + np.testing.assert_allclose(result[i]["img"].meta["patch_index"], i) + np.testing.assert_allclose(result[i]["seg"].meta["patch_index"], i) if __name__ == "__main__": From 30edb96e5979146108d40b655c5de52db02eb3fd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 19:14:45 +0800 Subject: [PATCH 29/47] [DLMED] update pos neg crop dict Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 6 +- monai/transforms/croppad/dictionary.py | 135 ++++++---------------- tests/test_rand_crop_by_pos_neg_labeld.py | 13 +-- 3 files changed, 43 insertions(+), 111 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6d99c39297..8ca55b588d 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1056,6 +1056,7 @@ def __call__( image: Optional[torch.Tensor] = None, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, + randomize: bool = True, ) -> List[torch.Tensor]: """ Args: @@ -1069,6 +1070,7 @@ def __call__( need to provide `fg_indices` and `bg_indices` together. bg_indices: background indices to randomly select crop centers, need to provide `fg_indices` and `bg_indices` together. + randomize: whether to execute the random operations, default to `True`. """ if label is None: @@ -1078,7 +1080,8 @@ def __call__( if image is None: image = self.image - self.randomize(label, fg_indices, bg_indices, image) + if randomize: + self.randomize(label, fg_indices, bg_indices, image) results: List[torch.Tensor] = [] if self.centers is not None: for i, center in enumerate(self.centers): @@ -1087,7 +1090,6 @@ def __call__( if get_track_meta(): cropped.meta[Key.PATCH_INDEX] = i results.append(cropped) - return results diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index cef360e13c..60578c4316 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -19,7 +19,7 @@ from enum import Enum from itertools import chain from math import floor -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union import numpy as np import torch @@ -49,9 +49,7 @@ from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( generate_label_classes_crop_centers, - generate_pos_neg_label_crop_centers, is_positive, - map_binary_to_indices, map_classes_to_indices, ) from monai.utils import MAX_SEED @@ -572,8 +570,8 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform): """ backend = RandSpatialCropSamples.backend - @deprecated_arg(name="meta_keys", since="0.8") - @deprecated_arg(name="meta_key_postfix", since="0.8") + @deprecated_arg(name="meta_keys", since="0.9") + @deprecated_arg(name="meta_key_postfix", since="0.9") def __init__( self, keys: KeysCollection, @@ -589,12 +587,6 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) self.cropper = RandSpatialCropSamples(roi_size, num_samples, max_roi_size, random_center, random_size) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSpatialCropSamplesd": - super().set_random_state(seed, state) - return self - def randomize(self, data: Optional[Any] = None) -> None: self.sub_seed = self.R.randint(MAX_SEED, dtype="uint32") @@ -717,9 +709,9 @@ class RandWeightedCropd(Randomizable, MapTransform): backend = SpatialCrop.backend - @deprecated_arg(name="meta_keys", since="0.8") - @deprecated_arg(name="meta_key_postfix", since="0.8") - @deprecated_arg(name="center_coord_key", since="0.8", msg_suffix="coords stored in img.meta['crop_center']") + @deprecated_arg(name="meta_keys", since="0.9") + @deprecated_arg(name="meta_key_postfix", since="0.9") + @deprecated_arg(name="center_coord_key", since="0.9", msg_suffix="coords stored in img.meta['crop_center']") def __init__( self, keys: KeysCollection, @@ -739,8 +731,7 @@ def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandCropd": super().set_random_state(seed, state) - if isinstance(self.cropper, Randomizable): - self.cropper.set_random_state(seed, state) + self.cropper.set_random_state(seed, state) return self def randomize(self, weight_map: NdarrayOrTensor) -> None: @@ -761,7 +752,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, return ret -class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): +@deprecated_arg(name="meta_keys", since="0.9") +@deprecated_arg(name="meta_key_postfix", since="0.9") +class RandCropByPosNegLabeld(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`. Crop random fixed sized regions with the center being a foreground or background voxel @@ -802,15 +795,6 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key` and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndicesd` transform first and cache the results. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). @@ -843,23 +827,24 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key - self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size - if pos < 0 or neg < 0: - raise ValueError(f"pos and neg must be nonnegative, got pos={pos} neg={neg}.") - if pos + neg == 0: - raise ValueError("Incompatible values: pos=0 and neg=0.") - self.pos_ratio = pos / (pos + neg) - self.num_samples = num_samples self.image_key = image_key - self.image_threshold = image_threshold self.fg_indices_key = fg_indices_key self.bg_indices_key = bg_indices_key - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[int]]] = None - self.allow_smaller = allow_smaller + self.cropper = RandCropByPosNegLabel( + spatial_size=spatial_size, + pos=pos, + neg=neg, + num_samples=num_samples, + image_threshold=image_threshold, + allow_smaller=allow_smaller + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCropd": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) + return self def randomize( self, @@ -868,21 +853,7 @@ def randomize( bg_indices: Optional[NdarrayOrTensor] = None, image: Optional[NdarrayOrTensor] = None, ) -> None: - if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) - else: - fg_indices_ = fg_indices - bg_indices_ = bg_indices - self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, - self.num_samples, - self.pos_ratio, - label.shape[1:], - fg_indices_, - bg_indices_, - self.R, - self.allow_smaller, - ) + self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) @@ -892,54 +863,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab bg_indices = d.pop(self.bg_indices_key, None) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) - if self.centers is None: - raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] - - for i, center in enumerate(self.centers): - # fill in the extra keys with unmodified data - for key in set(d.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(d[key]) - for key in self.key_iterator(d): - img = d[key] - orig_size = img.shape[1:] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - results[i][key] = cropper(img) - self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore - - return results - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - current_size = np.asarray(d[key].shape[1:]) - center = transform[TraceKeys.EXTRA_INFO]["center"] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) # type: ignore - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) + ret: List = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) - return d + for key in self.key_iterator(data): + for i, im in enumerate(self.cropper(data[key], label=label, randomize=False)): + ret[i][key] = im + return ret class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index a2808bd65d..9b9c3636c2 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -17,7 +17,7 @@ from monai.transforms import RandCropByPosNegLabeld from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [ [ @@ -35,7 +35,6 @@ "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - PostFix.meta("image"): {"affine": np.eye(3), "shape": "CHWD"}, }, (3, 3, 2, 2), ], @@ -54,7 +53,6 @@ "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - PostFix.meta("label"): {"affine": np.eye(3), "shape": "CHWD"}, }, (3, 2, 2, 2), ], @@ -73,7 +71,6 @@ "image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3]), - PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, }, (3, 2, 2, 2), ], @@ -93,7 +90,6 @@ "image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3]), - PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, }, (3, 3, 3, 2), ], @@ -113,7 +109,6 @@ "image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3]), - PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, }, (3, 3, 3, 3), ], @@ -131,13 +126,13 @@ def convert_data_type(im_type, d, keys=("img", "image", "label")): @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: input_param_mod = self.convert_data_type(p, input_param) input_data_mod = self.convert_data_type(p, input_data) cropper = RandCropByPosNegLabeld(**input_param_mod) cropper.set_random_state(0) result = cropper(input_data_mod) - self.assertListEqual(cropper.spatial_size, input_param["spatial_size"]) + self.assertListEqual(cropper.cropper.spatial_size, input_param["spatial_size"]) self.assertIsInstance(result, list) @@ -146,7 +141,7 @@ def test_type_shape(self, input_param, input_data, expected_shape): for k in ("image", "extra", "label"): self.assertTupleEqual(result[0][k].shape, expected_shape) for i, item in enumerate(result): - self.assertEqual(item[PostFix.meta(k)]["patch_index"], i) + self.assertEqual(item[k].meta["patch_index"], i) def test_correct_center(self): cropper = RandCropByPosNegLabeld(keys="label", label_key="label", spatial_size=[3, 3]) From c261745494964a703a7dfb9c67530f229ee329fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Jun 2022 11:15:32 +0000 Subject: [PATCH 30/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/croppad/dictionary.py | 2 +- tests/test_rand_crop_by_pos_neg_labeld.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 60578c4316..342de3ba73 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -853,7 +853,7 @@ def randomize( bg_indices: Optional[NdarrayOrTensor] = None, image: Optional[NdarrayOrTensor] = None, ) -> None: - self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) + self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 9b9c3636c2..a8feb53b0f 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -16,7 +16,6 @@ from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld -from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS_ALL TESTS = [ From ffe858c43d1a3bb19fcfa8981cff87a4d9f0da89 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 19:37:28 +0800 Subject: [PATCH 31/47] [DLMED] update crop by labels dict Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 5 +- monai/transforms/croppad/dictionary.py | 128 +++++++--------------- tests/test_rand_crop_by_label_classesd.py | 4 +- tests/test_rand_crop_by_pos_neg_labeld.py | 1 - 4 files changed, 45 insertions(+), 93 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 8ca55b588d..a407463eb7 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1206,6 +1206,7 @@ def __call__( label: Optional[torch.Tensor] = None, image: Optional[torch.Tensor] = None, indices: Optional[List[NdarrayOrTensor]] = None, + randomize: bool = True, ) -> List[torch.Tensor]: """ Args: @@ -1215,6 +1216,7 @@ def __call__( image: optional image data to help select valid area, can be same as `img` or another image array. use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`. indices: list of indices for every class in the image, used to randomly select crop centers. + randomize: whether to execute the random operations, default to `True`. """ if label is None: @@ -1224,7 +1226,8 @@ def __call__( if image is None: image = self.image - self.randomize(label, indices, image) + if randomize: + self.randomize(label, indices, image) results: List[NdarrayOrTensor] = [] if self.centers is not None: for i, center in enumerate(self.centers): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 60578c4316..1008e4bc85 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -47,14 +47,9 @@ ) from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable -from monai.transforms.utils import ( - generate_label_classes_crop_centers, - is_positive, - map_classes_to_indices, -) +from monai.transforms.utils import is_positive from monai.utils import MAX_SEED -from monai.utils import ImageMetaKey as Key -from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PostFix, TraceKeys @@ -729,7 +724,7 @@ def __init__( def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandCropd": + ) -> "RandWeightedCropd": super().set_random_state(seed, state) self.cropper.set_random_state(seed, state) return self @@ -836,26 +831,26 @@ def __init__( neg=neg, num_samples=num_samples, image_threshold=image_threshold, - allow_smaller=allow_smaller + allow_smaller=allow_smaller, ) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandCropd": + ) -> "RandCropByPosNegLabeld": super().set_random_state(seed, state) self.cropper.set_random_state(seed, state) return self def randomize( self, - label: NdarrayOrTensor, + label: torch.Tensor, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, ) -> None: self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -877,7 +872,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab return ret -class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): +@deprecated_arg(name="meta_keys", since="0.9") +@deprecated_arg(name="meta_key_postfix", since="0.9") +class RandCropByLabelClassesd(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`. Crop random fixed sized regions with the center being a class based on the specified ratios of every class. @@ -942,15 +939,6 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first and cache the results for better performance. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will remain unchanged. @@ -978,89 +966,51 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key - self.spatial_size = spatial_size - self.ratios = ratios - self.num_classes = num_classes - self.num_samples = num_samples self.image_key = image_key - self.image_threshold = image_threshold self.indices_key = indices_key - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[int]]] = None - self.allow_smaller = allow_smaller + self.cropper = RandCropByLabelClasses( + spatial_size=spatial_size, + ratios=ratios, + num_classes=num_classes, + num_samples=num_samples, + image_threshold=image_threshold, + allow_smaller=allow_smaller, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCropByLabelClassesd": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) + return self def randomize( self, - label: NdarrayOrTensor, + label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, ) -> None: - if indices is None: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) - else: - indices_ = indices - self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller - ) + self.cropper.randomize(label=label, indices=indices, image=image) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]: + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, torch.Tensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None indices = d.pop(self.indices_key, None) if self.indices_key is not None else None self.randomize(label, indices, image) - if self.centers is None: - raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] - - for i, center in enumerate(self.centers): - # fill in the extra keys with unmodified data - for key in set(d.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(d[key]) - for key in self.key_iterator(d): - img = d[key] - orig_size = img.shape[1:] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - results[i][key] = cropper(img) - self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore - - return results - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - current_size = np.asarray(d[key].shape[1:]) - center = transform[TraceKeys.EXTRA_INFO]["center"] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) # type: ignore - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) + ret: List = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) - return d + for key in self.key_iterator(data): + for i, im in enumerate(self.cropper(data[key], label=label, randomize=False)): + ret[i][key] = im + return ret class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 92780458e0..9a99ebab29 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -15,10 +15,10 @@ from parameterized import parameterized from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TESTS.append( [ # One-Hot label diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 9b9c3636c2..a8feb53b0f 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -16,7 +16,6 @@ from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld -from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS_ALL TESTS = [ From 94464dcdde594ea33f2b210b616d2b86c47b342c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Jun 2022 19:48:33 +0800 Subject: [PATCH 32/47] [DLMED] update resize with pad or crop dict Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 57 ++------------------------ tests/test_resize_with_pad_or_cropd.py | 4 +- 2 files changed, 6 insertions(+), 55 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index cb8cf697a8..362f9bd660 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -16,9 +16,6 @@ """ from copy import deepcopy -from enum import Enum -from itertools import chain -from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union import numpy as np @@ -51,7 +48,7 @@ from monai.utils import MAX_SEED from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import PostFix, TraceKeys +from monai.utils.enums import PostFix __all__ = [ "Padd", @@ -1013,7 +1010,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, torch.Te return ret -class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): +class ResizeWithPadOrCropd(Padd): """ Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`. @@ -1037,8 +1034,6 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ - backend = ResizeWithPadOrCrop.backend - def __init__( self, keys: KeysCollection, @@ -1048,52 +1043,8 @@ def __init__( method: Union[Method, str] = Method.SYMMETRIC, **pad_kwargs, ) -> None: - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key, m in self.key_iterator(d, self.mode): - orig_size = d[key].shape[1:] - d[key] = self.padcropper(d[key], mode=m) - self.push_transform(d, key, orig_size=orig_size, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - # Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding. - # Instead, we first pad any smaller dimensions, and then we crop any larger dimensions. - - # First, do pad - if np.any((orig_size - current_size) > 0): - pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) - # in each direction, if original size is even and current size is odd, += 1 - pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 - pad_to_start[pad_to_start < 0] = 0 - pad_to_end = orig_size - current_size - pad_to_start - pad_to_end[pad_to_end < 0] = 0 - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - d[key] = BorderPad(pad)(d[key]) - - # Next crop - if np.any((orig_size - current_size) < 0): - if self.padcropper.padder.method == Method.SYMMETRIC: - roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] - else: - roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] - - d[key] = SpatialCrop(roi_center, orig_size)(d[key]) - - # Remove the applied transform - self.pop_transform(d, key) - - return d + padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) + super().__init__(keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys) class BoundingRectd(MapTransform): diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 28993a2bf4..6658c76386 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TEST_CASES = [ [{"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 8, 8)], @@ -34,7 +34,7 @@ class TestResizeWithPadOrCropd(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_data, expected_val): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: if isinstance(p(0), torch.Tensor) and ( "constant_values" in input_param or input_param["mode"] == "reflect" ): From 33e5093182636b7673e367e0160a268d16bc7f04 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 27 Jun 2022 11:40:37 +0800 Subject: [PATCH 33/47] [DLMED] update format Signed-off-by: Nic Ma --- monai/transforms/__init__.py | 6 +-- monai/transforms/croppad/array.py | 45 +++++------------------ monai/transforms/croppad/dictionary.py | 11 +++--- tests/test_rand_crop_by_pos_neg_labeld.py | 18 ++------- tests/test_rand_weighted_cropd.py | 4 +- 5 files changed, 21 insertions(+), 63 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f02737a4ef..5cc7747aeb 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -56,15 +56,15 @@ Padd, PadD, PadDict, - RandCropd, - RandCropD, - RandCropDict, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, RandCropByPosNegLabeld, RandCropByPosNegLabelD, RandCropByPosNegLabelDict, + RandCropd, + RandCropD, + RandCropDict, RandScaleCropd, RandScaleCropD, RandScaleCropDict, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index a407463eb7..edd3a441a3 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -41,15 +41,8 @@ weighted_patch_samples, ) from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum -from monai.utils import ( - Method, - PytorchPadMode, - ensure_tuple, - ensure_tuple_rep, - fall_back_tuple, - look_up_option, -) from monai.utils import ImageMetaKey as Key +from monai.utils import Method, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option from monai.utils.enums import TraceKeys, TransformBackends from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor @@ -98,10 +91,7 @@ class Pad(InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, - to_pad: Optional[List[Tuple[int, int]]] = None, - mode: str = PytorchPadMode.CONSTANT, - **kwargs, + self, to_pad: Optional[List[Tuple[int, int]]] = None, mode: str = PytorchPadMode.CONSTANT, **kwargs ) -> None: self.to_pad = to_pad self.mode = mode @@ -128,11 +118,7 @@ def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) def __call__( - self, - img: torch.Tensor, - to_pad: Optional[List[Tuple[int, int]]] = None, - mode: Optional[str] = None, - **kwargs, + self, img: torch.Tensor, to_pad: Optional[List[Tuple[int, int]]] = None, mode: Optional[str] = None, **kwargs ) -> torch.Tensor: """ Args: @@ -278,10 +264,7 @@ class BorderPad(Pad): """ def __init__( - self, - spatial_border: Union[Sequence[int], int], - mode: str = PytorchPadMode.CONSTANT, - **kwargs, + self, spatial_border: Union[Sequence[int], int], mode: str = PytorchPadMode.CONSTANT, **kwargs ) -> None: self.spatial_border = spatial_border super().__init__(mode=mode, **kwargs) @@ -479,7 +462,7 @@ def __init__( roi_slices: list of slices for each of the spatial dimensions. """ self.slices = self.compute_slices( - roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices, + roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices ) def __call__(self, img: torch.Tensor) -> torch.Tensor: @@ -824,12 +807,7 @@ def compute_bounding_box(self, img: torch.Tensor): return box_start_, box_end_ def crop_pad( - self, - img: torch.Tensor, - box_start: np.ndarray, - box_end: np.ndarray, - mode: Optional[str] = None, - **pad_kwargs, + self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: Optional[str] = None, **pad_kwargs ): """ Crop and pad based on the bounding box. @@ -904,7 +882,7 @@ def randomize(self, weight_map: NdarrayOrTensor) -> None: ) # using only the first channel as weight map def __call__( - self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = None, randomize: bool = True, + self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = None, randomize: bool = True ) -> List[torch.Tensor]: """ Args: @@ -1183,10 +1161,7 @@ def __init__( self.allow_smaller = allow_smaller def randomize( - self, - label: torch.Tensor, - indices: Optional[List[NdarrayOrTensor]] = None, - image: Optional[torch.Tensor] = None, + self, label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, image: Optional[torch.Tensor] = None ) -> None: indices_: Sequence[NdarrayOrTensor] if indices is None: @@ -1275,9 +1250,7 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__( - self, img: torch.Tensor, mode: Optional[ str] = None, **pad_kwargs, - ) -> torch.Tensor: + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) -> torch.Tensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 362f9bd660..d938e4f877 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -45,8 +45,7 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import is_positive -from monai.utils import MAX_SEED -from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple_rep +from monai.utils import MAX_SEED, Method, NumpyPadMode, PytorchPadMode, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PostFix @@ -115,6 +114,7 @@ class Padd(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Pad`. """ + backend = Pad.backend def __init__( @@ -328,6 +328,7 @@ class RandCropd(Cropd, Randomizable): allow_missing_keys: don't raise exception if key is missing. """ + backend = Crop.backend def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): @@ -560,6 +561,7 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform): ValueError: When ``num_samples`` is nonpositive. """ + backend = RandSpatialCropSamples.backend @deprecated_arg(name="meta_keys", since="0.9") @@ -982,10 +984,7 @@ def set_random_state( return self def randomize( - self, - label: torch.Tensor, - indices: Optional[List[NdarrayOrTensor]] = None, - image: Optional[torch.Tensor] = None, + self, label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, image: Optional[torch.Tensor] = None ) -> None: self.cropper.randomize(label=label, indices=indices, image=image) diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index a8feb53b0f..64673bf4bf 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -66,11 +66,7 @@ "image_key": None, "image_threshold": 0, }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - }, + {"image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3])}, (3, 2, 2, 2), ], [ @@ -85,11 +81,7 @@ "image_threshold": 0, "allow_smaller": True, }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - }, + {"image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3])}, (3, 3, 3, 2), ], [ @@ -104,11 +96,7 @@ "image_threshold": 0, "allow_smaller": True, }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - }, + {"image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3])}, (3, 3, 3, 3), ], ] diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 9ef91fd177..cb46a892f5 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -174,9 +174,7 @@ def test_rand_weighted_crop_patch_index(self): weight[0, 13, 31] = 1.1 weight[0, 24, 21] = 1 crop.set_random_state(10) - result = crop( - {"img": p(img), "seg": p(self.segn[0]), "w": q(weight)} - ) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) self.assertTrue(len(result) == n_samples) for c, e in zip(crop.cropper.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): assert_allclose(c, e, type_test=False) From 52d18189546826ed773aeb414ed85eb037b1c3c2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 27 Jun 2022 18:59:40 +0800 Subject: [PATCH 34/47] [DLMED] fix all the mypy errors Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 116 +++++++++++++------------ monai/transforms/croppad/dictionary.py | 35 ++++---- 2 files changed, 78 insertions(+), 73 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index edd3a441a3..8725600d6f 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -40,7 +40,6 @@ map_classes_to_indices, weighted_patch_samples, ) -from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum from monai.utils import ImageMetaKey as Key from monai.utils import Method, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option from monai.utils.enums import TraceKeys, TransformBackends @@ -117,7 +116,7 @@ def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: # torch.pad expects `[B, C, H, W, [D]]` shape return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) - def __call__( + def __call__( # type: ignore self, img: torch.Tensor, to_pad: Optional[List[Tuple[int, int]]] = None, mode: Optional[str] = None, **kwargs ) -> torch.Tensor: """ @@ -157,11 +156,11 @@ def __call__( mode = convert_pad_mode(dst=img_np, mode=mode_).value out = torch.as_tensor(self._np_pad(img_np, pad_width=to_pad_, mode=mode_, **kwargs_)) if get_track_meta(): - out = MetaTensor(out, meta=img_t.meta, applied_operations=img_t.applied_operations) # type: ignore + out = MetaTensor(out, meta=img_t.meta, applied_operations=img_t.applied_operations) else: out = img_t if get_track_meta(): - self._update_meta(tensor=out, to_pad=to_pad_) + self._update_meta(tensor=out, to_pad=to_pad_) # type: ignore self.push_transform(out, extra_info={"padded": to_pad_}) return out @@ -171,7 +170,7 @@ def _update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): mat = create_translate(spatial_rank, to_shift) tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] - def inverse(self, data: torch.Tensor) -> torch.Tensor: + def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) padded = transform[TraceKeys.EXTRA_INFO]["padded"] if padded[0][0] != 0 or padded[0][1] != 0: @@ -182,7 +181,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: roi_end = [i - j[1] for i, j in zip(data.shape[1:], padded[1:])] cropper = SpatialCrop(roi_start=roi_start, roi_end=roi_end) with cropper.trace_transform(False): - return cropper(data) + return cropper(data) # type: ignore class SpatialPad(Pad): @@ -360,7 +359,7 @@ def compute_slices( roi_slices: list of slices for each of the spatial dimensions. """ - roi_start_torch: torch.Tensor + roi_start_t: torch.Tensor if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): @@ -368,44 +367,40 @@ def compute_slices( return list(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center, *_ = convert_data_type( - data=roi_center, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True - ) - roi_size, *_ = convert_to_dst_type(src=roi_size, dst=roi_center, wrap_sequence=True) - _zeros = torch.zeros_like(roi_center) - roi_start_torch = maximum(roi_center - floor_divide(roi_size, 2), _zeros) - roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch) + roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True) + roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True) + _zeros = torch.zeros_like(roi_center_t) + roi_start_t = torch.maximum(roi_center_t - torch.floor_divide(roi_size_t, 2), _zeros) + roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t) else: if roi_start is None or roi_end is None: raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_torch, *_ = convert_data_type( - data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True - ) - roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) - roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True) - roi_end_torch = maximum(roi_end_torch, roi_start_torch) + roi_start_t = convert_to_tensor(data=roi_start, dtype=torch.int16, wrap_sequence=True) + roi_start_t = torch.maximum(roi_start_t, torch.zeros_like(roi_start_t)) + roi_end_t = convert_to_tensor(data=roi_end, dtype=torch.int16, wrap_sequence=True) + roi_end_t = torch.maximum(roi_end_t, roi_start_t) # convert to slices (accounting for 1d) - if roi_start_torch.numel() == 1: - return [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] + if roi_start_t.numel() == 1: + return [slice(int(roi_start_t.item()), int(roi_end_t.item()))] else: - return [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] + return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] - def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: + def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ orig_size = img.shape[1:] - slices = list(slices) + slices_ = list(slices) sd = len(img.shape[1:]) # spatial dims - if len(slices) < sd: - slices += [slice(None)] * (sd - len(slices)) + if len(slices_) < sd: + slices_ += [slice(None)] * (sd - len(slices_)) # Add in the channel (no cropping) - slices = [slice(None)] + slices[:sd] + slices = tuple([slice(None)] + slices_[:sd]) - img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - img_t = img_t[tuple(slices)] + img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) + img_t = img_t[slices] # type: ignore if get_track_meta(): self._update_meta(tensor=img_t, slices=slices) cropped_from_start = np.asarray([s.indices(o)[0] for s, o in zip(slices[1:], orig_size)]) @@ -414,20 +409,20 @@ def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor self.push_transform(img_t, extra_info={"cropped": cropped}) return img_t - def _update_meta(self, tensor: MetaTensor, slices: List): + def _update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): spatial_rank = max(len(tensor.affine) - 1, 1) to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]] mat = create_translate(spatial_rank, to_shift) tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] - def inverse(self, img: torch.Tensor) -> torch.Tensor: + def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) cropped = transform[TraceKeys.EXTRA_INFO]["cropped"] # the amount we pad is equal to the amount we cropped in each direction inverse_transform = BorderPad(cropped) # Apply inverse transform with inverse_transform.trace_transform(False): - return inverse_transform(img) + return inverse_transform(img) # type: ignore class SpatialCrop(Crop): @@ -465,7 +460,7 @@ def __init__( roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices ) - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -492,12 +487,12 @@ class CenterSpatialCrop(Crop): def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def compute_slices(self, spatial_size: Sequence[int]): + def compute_slices(self, spatial_size: Sequence[int]): # type: ignore roi_size = fall_back_tuple(self.roi_size, spatial_size) roi_center = [i // 2 for i in spatial_size] return super().compute_slices(roi_center=roi_center, roi_size=roi_size) - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -519,7 +514,7 @@ class CenterScaleCrop(Crop): def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -563,7 +558,7 @@ def __init__( self.random_center = random_center self.random_size = random_size self._size: Optional[Sequence[int]] = None - self._slices: Optional[Tuple[slice, ...]] = None + self._slices: Tuple[slice, ...] def randomize(self, img_size: Sequence[int]) -> None: self._size = fall_back_tuple(self.roi_size, img_size) @@ -576,7 +571,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -636,7 +631,7 @@ def randomize(self, img_size: Sequence[int]) -> None: self.get_max_roi_size(img_size) super().randomize(img_size) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -710,7 +705,7 @@ def __call__(self, img: torch.Tensor) -> List[torch.Tensor]: ret = [self.cropper(img) for _ in range(self.num_samples)] if get_track_meta(): for i, r in enumerate(ret): - r.meta[Key.PATCH_INDEX] = i + r.meta[Key.PATCH_INDEX] = i # type: ignore return ret @@ -823,11 +818,12 @@ def crop_pad( # combine the traced cropping and padding into one transformation # by taking the padded info and placing it in a key inside the crop info. if get_track_meta(): - app_op = ret.applied_operations.pop(-1) - ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = app_op + ret_: MetaTensor = ret # type: ignore + app_op = ret_.applied_operations.pop(-1) + ret_.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = app_op return ret - def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs): + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs): # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. @@ -839,7 +835,7 @@ def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs): return cropped, box_start, box_end return cropped - def inverse(self, img: torch.Tensor) -> torch.Tensor: + def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.get_most_recent_transform(img) # we moved the padding info in the forward, so put it back for the inverse pad_info = transform[TraceKeys.EXTRA_INFO].pop("pad_info") @@ -905,12 +901,13 @@ def __call__( if randomize: self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results: List[NdarrayOrTensor] = [] + results: List[torch.Tensor] = [] for i, center in enumerate(self.centers): cropped = SpatialCrop(roi_center=center, roi_size=_spatial_size)(img) if get_track_meta(): - cropped.meta[Key.PATCH_INDEX] = i - cropped.meta["crop_center"] = center + ret_: MetaTensor = cropped # type: ignore + ret_.meta[Key.PATCH_INDEX] = i + ret_.meta["crop_center"] = center results.append(cropped) return results @@ -1066,7 +1063,9 @@ def __call__( roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) cropped = SpatialCrop(roi_center=center, roi_size=roi_size)(img) if get_track_meta(): - cropped.meta[Key.PATCH_INDEX] = i + ret_: MetaTensor = cropped # type: ignore + ret_.meta[Key.PATCH_INDEX] = i + ret_.meta["crop_center"] = center results.append(cropped) return results @@ -1203,13 +1202,15 @@ def __call__( if randomize: self.randomize(label, indices, image) - results: List[NdarrayOrTensor] = [] + results: List[torch.Tensor] = [] if self.centers is not None: for i, center in enumerate(self.centers): roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) cropped = SpatialCrop(roi_center=tuple(center), roi_size=roi_size)(img) if get_track_meta(): - cropped.meta[Key.PATCH_INDEX] = i + ret_: MetaTensor = cropped # type: ignore + ret_.meta[Key.PATCH_INDEX] = i + ret_.meta["crop_center"] = center results.append(cropped) return results @@ -1250,7 +1251,7 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) -> torch.Tensor: + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) -> torch.Tensor: # type: ignore """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1268,16 +1269,17 @@ def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) # remove the individual info and combine if get_track_meta(): - pad_info = ret.applied_operations.pop(-1) - crop_info = ret.applied_operations.pop(-1) - self.push_transform(ret, extra_info={"pad_info": pad_info, "crop_info": crop_info}) + ret_: MetaTensor = ret # type: ignore + pad_info = ret_.applied_operations.pop(-1) + crop_info = ret_.applied_operations.pop(-1) + self.push_transform(ret_, extra_info={"pad_info": pad_info, "crop_info": crop_info}) return ret - def inverse(self, img: torch.Tensor) -> torch.Tensor: + def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) return self.inverse_transform(img, transform) - def inverse_transform(self, img: torch.Tensor, transform) -> torch.Tensor: + def inverse_transform(self, img: MetaTensor, transform) -> MetaTensor: # we joined the cropping and padding, so put them back before calling the inverse crop_info = transform[TraceKeys.EXTRA_INFO].pop("crop_info") pad_info = transform[TraceKeys.EXTRA_INFO].pop("pad_info") diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index d938e4f877..3858e8de30 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -23,6 +23,7 @@ from monai.config import IndexSelection, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import ( BorderPad, BoundingRect, @@ -45,7 +46,7 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import is_positive -from monai.utils import MAX_SEED, Method, NumpyPadMode, PytorchPadMode, ensure_tuple_rep +from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PostFix @@ -118,7 +119,11 @@ class Padd(MapTransform, InvertibleTransform): backend = Pad.backend def __init__( - self, keys: KeysCollection, padder: Pad, mode: str = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False + self, + keys: KeysCollection, + padder: Pad, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -145,7 +150,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.padder.inverse(d[key]) @@ -307,10 +312,10 @@ def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): - d[key] = self.cropper(d[key]) + d[key] = self.cropper(d[key]) # type: ignore return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.cropper.inverse(d[key]) @@ -352,7 +357,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc self.randomize(d[self.first_key(d)].shape[1:]) for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} - d[key] = self.cropper(d[key], **kwargs) + d[key] = self.cropper(d[key], **kwargs) # type: ignore return d @@ -584,7 +589,7 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: self.sub_seed = self.R.randint(MAX_SEED, dtype="uint32") - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: # output starts as empty list of dictionaries ret: List[Dict[Hashable, torch.Tensor]] = [{} for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data @@ -623,7 +628,7 @@ def __init__( margin: Union[Sequence[int], int] = 0, allow_smaller: bool = True, k_divisible: Union[Sequence[int], int] = 1, - mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, @@ -657,7 +662,6 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - super().__init__(keys, allow_missing_keys) self.source_key = source_key self.start_coord_key = start_coord_key self.end_coord_key = end_coord_key @@ -670,11 +674,11 @@ def __init__( **pad_kwargs, ) super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) + self.cropper: CropForeground box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) d[self.start_coord_key] = box_start d[self.end_coord_key] = box_end @@ -731,7 +735,7 @@ def set_random_state( def randomize(self, weight_map: NdarrayOrTensor) -> None: self.cropper.randomize(weight_map) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: # type: ignore + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: # output starts as empty list of dictionaries ret: List = [{} for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data @@ -793,9 +797,8 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform): the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). allow_missing_keys: don't raise exception if key is missing. - - Raises: - ValueError: When ``pos`` or ``neg`` are negative. + padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) + super().__init__(keys, padder=padcropper, mg`` are negative. ValueError: When ``pos=0`` and ``neg=0``. Incompatible values. """ @@ -1037,13 +1040,13 @@ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], - mode: Union[Sequence[str], str] = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, method: Union[Method, str] = Method.SYMMETRIC, **pad_kwargs, ) -> None: padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) - super().__init__(keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys) + super().__init__(keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys) # type: ignore class BoundingRectd(MapTransform): From da5195b77aac0d97a52c0b7c936a2144eb4eaee0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 11:18:44 +0800 Subject: [PATCH 35/47] [DLMED] add crop / pad base tests Signed-off-by: Nic Ma --- tests/croppers.py | 103 ++++++++++++++++++++++++++++++++++++++++++++ tests/padders.py | 106 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 tests/croppers.py create mode 100644 tests/padders.py diff --git a/tests/croppers.py b/tests/croppers.py new file mode 100644 index 0000000000..8f78249d90 --- /dev/null +++ b/tests/croppers.py @@ -0,0 +1,103 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from copy import deepcopy + +import numpy as np + +from monai.data.meta_tensor import MetaTensor +from monai.transforms.transform import MapTransform +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose + + +class CropTest(unittest.TestCase): + @staticmethod + def get_arr(shape): + return np.random.randint(100, size=shape).astype(float) + + def crop_test(self, input_param, input_shape, expected_shape, same_area=None): + base_comparison = None + input_image = self.get_arr(input_shape) + + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + # input parameters, such as roi_start can be numpy, torch, list etc. + for param_type in TEST_NDARRAYS_ALL + (None,): + with self.subTest(param_type=param_type): + input_param_mod = deepcopy(input_param) + if param_type is not None: + for k in ("roi_start", "roi_end", "roi_center", "roi_size", "roi_scale"): + if k in input_param: + input_param_mod[k] = param_type(input_param[k]) + im = im_type(input_image) + cropper = self.Cropper(**input_param_mod) + is_map = isinstance(cropper, MapTransform) + input_data = {"img": im} if is_map else im + result = cropper(input_data) + out_im = result["img"] if is_map else result + self.assertIsInstance(out_im, MetaTensor) + self.assertTupleEqual(out_im.shape, expected_shape) + if same_area is not None: + assert_allclose(out_im, im[same_area], type_test=False) + # check result is the same regardless of input type + if base_comparison is None: + base_comparison = out_im + else: + assert_allclose(out_im, base_comparison) + + # test inverse + inv = cropper.inverse(result) + inv_im = inv["img"] if is_map else inv + self.assertIsInstance(inv_im, MetaTensor) + if same_area is not None: + assert_allclose(inv_im[same_area], im[same_area], type_test=False) + self.assertEqual(inv_im.applied_operations, []) + + def crop_test_value(self, input_param, input_arr, expected_array): + cropper = self.Cropper(**input_param) + is_map = isinstance(cropper, MapTransform) + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + im = im_type(input_arr) + input_data = {"img": im} if is_map else im + result = self.Cropper(**input_param)(input_data) + out_im = result["img"] if is_map else result + self.assertIsInstance(out_im, MetaTensor) + assert_allclose(out_im, expected_array, type_test=False) + + def multi_inverse(self, input_shape, init_params): + input_data = np.arange(np.prod(input_shape)).reshape(*input_shape) + 1 + xform = self.Cropper(**init_params) + xform.set_random_state(1234) + out = xform(input_data) + if "num_samples" in init_params: + self.assertEqual(len(out), init_params["num_samples"]) + inv = xform.inverse(out) + self.assertIsInstance(inv, MetaTensor) + self.assertEqual(inv.applied_operations, []) + self.assertTrue("patch_index" not in inv.meta) + self.assertTupleEqual(inv.shape, input_shape) + inv_np = inv.numpy() + + # get list of all numbers that exist inside the crops + uniques = set() + for o in out: + uniques.update(set(o.flatten().tolist())) + + # make sure that + for i in uniques: + a = np.where(input_data == i) + b = np.where(inv_np == i) + self.assertTupleEqual(a, b) + # there should be as many zeros as elements missing from uniques + missing = input_data.size - len(uniques) + self.assertEqual((inv_np == 0).sum(), missing) diff --git a/tests/padders.py b/tests/padders.py new file mode 100644 index 0000000000..42052c6d06 --- /dev/null +++ b/tests/padders.py @@ -0,0 +1,106 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import List + +import numpy as np +import torch + +from monai.data.meta_tensor import MetaTensor +from monai.transforms.transform import MapTransform +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose + +MODES = [] +# Test modes +NP_MODES: List = [ + "constant", + "edge", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", + "wrap", + "median", +] +MODES += NP_MODES +MODES += [NumpyPadMode(i) for i in NP_MODES] + +PT_MODES: list = [ + "constant", + "replicate", + "circular", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", +] +MODES += PT_MODES +MODES += [PytorchPadMode(i) for i in PT_MODES] + + +class PadTest(unittest.TestCase): + @staticmethod + def get_arr(shape): + return np.random.randint(100, size=shape).astype(float) + + def pad_test(self, input_param, input_shape, expected_shape, modes=None): + # loop over each mode + for mode in modes or MODES: + with self.subTest(mode=mode): + base_comparison = None + im = self.get_arr(input_shape) + padder = self.Padder(mode=mode, **input_param) + is_map = isinstance(padder, MapTransform) + # check result is the same regardless of input type + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + input_image = im_type(im) + input_data = {"img": im_type(im)} if is_map else im_type(im) + # our array transforms can also take `mode` as an argument to `__call__` + # Check this gives equivalent results + for call_extra_args in [{}] if is_map else [{}, {"mode": mode}]: + with self.subTest(call_extra_args=call_extra_args): + r_out = padder(input_data, **call_extra_args) + r_im = r_out["img"] if is_map else r_out + # check shape, type, etc. + np.testing.assert_allclose(r_im.shape, expected_shape) + self.assertIsInstance(r_im, MetaTensor) + self.assertEqual(len(r_im.applied_operations), 1) + # check results are same regardless of input type + if base_comparison is None: + base_comparison = r_im + torch.testing.assert_allclose(r_im, base_comparison, atol=0, rtol=1e-5) + # test inverse + if isinstance(r_im, MetaTensor): + r_out = padder.inverse(r_out) + r_im = r_out["img"] if is_map else r_out + self.assertIsInstance(r_im, MetaTensor) + assert_allclose(r_im, input_image, type_test=False) + self.assertEqual(r_im.applied_operations, []) + + def pad_test_kwargs(self, unchanged_slices, **input_param): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + for kwargs in ({"value": 2}, {"constant_values": ((0, 0), (1, 1), (2, 2))}): + with self.subTest(kwargs=kwargs): + im = im_type(np.random.randint(-100, -10, size=(3, 8, 4))) + padder = self.Padder(**input_param, **kwargs) + result = padder(im) + if isinstance(result, torch.Tensor): + result = result.cpu() + assert_allclose(result[unchanged_slices], im, type_test=False) + # we should have the same as the input plus some 2s (if value) or 1s and 2s (if constant_values) + expected_vals = np.unique(im).tolist() + expected_vals += [2] if "value" in kwargs else [1, 2] + assert_allclose(np.unique(result), expected_vals, type_test=False) + # check inverse + if isinstance(result, MetaTensor): + inv = padder.inverse(result) + assert_allclose(im, inv, type_test=False) + self.assertEqual(inv.applied_operations, []) From 7a573fa4dd42532cd3d416a72864bdcdd648ada5 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 11:25:23 +0800 Subject: [PATCH 36/47] [DLMED] update border pad test Signed-off-by: Nic Ma --- tests/padders.py | 3 ++- tests/test_border_pad.py | 47 +++++++++++++----------------------- tests/test_border_padd.py | 50 ++++++++++++--------------------------- 3 files changed, 34 insertions(+), 66 deletions(-) diff --git a/tests/padders.py b/tests/padders.py index 42052c6d06..932a0566cc 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -75,7 +75,8 @@ def pad_test(self, input_param, input_shape, expected_shape, modes=None): # check results are same regardless of input type if base_comparison is None: base_comparison = r_im - torch.testing.assert_allclose(r_im, base_comparison, atol=0, rtol=1e-5) + else: + assert_allclose(r_im, base_comparison) # test inverse if isinstance(r_im, MetaTensor): r_out = padder.inverse(r_out) diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index b06aa7c564..1194ae49a6 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -11,45 +11,32 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import BorderPad -from monai.utils import NumpyPadMode -from tests.utils import TEST_NDARRAYS_ALL - -TEST_CASE_1 = [{"spatial_border": 2, "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 12, 12, 8))] - -TEST_CASE_2 = [{"spatial_border": [1, 2, 3], "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 10, 12, 10))] - -TEST_CASE_3 = [ - {"spatial_border": [1, 2, 3, 4, 5, 6], "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 11, 15, 15)), +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.padders import PadTest + +TESTS = [ + [{"spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], + [{"spatial_border": [1, 2, 3]}, (3, 8, 8, 4), (3, 10, 12, 10)], + [{"spatial_border": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)], + [{"spatial_border": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)], ] -TEST_CASE_4 = [ - {"spatial_border": [1, 2, 3, 4, 5, 6], "mode": NumpyPadMode.CONSTANT}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 11, 15, 15)), -] +class TestBorderPad(PadTest): + Padder = BorderPad -class TestBorderPad(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) - def test_pad_shape(self, input_param, input_data, expected_val): - for p in TEST_NDARRAYS_ALL: - padder = BorderPad(**input_param) - r1 = padder(p(input_data)) - r2 = padder(input_data, mode=input_param["mode"]) - self.assertAlmostEqual(r1.shape, expected_val.shape) - self.assertAlmostEqual(r2.shape, expected_val.shape) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT] + self.pad_test(input_param, input_shape, expected_shape, modes) def test_pad_kwargs(self): - padder = BorderPad(spatial_border=2, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) - result = padder(np.zeros((3, 8, 4))) - np.testing.assert_allclose(result[:, :2, 2:6], np.ones((3, 2, 4))) - np.testing.assert_allclose(result[:, :, :2], np.ones((3, 12, 2)) + 1) + kwargs = {"spatial_border": 2, "mode": "constant"} + unchanged_slices = [slice(None), slice(2, -2), slice(2, -2)] + self.pad_test_kwargs(unchanged_slices, **kwargs) if __name__ == "__main__": diff --git a/tests/test_border_padd.py b/tests/test_border_padd.py index e4b8dd20ea..b8a29a873e 100644 --- a/tests/test_border_padd.py +++ b/tests/test_border_padd.py @@ -11,49 +11,29 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import BorderPadd from monai.utils import NumpyPadMode - -TEST_CASE_1 = [ - {"keys": ["img", "seg"], "spatial_border": 2, "mode": ["constant", "edge"]}, - {"img": np.zeros((3, 8, 8, 4)), "seg": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 12, 12, 8)), -] - -TEST_CASE_2 = [ - {"keys": "img", "spatial_border": [1, 2, 3], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 10, 12, 10)), -] - -TEST_CASE_3 = [ - {"keys": "img", "spatial_border": [1, 2, 3, 4, 5, 6], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 11, 15, 15)), -] - -TEST_CASE_4 = [ - {"keys": ["img", "seg"], "spatial_border": 2, "mode": ["constant", NumpyPadMode.EDGE]}, - {"img": np.zeros((3, 8, 8, 4)), "seg": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 12, 12, 8)), +from monai.utils.enums import PytorchPadMode +from tests.padders import PadTest + +TESTS = [ + [{"keys": "img", "spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], + [{"keys": "img", "spatial_border": [1, 2, 3]}, (3, 8, 8, 4), (3, 10, 12, 10)], + [{"keys": "img", "spatial_border": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)], + [{"keys": "img", "spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], + [{"keys": "img", "spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], ] -TEST_CASE_5 = [ - {"keys": ["img", "seg"], "spatial_border": 2, "mode": [NumpyPadMode.CONSTANT, NumpyPadMode.EDGE]}, - {"img": np.zeros((3, 8, 8, 4)), "seg": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 12, 12, 8)), -] +class TestBorderPadd(PadTest): + Padder = BorderPadd -class TestBorderPadd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = BorderPadd(**input_param) - result = padder(input_data) - self.assertAlmostEqual(result["img"].shape, expected_val.shape) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, "edge", NumpyPadMode.EDGE] + self.pad_test(input_param, input_shape, expected_shape, modes) if __name__ == "__main__": From e8de8e357ac000b0f94790ed8982eb2a5458376e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 11:35:23 +0800 Subject: [PATCH 37/47] [DLMED] update spatial crop Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 2 +- tests/test_center_scale_crop.py | 46 ++++++++++----------- tests/test_center_scale_cropd.py | 47 +++++++++------------ tests/test_center_spatial_crop.py | 45 +++++++++------------ tests/test_center_spatial_cropd.py | 65 +++++++++++++++--------------- 5 files changed, 94 insertions(+), 111 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 8725600d6f..584f95865d 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -370,7 +370,7 @@ def compute_slices( roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True) roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True) _zeros = torch.zeros_like(roi_center_t) - roi_start_t = torch.maximum(roi_center_t - torch.floor_divide(roi_size_t, 2), _zeros) + roi_start_t = torch.maximum(roi_center_t - torch.div(roi_size_t, 2, rounding_mode="floor"), _zeros) roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t) else: if roi_start is None or roi_end is None: diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index 5476321165..ab07a44eb5 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -9,44 +9,40 @@ # See the License for the specific language governing permissions and # limitations under the License. + import unittest import numpy as np -import torch from parameterized import parameterized -from monai.data import MetaTensor from monai.transforms import CenterScaleCrop +from tests.croppers import CropTest -TEST_CASE_0 = [{"roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] - -TEST_CASE_1 = [{"roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] - -TEST_CASE_2 = [ - {"roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), +TEST_SHAPES = [ + [{"roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3)], + [{"roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2)], ] -TEST_CASE_3 = [ - {"roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), +TEST_VALUES = [ + [ + {"roi_scale": [0.4, 0.4]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] ] -class TestCenterScaleCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterScaleCrop(**input_param)(input_data) - self.assertTrue(isinstance(result, MetaTensor)) - np.testing.assert_allclose(result.shape, expected_shape) +class TestCenterSpatialCrop(CropTest): + Cropper = CenterScaleCrop + + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) - def test_value(self, input_param, input_data, expected_value): - result = CenterScaleCrop(**input_param)(input_data) - self.assertTrue(isinstance(result, MetaTensor)) - np.testing.assert_allclose(result, expected_value) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_arr, expected_arr): + self.crop_test_value(input_param, input_arr, expected_arr) if __name__ == "__main__": diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 8aef2dbe5b..894692530d 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -12,44 +12,37 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import CenterScaleCropd +from tests.croppers import CropTest -TEST_CASE_0 = [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] - -TEST_CASE_1 = [{"keys": "img", "roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] - -TEST_CASE_2 = [ - {"keys": "img", "roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), +TESTS = [ + [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3)], + [{"keys": "img", "roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"keys": "img", "roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2)], ] -TEST_CASE_3 = [ - {"keys": "img", "roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] -TEST_CASE_4 = [ - {"keys": "test", "roi_scale": 0.6, "allow_missing_keys": True}, - np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"keys": "img", "roi_scale": [0.4, 0.4]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] ] -class TestCenterScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3, TEST_CASE_4]) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterScaleCropd(**input_param)({"img": input_data}) - np.testing.assert_allclose(result["img"].shape, expected_shape) +class TestCenterScaleCropd(CropTest): + Cropper = CenterScaleCropd + + @parameterized.expand(TESTS) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) - def test_value(self, input_param, input_data, expected_value): - result = CenterScaleCropd(**input_param)({"img": input_data}) - np.testing.assert_allclose(result["img"], expected_value) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_arr, expected_arr): + self.crop_test_value(input_param, input_arr, expected_arr) if __name__ == "__main__": diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 771ba650a9..7b5b19107d 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -12,41 +12,36 @@ import unittest import numpy as np -import torch from parameterized import parameterized -from monai.data import MetaTensor from monai.transforms import CenterSpatialCrop +from tests.croppers import CropTest -TEST_CASE_0 = [{"roi_size": [2, 2, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 3)] - -TEST_CASE_1 = [{"roi_size": [2, 2, 2]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] - -TEST_CASE_2 = [ - {"roi_size": [2, 2]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), +TEST_SHAPES = [ + [{"roi_size": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3)], + [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], ] -TEST_CASE_3 = [ - {"roi_size": [2, 2, 2]}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), +TEST_VALUES = [ + [ + {"roi_size": [2, 2]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] ] -class TestCenterSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterSpatialCrop(**input_param)(input_data) - self.assertTrue(isinstance(result, MetaTensor)) - np.testing.assert_allclose(result.shape, expected_shape) +class TestCenterSpatialCrop(CropTest): + Cropper = CenterSpatialCrop + + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) - def test_value(self, input_param, input_data, expected_value): - result = CenterSpatialCrop(**input_param)(input_data) - self.assertTrue(isinstance(result, MetaTensor)) - np.testing.assert_allclose(result, expected_value) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_arr, expected_arr): + self.crop_test_value(input_param, input_arr, expected_arr) if __name__ == "__main__": diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index 02e100bb92..fa7bc8c8fa 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -15,43 +15,42 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCropd -from tests.utils import TEST_NDARRAYS_ALL, assert_allclose - -TEST_SHAPES = [] -for p in TEST_NDARRAYS_ALL: - TEST_SHAPES.append( - [{"keys": "img", "roi_size": [2, -1, -1]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 3, 3)] - ) - - TEST_SHAPES.append( - [{"keys": "img", "roi_size": [2, 2, 2]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 2, 2)] - ) - -TEST_CASES = [] -for p in TEST_NDARRAYS_ALL: - TEST_CASES.append( - [ - {"keys": "img", "roi_size": [2, 2]}, - { - "img": p( - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) - ) - }, - p(np.array([[[1, 2], [2, 3]]])), - ] - ) - - -class TestCenterSpatialCropd(unittest.TestCase): +from tests.croppers import CropTest + +TEST_SHAPES = [ + [ + {"keys": "img", "roi_size": [2, -1, -1]}, + (3, 3, 3, 3), + (3, 2, 3, 3), + (slice(None), slice(None, -1), slice(None), slice(None)), + ], + [ + {"keys": "img", "roi_size": [2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, -1), slice(None, -1), slice(None, -1)), + ], +] + +TEST_CASES = [ + [ + {"keys": "img", "roi_size": [2, 2]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] +] + + +class TestCenterSpatialCropd(CropTest): + Cropper = CenterSpatialCropd + @parameterized.expand(TEST_SHAPES) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterSpatialCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + def test_shape(self, input_param, input_shape, expected_shape, same_area): + self.crop_test(input_param, input_shape, expected_shape, same_area) @parameterized.expand(TEST_CASES) def test_value(self, input_param, input_data, expected_value): - result = CenterSpatialCropd(**input_param)(input_data) - assert_allclose(result["img"], expected_value, type_test=False) + self.crop_test_value(input_param, input_data, expected_value) if __name__ == "__main__": From a91cfa38d934b5de1f39986c4481f04c09bc2cad Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 11:57:03 +0800 Subject: [PATCH 38/47] [DLMED] update pad transforms Signed-off-by: Nic Ma --- tests/test_border_padd.py | 3 +- tests/test_crop_foreground.py | 15 +++-- tests/test_crop_foregroundd.py | 17 ++--- tests/test_divisible_pad.py | 39 ++++------- tests/test_divisible_padd.py | 29 ++++----- tests/test_rand_scale_crop.py | 85 +++++++++++------------- tests/test_rand_scale_cropd.py | 115 +++++++++++++++++---------------- 7 files changed, 143 insertions(+), 160 deletions(-) diff --git a/tests/test_border_padd.py b/tests/test_border_padd.py index b8a29a873e..ca55c8b09d 100644 --- a/tests/test_border_padd.py +++ b/tests/test_border_padd.py @@ -14,8 +14,7 @@ from parameterized import parameterized from monai.transforms import BorderPadd -from monai.utils import NumpyPadMode -from monai.utils.enums import PytorchPadMode +from monai.utils.enums import NumpyPadMode, PytorchPadMode from tests.padders import PadTest TESTS = [ diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index a9a891100c..e400406e4d 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -12,11 +12,11 @@ import unittest import numpy as np -import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForeground -from tests.utils import TEST_NDARRAYS_ALL +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_COORDS, TESTS = [], [] @@ -89,8 +89,15 @@ class TestCropForeground(unittest.TestCase): @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, argments, image, expected_data): - result = CropForeground(**argments)(image) - torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) + cropper = CropForeground(**argments) + result = cropper(image) + assert_allclose(result, expected_data, type_test=False) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(len(result.applied_operations), 1) + inv = cropper.inverse(result) + self.assertIsInstance(inv, MetaTensor) + self.assertEqual(inv.applied_operations, []) + self.assertTupleEqual(inv.shape, image.shape) @parameterized.expand(TEST_COORDS) def test_return_coords(self, argments, image, _): diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index 1bd08e13e9..d641c5a376 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -12,10 +12,8 @@ import unittest import numpy as np -import torch from parameterized import parameterized -from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForegroundd from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -152,12 +150,15 @@ class TestCropForegroundd(unittest.TestCase): @parameterized.expand(TEST_POSITION + TESTS) def test_value(self, argments, input_data, expected_data): - result = CropForegroundd(**argments)(input_data) - r, i = result["img"], input_data["img"] - self.assertEqual(type(r), MetaTensor) - if isinstance(i, torch.Tensor): - self.assertEqual(r.device, i.device) - assert_allclose(r, expected_data, type_test=False) + cropper = CropForegroundd(**argments) + result = cropper(input_data) + assert_allclose(result["img"], expected_data, type_test=False) + if "label" in input_data and "img" in input_data: + self.assertTupleEqual(result["img"].shape, result["label"].shape) + inv = cropper.inverse(result) + self.assertTupleEqual(inv["img"].shape, input_data["img"].shape) + if "label" in input_data: + self.assertTupleEqual(inv["label"].shape, input_data["label"].shape) @parameterized.expand(TEST_POSITION) def test_foreground_position(self, argments, input_data, _): diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index 4428078f40..df610c4939 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -11,43 +11,32 @@ import unittest -import numpy as np -import torch from parameterized import parameterized from monai.transforms import DivisiblePad -from tests.utils import TEST_NDARRAYS_ALL +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.padders import PadTest TESTS = [] -for p in TEST_NDARRAYS_ALL: - # pad first dim to be divisible by 7, the second unchanged. - TESTS.append([{"k": (7, -1), "mode": "constant"}, p(np.zeros((3, 8, 7))), p(np.zeros((3, 14, 7)))]) +# pad first dim to be divisible by 7, the second unchanged. +TESTS.append([{"k": (7, -1)}, (3, 8, 7), (3, 14, 7)]) +# pad all dimensions to be divisible by 5 +TESTS.append([{"k": 5, "method": "end"}, (3, 10, 5, 17), (3, 10, 5, 20)]) - # pad all dimensions to be divisible by 5 - TESTS.append( - [{"k": 5, "mode": "constant", "method": "end"}, p(np.zeros((3, 10, 5, 17))), p(np.zeros((3, 10, 5, 20)))] - ) +class TestDivisiblePad(PadTest): + Padder = DivisiblePad -class TestDivisiblePad(unittest.TestCase): @parameterized.expand(TESTS) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = DivisiblePad(**input_param) - result = padder(input_data) - self.assertAlmostEqual(result.shape, expected_val.shape) - result = padder(input_data, mode=input_param["mode"]) - self.assertAlmostEqual(result.shape, expected_val.shape) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT] + self.pad_test(input_param, input_shape, expected_shape, modes) def test_pad_kwargs(self): - for p in TEST_NDARRAYS_ALL: - input_data = p(np.zeros((3, 8, 4))) - if isinstance(input_data, np.ndarray): - result = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))(input_data) - np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) - else: - result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() - torch.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1, rtol=1e-7, atol=0) + kwargs = {"k": 5, "method": "end"} + unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)] + self.pad_test_kwargs(unchanged_slices, **kwargs) if __name__ == "__main__": diff --git a/tests/test_divisible_padd.py b/tests/test_divisible_padd.py index 61fe917421..93e5a879f0 100644 --- a/tests/test_divisible_padd.py +++ b/tests/test_divisible_padd.py @@ -11,32 +11,25 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import DivisiblePadd +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.padders import PadTest -TEST_CASE_1 = [ - {"keys": ["img"], "k": [4, 3, 2], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 8, 9, 4)), +TESTS = [ + [{"keys": "img", "k": [4, 3, 2]}, (3, 8, 8, 4), (3, 8, 9, 4)], + [{"keys": "img", "k": 7, "method": "end"}, (3, 8, 7), (3, 14, 7)], ] -TEST_CASE_2 = [ - {"keys": ["img"], "k": 7, "mode": "constant", "method": "end"}, - {"img": np.zeros((3, 8, 7))}, - np.zeros((3, 14, 7)), -] - -TEST_CASE_3 = [{"keys": ["img"], "k": 0, "mode": {"constant"}}, {"img": np.zeros((3, 8))}, np.zeros((3, 8))] +class TestDivisiblePadd(PadTest): + Padder = DivisiblePadd -class TestDivisiblePadd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = DivisiblePadd(**input_param) - result = padder(input_data) - np.testing.assert_allclose(result["img"], expected_val) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, "edge", NumpyPadMode.EDGE] + self.pad_test(input_param, input_shape, expected_shape, modes) if __name__ == "__main__": diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index 58ed65bf0d..aea26d62bb 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -15,66 +15,57 @@ from parameterized import parameterized from monai.transforms import RandScaleCrop +from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_1 = [ - {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), +TEST_SHAPES = [ + [{"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, (3, 3, 3, 4), (3, 3, 3, 4)], + [{"roi_scale": [1.0, 1.0, 1.0], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], ] -TEST_CASE_2 = [ - {"roi_scale": [1.0, 1.0, 1.0], "random_center": False}, - np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"roi_scale": [0.6, 0.6], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_3 = [ - {"roi_scale": [0.6, 0.6], "random_center": False}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), +TEST_RANDOM_SHAPES = [ + [ + {"roi_scale": [0.75, 0.6, 0.5], "max_roi_scale": [1.0, -1.0, 0.6], "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 4, 3), + ], + [{"roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, (1, 4, 5, 6), (1, 3, 4, 4)], + [{"roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, (1, 4, 5, 6), (1, 3, 2, 4)], ] -TEST_CASE_4 = [ - {"roi_scale": [0.75, 0.6, 0.5], "max_roi_scale": [1.0, -1.0, 0.6], "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 4, 3), -] - -TEST_CASE_5 = [ - {"roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 4, 4), -] - -TEST_CASE_6 = [ - {"roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 2, 4), -] +class TestRandScaleCrop(CropTest): + Cropper = RandScaleCrop -class TestRandScaleCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS_ALL: - result = RandScaleCrop(**input_param)(p(input_data)) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS_ALL: - cropper = RandScaleCrop(**input_param) - result = cropper(p(input_data)) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandScaleCrop(**input_param) + result = cropper(im_type(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_random_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS_ALL: - cropper = RandScaleCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(p(input_data)) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(seed=123) + input_data = im_type(np.random.randint(0, 2, input_shape)) + result = cropper(input_data) + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 177cf7fa49..645c058dfb 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -15,74 +15,77 @@ from parameterized import parameterized from monai.transforms import RandScaleCropd +from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), +TEST_SHAPES = [ + [{"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, (3, 3, 3, 4), (3, 3, 3, 4)], + [ + # test `allow_missing_keys` with key "label" + {"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, + (3, 3, 3, 3), + (3, 3, 3, 3), + ], ] -TEST_CASE_2 = [ - # test `allow_missing_keys` with key "label" - {"keys": ["img", "label"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"keys": "img", "roi_scale": [0.6, 0.6], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_3 = [ - {"keys": "img", "roi_scale": [0.6, 0.6], "random_center": False}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, +TEST_RANDOM_SHAPES = [ + [ + { + "keys": "img", + "roi_scale": [0.75, 0.6, 0.5], + "max_roi_scale": [1.0, -1.0, 0.6], + "random_center": True, + "random_size": True, + }, + (1, 4, 5, 6), + (1, 3, 4, 3), + ], + [ + {"keys": "img", "roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 4, 4), + ], + [ + {"keys": "img", "roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 2, 4), + ], ] -TEST_CASE_4 = [ - { - "keys": "img", - "roi_scale": [0.75, 0.6, 0.5], - "max_roi_scale": [1.0, -1.0, 0.6], - "random_center": True, - "random_size": True, - }, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 4, 3), -] - -TEST_CASE_5 = [ - {"keys": "img", "roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 4, 4), -] - -TEST_CASE_6 = [ - {"keys": "img", "roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 2, 4), -] +class TestRandScaleCropd(CropTest): + Cropper = RandScaleCropd -class TestRandScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) - def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS_ALL: - cropper = RandScaleCropd(**input_param) - input_data["img"] = p(input_data["img"]) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] - assert_allclose( - result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False - ) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_im): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + input_data = {"img": im_type(input_im)} + result = cropper(input_data)["img"] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] + assert_allclose(result, input_im[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCropd(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + cropper.set_random_state(seed=123) + input_data = {"img": im_type(np.random.randint(0, 2, input_shape))} + result = cropper(input_data)["img"] + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": From bf108554652fe10be5d8a66b5f7905a1aea292df Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 15:11:55 +0800 Subject: [PATCH 39/47] [DLMED] update samples crop Signed-off-by: Nic Ma --- tests/test_rand_spatial_crop.py | 79 +++--- tests/test_rand_spatial_crop_samples.py | 21 +- tests/test_rand_spatial_cropd.py | 93 ++++---- tests/test_rand_weighted_crop.py | 11 +- tests/test_rand_weighted_cropd.py | 305 +++++++++++------------- 5 files changed, 242 insertions(+), 267 deletions(-) diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 5521ede350..0c8d4ab132 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -15,60 +15,57 @@ from parameterized import parameterized from monai.transforms import RandSpatialCrop +from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_0 = [ - {"roi_size": [3, 3, -1], "random_center": True}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), +TEST_SHAPES = [ + [{"roi_size": [3, 3, -1], "random_center": True}, (3, 3, 3, 4), (3, 3, 3, 4)], + [{"roi_size": [3, 3, 3], "random_center": True}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"roi_size": [3, 3, 3], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], ] -TEST_CASE_1 = [{"roi_size": [3, 3, 3], "random_center": True}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 3, 3, 3)] - -TEST_CASE_2 = [ - {"roi_size": [3, 3, 3], "random_center": False}, - np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 3, 3, 3), -] - -TEST_CASE_3 = [ - {"roi_size": [3, 3], "random_center": False}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), +TEST_VALUES = [ + [ + {"roi_size": [3, 3], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_4 = [ - {"roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 4, 4, 3), +TEST_RANDOM_SHAPES = [ + [ + {"roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 4, 4, 3), + ], + [{"roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, (1, 4, 5, 6), (1, 3, 4, 3)], ] -TEST_CASE_5 = [ - {"roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 4, 3), -] +class TestRandSpatialCrop(CropTest): + Cropper = RandSpatialCrop -class TestRandSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - result = RandSpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS_ALL: - cropper = RandSpatialCrop(**input_param) - result = cropper(p(input_data)) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandSpatialCrop(**input_param) + result = cropper(im_type(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) - def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandSpatialCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandSpatialCrop(**input_param) + cropper.set_random_state(seed=123) + input_data = im_type(np.random.randint(0, 2, input_shape)) + result = cropper(input_data) + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index de96908cc6..537c7b4e4e 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -15,11 +15,12 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropSamples +from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, - np.arange(192).reshape(3, 4, 4, 4), + (3, 4, 4, 4), [(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], np.array( [ @@ -44,7 +45,7 @@ TEST_CASE_2 = [ {"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True}, - np.arange(192).reshape(3, 4, 4, 4), + (3, 4, 4, 4), [(3, 4, 4, 3), (3, 4, 3, 3), (3, 3, 4, 4), (3, 4, 4, 4), (3, 3, 3, 4), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], np.array( [ @@ -67,10 +68,22 @@ ), ] +TEST_INVERSE_LIST = [ + [(1, 2, 2), {"roi_size": (1, 1), "num_samples": 4, "random_size": False}], + [(1, 3, 2), {"roi_size": (1, 1), "num_samples": 100, "random_size": False}], + [(3, 10, 11, 12), {"roi_size": (3, 5, 4), "num_samples": 7, "random_size": False}], + [(3, 10, 11, 12), {"roi_size": (10, 11, 12), "num_samples": 3, "random_size": False}], + [(3, 10, 11, 12), {"roi_size": (3, 4, 5), "num_samples": 100, "random_size": False}], +] + + +class TestRandSpatialCropSamples(CropTest): + Cropper = RandSpatialCropSamples -class TestRandSpatialCropSamples(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape, expected_last_item): + def test_shape(self, input_param, input_shape, expected_shape, expected_last_item): + input_data = np.arange(192).reshape(*input_shape) + for p in TEST_NDARRAYS_ALL: xform = RandSpatialCropSamples(**input_param) xform.set_random_state(1234) diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 29d988562f..c6a0fbe5e7 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -15,65 +15,62 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropd -from tests.utils import TEST_NDARRAYS_ALL +from tests.croppers import CropTest +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_0 = [ - {"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 5])}, - (3, 3, 3, 5), +TEST_SHAPES = [ + [{"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, (3, 3, 3, 5), (3, 3, 3, 5)], + [{"keys": "img", "roi_size": [3, 3, 3], "random_center": True}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"keys": "img", "roi_size": [3, 3, 3], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], ] -TEST_CASE_1 = [ - {"keys": "img", "roi_size": [3, 3, 3], "random_center": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"keys": "img", "roi_size": [3, 3], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_2 = [ - {"keys": "img", "roi_size": [3, 3, 3], "random_center": False}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 3, 3, 3), +TEST_RANDOM_SHAPES = [ + [ + {"keys": "img", "roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 4, 4, 3), + ], + [ + {"keys": "img", "roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 4, 3), + ], ] -TEST_CASE_3 = [ - {"keys": "img", "roi_size": [3, 3], "random_center": False}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, -] - -TEST_CASE_4 = [ - {"keys": "img", "roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 4, 4, 3), -] - -TEST_CASE_5 = [ - {"keys": "img", "roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 4, 3), -] +class TestRandSpatialCropd(CropTest): + Cropper = RandSpatialCropd -class TestRandSpatialCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - result = RandSpatialCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) - def test_value(self, input_param, input_data): - cropper = RandSpatialCropd(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] - np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_im): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + input_data = {"img": im_type(input_im)} + result = cropper(input_data)["img"] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] + assert_allclose(result, input_im[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) - def test_random_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS_ALL: - cropper = RandSpatialCropd(**input_param) - cropper.set_random_state(seed=123) - input_data["img"] = p(input_data["img"]) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + cropper.set_random_state(seed=123) + input_data = {"img": im_type(np.random.randint(0, 2, input_shape))} + result = cropper(input_data)["img"] + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 952aff8327..696de9c05e 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -12,11 +12,11 @@ import unittest import numpy as np -import torch from parameterized.parameterized import parameterized from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import RandWeightedCrop +from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose @@ -149,7 +149,9 @@ def get_data(ndim): ) -class TestRandWeightedCrop(unittest.TestCase): +class TestRandWeightedCrop(CropTest): + Cropper = RandWeightedCrop + @parameterized.expand(TESTS) def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, expected_vals): crop = RandWeightedCrop(**input_params) @@ -162,10 +164,9 @@ def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, # if desired ROI is larger than image, check image is unchanged if all(s >= i for i, s in zip(img.shape[1:], input_params["spatial_size"])): for res in result: - self.assertEqual(MetaTensor, type(res)) - if isinstance(img, torch.Tensor): - self.assertEqual(res.device, img.device) + self.assertIsInstance(res, MetaTensor) assert_allclose(res, img, type_test=False) + self.assertEqual(len(res.applied_operations), 1) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index cb46a892f5..b3fc92b445 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -12,177 +12,144 @@ import unittest import numpy as np +from parameterized import parameterized from monai.transforms.croppad.dictionary import RandWeightedCropd -from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose - - -class TestRandWeightedCrop(NumpyImageTestCase2D): - def test_rand_weighted_crop_small_roi(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - d = {"img": p(img), "w": q(weight)} - result = crop(d) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) - for c, e in zip(crop.cropper.centers, [[80, 21], [30, 17], [40, 31]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_default_roi(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - data = {"im": p(img), "weight": q(weight), "others": np.nan} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) - for c, e in zip(crop.cropper.centers, [[14, 32], [105, 32], [20, 32]]): - assert_allclose(c, e, type_test=False) - assert_allclose(result[1]["im"].meta["crop_center"], [105, 32], type_test=False) - - def test_rand_weighted_crop_large_roi(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - data = {"img": p(img), "seg": p(self.imt[0]), "weight": q(weight)} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) - for c, e in zip(crop.cropper.centers, [[64, 32], [64, 32], [64, 32]]): - assert_allclose(c, e, type_test=False) - assert_allclose(result[1]["img"].meta["crop_center"], [64, 32], type_test=False) - - def test_rand_weighted_crop_bad_w(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) - for c, e in zip(crop.cropper.centers, [[63, 37], [31, 43], [66, 20]]): - assert_allclose(c, e, type_test=False) - - -class TestRandWeightedCrop3D(NumpyImageTestCase3D): - def test_rand_weighted_crop_small_roi(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop({"img": p(img), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) - for c, e in zip(crop.cropper.centers, [[11, 23, 21], [5, 30, 17], [8, 40, 31]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_default_roi(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) - for c, e in zip(crop.cropper.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_large_roi(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop({"img": p(img), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - for c, e in zip(crop.cropper.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_bad_w(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) - for c, e in zip(crop.cropper.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_patch_index(self): - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - for c, e in zip(crop.cropper.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): - assert_allclose(c, e, type_test=False) - for i in range(n_samples): - np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["img"].meta["patch_index"], i) - np.testing.assert_allclose(result[i]["seg"].meta["patch_index"], i) +from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D + + +def get_data(ndim): + im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D() + im_gen.setUp() + return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0] + + +IMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2) +IMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3) + +TESTS = [] +for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: + im = IMT_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + TESTS.append( + [ + "small roi 2d", + dict(keys="img", w_key="w", spatial_size=(10, 12), num_samples=3), + {"img": p(im), "w": q(weight)}, + (1, 10, 12), + [[80, 21], [30, 17], [40, 31]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + TESTS.append( + [ + "default roi 2d", + dict(keys="img", w_key="w", spatial_size=(10, -1), num_samples=3), + {"img": p(im), "w": q(weight), "others": np.nan}, + (1, 10, 64), + [[14, 32], [105, 32], [20, 32]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + TESTS.append( + [ + "large roi 2d", + dict(keys=("img", "seg"), w_key="weight", spatial_size=(10000, 400), num_samples=3), + {"img": p(im), "seg": p(SEGN_2D), "weight": q(weight)}, + (1, 128, 64), + [[64, 32], [64, 32], [64, 32]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w roi 2d", + dict(keys=("img", "seg"), w_key="w", spatial_size=(20, 40), num_samples=3), + {"img": p(im), "seg": p(SEGN_2D), "w": q(weight)}, + (1, 20, 40), + [[63, 37], [31, 43], [66, 20]], + ] + ) + + im = IMT_3D + weight = np.zeros_like(im) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + TESTS.append( + [ + "small roi 3d", + dict(keys="img", w_key="w", spatial_size=(8, 10, 12), num_samples=3), + {"img": p(im), "w": q(weight)}, + (1, 8, 10, 12), + [[11, 23, 21], [5, 30, 17], [8, 40, 31]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + TESTS.append( + [ + "default roi 3d", + dict(keys=("img", "seg"), w_key="w", spatial_size=(10, -1, -1), num_samples=3), + {"img": p(im), "seg": p(SEGN_3D), "w": q(weight)}, + (1, 10, 64, 80), + [[14, 32, 40], [41, 32, 40], [20, 32, 40]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + TESTS.append( + [ + "large roi 3d", + dict(keys="img", w_key="w", spatial_size=(10000, 400, 80), num_samples=3), + {"img": p(im), "w": q(weight)}, + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w roi 3d", + dict(keys=("img", "seg"), w_key="w", spatial_size=(48, 64, 80), num_samples=3), + {"img": p(im), "seg": p(SEGN_3D), "w": q(weight)}, + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + + +class TestRandWeightedCrop(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, expected_centers): + crop = RandWeightedCropd(**init_params) + crop.set_random_state(10) + result = crop(input_data) + self.assertTrue(len(result) == init_params["num_samples"]) if __name__ == "__main__": From c43a87df8ec6782bb1c0e25e764fb8c7d8c44304 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 15:26:00 +0800 Subject: [PATCH 40/47] [DLMED] update crop tests Signed-off-by: Nic Ma --- tests/test_resize_with_pad_or_crop.py | 27 ++++++-- tests/test_resize_with_pad_or_cropd.py | 21 +++++-- tests/test_spatial_crop.py | 37 +++++------ tests/test_spatial_cropd.py | 87 +++++++++++++------------- tests/test_spatial_pad.py | 83 +++--------------------- tests/test_spatial_padd.py | 40 ++++-------- 6 files changed, 119 insertions(+), 176 deletions(-) diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 2eb39bfe4d..4e097cd3d4 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -15,8 +15,9 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCrop -from tests.utils import TEST_NDARRAYS_ALL +from tests.utils import TEST_NDARRAYS_ALL, pytorch_after TEST_CASES = [ [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8)], @@ -26,8 +27,16 @@ (3, 15, 4, 8), ], [{"spatial_size": [15, 4, -1], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 4, 4)], - [{"spatial_size": [15, 4, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 15, 4, 4)], - [{"spatial_size": [-1, -1, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 8, 8, 4)], + [ + {"spatial_size": [15, 4, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + (3, 8, 8, 4), + (3, 15, 4, 4), + ], + [ + {"spatial_size": [-1, -1, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + (3, 8, 8, 4), + (3, 8, 8, 4), + ], ] @@ -39,11 +48,17 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): "constant_values" in input_param or input_param["mode"] == "reflect" ): continue - paddcroper = ResizeWithPadOrCrop(**input_param) - result = paddcroper(p(np.zeros(input_shape))) + padcropper = ResizeWithPadOrCrop(**input_param) + result = padcropper(p(np.zeros(input_shape))) np.testing.assert_allclose(result.shape, expected_shape) - result = paddcroper(p(np.zeros(input_shape)), mode="constant") + result = padcropper(p(np.zeros(input_shape)), mode="constant") np.testing.assert_allclose(result.shape, expected_shape) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(len(result.applied_operations), 1) + inv = padcropper.inverse(result) + self.assertTupleEqual(inv.shape, input_shape) + self.assertIsInstance(inv, MetaTensor) + self.assertEqual(inv.applied_operations, []) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 6658c76386..eb4e5f09cc 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd -from tests.utils import TEST_NDARRAYS_ALL +from tests.utils import TEST_NDARRAYS_ALL, pytorch_after TEST_CASES = [ [{"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 8, 8)], @@ -26,8 +26,16 @@ (3, 15, 4, 8), ], [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], - [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], - [{"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 8, 8, 4)], + [ + {"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + {"img": np.zeros((3, 8, 8, 4))}, + (3, 15, 4, 4), + ], + [ + {"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + {"img": np.zeros((3, 8, 8, 4))}, + (3, 8, 8, 4), + ], ] @@ -39,10 +47,13 @@ def test_pad_shape(self, input_param, input_data, expected_val): "constant_values" in input_param or input_param["mode"] == "reflect" ): continue - paddcroper = ResizeWithPadOrCropd(**input_param) + padcropper = ResizeWithPadOrCropd(**input_param) input_data["img"] = p(input_data["img"]) - result = paddcroper(input_data) + result = padcropper(input_data) np.testing.assert_allclose(result["img"].shape, expected_val) + inv = padcropper.inverse(result) + for k in input_data: + self.assertTupleEqual(inv[k].shape, input_data[k].shape) if __name__ == "__main__": diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index ebf4665a23..6fdfbd3f70 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -11,12 +11,10 @@ import unittest -import numpy as np -import torch from parameterized import parameterized from monai.transforms import SpatialCrop -from tests.utils import TEST_NDARRAYS_ALL, assert_allclose +from tests.croppers import CropTest TESTS = [ [{"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], @@ -26,31 +24,28 @@ [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [8, 8, 8, 2, 2]}, (3, 3, 3, 3), (3, 3, 3, 3)], [{"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, (3, 3, 3, 3), (3, 0, 3, 3)], - [{"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, (3, 3, 3, 3), (3, 1, 2, 2)], + [ + {"roi_slices": [slice(s, e) for s, e in zip([None, None, None], [None, None, None])]}, + (3, 11, 12, 15), + (3, 11, 12, 15), + ], + [{"roi_slices": [slice(s, e) for s, e in zip([1, None, 0], [None, None, None])]}, (3, 7, 9, 11), (3, 6, 9, 11)], + [{"roi_slices": [slice(s, e) for s, e in zip([0, None, None], [-1, None, None])]}, (3, 7, 9, 11), (3, 6, 9, 11)], + [{"roi_slices": [slice(s, e) for s, e in zip([1, None, None], [None, None, None])]}, (3, 10, 8, 6), (3, 9, 8, 6)], + [{"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, (3, 15, 17, 8), (3, 1, 2, 2)], + [{"roi_slices": [slice(s, e) for s, e in zip([None, None, None], [-2, -1, 2])]}, (3, 13, 8, 6), (3, 11, 7, 2)], + [{"roi_start": [-1, 0], "roi_end": [5, 5]}, (1, 5, 5), (1, 5, 5)], ] TEST_ERRORS = [[{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}]] -class TestSpatialCrop(unittest.TestCase): +class TestSpatialCrop(CropTest): + Cropper = SpatialCrop + @parameterized.expand(TESTS) def test_shape(self, input_param, input_shape, expected_shape): - input_data = np.random.randint(0, 2, size=input_shape) - results = [] - for p in TEST_NDARRAYS_ALL: - for q in TEST_NDARRAYS_ALL + (None,): - input_param_mod = { - k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() - } - im = p(input_data) - result = SpatialCrop(**input_param_mod)(im) - self.assertTrue(isinstance(result, torch.Tensor)) - if isinstance(im, torch.Tensor): - self.assertEqual(result.device, im.device) - self.assertTupleEqual(result.shape, expected_shape) - results.append(result) - if len(results) > 1: - assert_allclose(results[0], results[-1], type_test=False) + self.crop_test(input_param, input_shape, expected_shape) @parameterized.expand(TEST_ERRORS) def test_error(self, input_param): diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 87f11a106d..11f6da0811 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -11,56 +11,57 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import SpatialCropd -from tests.utils import TEST_NDARRAYS_ALL +from tests.croppers import CropTest -TESTS = [] -for p in TEST_NDARRAYS_ALL: - TESTS.append( - [ - {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 2), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 2), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 3), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 2), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 1, 2, 2), - ] - ) +TESTS = [ + [ + {"keys": ["img"], "roi_center": [1, 1], "roi_size": [2, 2]}, + (1, 3, 3), + (1, 2, 2), + (slice(None), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 3), + (slice(None), slice(None, 2), slice(None, 2), slice(None)), + ], + [ + {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, + (3, 3, 3, 3), + (3, 1, 2, 2), + (slice(None), slice(-1, None), slice(-2, None), slice(0, 2)), + ], +] -class TestSpatialCropd(unittest.TestCase): +class TestSpatialCropd(CropTest): + Cropper = SpatialCropd + @parameterized.expand(TESTS) - def test_shape(self, input_param, input_data, expected_shape): - result = SpatialCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + def test_shape(self, input_param, input_shape, expected_shape, same_area): + self.crop_test(input_param, input_shape, expected_shape, same_area) if __name__ == "__main__": diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 932760c3d9..5a70c10686 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -10,91 +10,28 @@ # limitations under the License. import unittest -from typing import List -import numpy as np -import torch from parameterized import parameterized from monai.transforms import SpatialPad -from monai.utils.enums import NumpyPadMode, PytorchPadMode -from monai.utils.misc import set_determinism -from tests.utils import TEST_NDARRAYS_ALL +from tests.padders import PadTest TESTS = [] +TESTS.append([{"spatial_size": [3, 4], "method": "end"}, (1, 2, 3), (1, 3, 4)]) +TESTS.append([{"spatial_size": [15, 4, -1], "method": "symmetric"}, (3, 8, 8, 4), (3, 15, 8, 4)]) -MODES = [] -# Test modes -NP_MODES: List = [ - "constant", - "edge", - # `reflect` mode is not supported in some PyTorch versions, skip the test - # "reflect", - "wrap", -] -MODES += NP_MODES -MODES += [NumpyPadMode(i) for i in NP_MODES] - -PT_MODES: list = [ - "constant", - "replicate", - "circular", - # `reflect` mode is not supported in some PyTorch versions, skip the test - # "reflect", -] -MODES += PT_MODES -MODES += [PytorchPadMode(i) for i in PT_MODES] - -for mode in MODES: - TESTS.append([{"spatial_size": [3, 4], "method": "end", "mode": mode}, (1, 2, 3), (1, 3, 4)]) - - TESTS.append([{"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, (3, 8, 8, 4), (3, 15, 8, 4)]) - - -class TestSpatialPad(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - - @staticmethod - def get_arr(shape): - return np.random.randint(100, size=shape).astype(float) +class TestSpatialPad(PadTest): + Padder = SpatialPad @parameterized.expand(TESTS) - def test_pad_shape(self, input_param, input_shape, expected_shape): - results_1 = [] - results_2 = [] - input_data = self.get_arr(input_shape) - # check result is the same regardless of input type - for p in TEST_NDARRAYS_ALL: - padder = SpatialPad(**input_param) - r1 = padder(p(input_data)) - r2 = padder(p(input_data), mode=input_param["mode"]) - results_1.append(r1.cpu() if isinstance(r1, torch.Tensor) else r1) - results_2.append(r2.cpu() if isinstance(r2, torch.Tensor) else r2) - for results in (results_1, results_2): - np.testing.assert_allclose(results[-1].shape, expected_shape) - if input_param["mode"] not in ("empty", NumpyPadMode.EMPTY): - torch.testing.assert_allclose(results[0], results[-1], atol=0, rtol=1e-5) + def test_pad(self, input_param, input_shape, expected_shape): + self.pad_test(input_param, input_shape, expected_shape) def test_pad_kwargs(self): - for p in TEST_NDARRAYS_ALL: - input_data = p(np.zeros((3, 8, 4))) - if isinstance(input_data, torch.Tensor): - result = ( - SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) - .cpu() - .numpy() - ) - else: - result = SpatialPad( - spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) - )(img=input_data) - torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) - torch.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1, rtol=1e-7, atol=0) + kwargs = {"spatial_size": [15, 8], "method": "end", "mode": "constant"} + unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)] + self.pad_test_kwargs(unchanged_slices, **kwargs) if __name__ == "__main__": diff --git a/tests/test_spatial_padd.py b/tests/test_spatial_padd.py index 762a1145f5..656a731de0 100644 --- a/tests/test_spatial_padd.py +++ b/tests/test_spatial_padd.py @@ -11,42 +11,26 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import SpatialPadd +from tests.padders import PadTest -TEST_CASE_1 = [ - {"keys": ["img"], "spatial_size": [15, 8, 8], "method": "symmetric", "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 15, 8, 8)), +TESTS = [ + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "symmetric"}, (3, 8, 8, 4), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 4), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 4), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, -1], "method": "end"}, (3, 8, 4, 4), (3, 15, 8, 4)], ] -TEST_CASE_2 = [ - {"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end", "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 15, 8, 8)), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end", "mode": {"constant"}}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 15, 8, 8)), -] - -TEST_CASE_4 = [ - {"keys": ["img"], "spatial_size": [15, 8, -1], "method": "end", "mode": {"constant"}}, - {"img": np.zeros((3, 8, 4, 4))}, - np.zeros((3, 15, 8, 4)), -] +class TestSpatialPadd(PadTest): + Padder = SpatialPadd -class TestSpatialPadd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = SpatialPadd(**input_param) - result = padder(input_data) - np.testing.assert_allclose(result["img"].shape, expected_val.shape) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", {"constant"}] + self.pad_test(input_param, input_shape, expected_shape, modes) if __name__ == "__main__": From e65c609a1e6ec5be250dfdfe8711ea8cea8d67dd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 22:33:26 +0800 Subject: [PATCH 41/47] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 14 +++++++------- monai/transforms/croppad/dictionary.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 584f95865d..91183d6749 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -160,11 +160,11 @@ def __call__( # type: ignore else: out = img_t if get_track_meta(): - self._update_meta(tensor=out, to_pad=to_pad_) # type: ignore + self.update_meta(tensor=out, to_pad=to_pad_) # type: ignore self.push_transform(out, extra_info={"padded": to_pad_}) return out - def _update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): + def update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): spatial_rank = max(len(tensor.affine) - 1, 1) to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad mat = create_translate(spatial_rank, to_shift) @@ -210,7 +210,7 @@ class SpatialPad(Pad): def __init__( self, spatial_size: Union[Sequence[int], int], - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, **kwargs, ) -> None: @@ -298,7 +298,7 @@ class DivisiblePad(Pad): def __init__( self, k: Union[Sequence[int], int], - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, **kwargs, ) -> None: @@ -402,14 +402,14 @@ def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) img_t = img_t[slices] # type: ignore if get_track_meta(): - self._update_meta(tensor=img_t, slices=slices) + self.update_meta(tensor=img_t, slices=slices) cropped_from_start = np.asarray([s.indices(o)[0] for s, o in zip(slices[1:], orig_size)]) cropped_from_end = np.asarray(orig_size) - img_t.shape[1:] - cropped_from_start cropped = list(chain(*zip(cropped_from_start.tolist(), cropped_from_end.tolist()))) self.push_transform(img_t, extra_info={"cropped": cropped}) return img_t - def _update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): + def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): spatial_rank = max(len(tensor.affine) - 1, 1) to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]] mat = create_translate(spatial_rank, to_shift) @@ -1244,7 +1244,7 @@ class ResizeWithPadOrCrop(InvertibleTransform): def __init__( self, spatial_size: Union[Sequence[int], int], - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, ): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 3858e8de30..ae8a74d7c7 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -168,7 +168,7 @@ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **kwargs, @@ -260,7 +260,7 @@ def __init__( keys: KeysCollection, k: Union[Sequence[int], int], mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -1042,7 +1042,7 @@ def __init__( spatial_size: Union[Sequence[int], int], mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, **pad_kwargs, ) -> None: padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) From 65bda07f904e38997087b53083affc38ff311728 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 22:35:30 +0800 Subject: [PATCH 42/47] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 91183d6749..4bdc3e6127 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -298,8 +298,8 @@ class DivisiblePad(Pad): def __init__( self, k: Union[Sequence[int], int], - method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, + method: str = Method.SYMMETRIC, **kwargs, ) -> None: """ @@ -307,14 +307,14 @@ def __init__( k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - method: {``"symmetric"``, ``"end"``} - Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + method: {``"symmetric"``, ``"end"``} + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. From 5514c690cc7a6bef5ba414e847cc62f5705d0317 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jun 2022 22:46:23 +0800 Subject: [PATCH 43/47] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index ae8a74d7c7..70a775578b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -107,8 +107,6 @@ "RandCropByLabelClassesDict", ] -DEFAULT_POST_FIX = PostFix.meta() - class Padd(MapTransform, InvertibleTransform): """ @@ -580,7 +578,7 @@ def __init__( random_center: bool = True, random_size: bool = True, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = PostFix.meta(), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -718,7 +716,7 @@ def __init__( num_samples: int = 1, center_coord_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = PostFix.meta(), allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) @@ -797,8 +795,9 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform): the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). allow_missing_keys: don't raise exception if key is missing. - padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) - super().__init__(keys, padder=padcropper, mg`` are negative. + + Raises: + ValueError: When ``pos`` or ``neg`` are negative. ValueError: When ``pos=0`` and ``neg=0``. Incompatible values. """ @@ -818,7 +817,7 @@ def __init__( fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = PostFix.meta(), allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: @@ -962,7 +961,7 @@ def __init__( image_threshold: float = 0.0, indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = PostFix.meta(), allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: From f8e23794476070b54ddcc0ca76549e69ac6679ea Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 29 Jun 2022 10:09:48 +0800 Subject: [PATCH 44/47] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 31 +++++++++++++++++-------------- tests/padders.py | 1 + 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 4bdc3e6127..7904da0355 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -107,8 +107,13 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @staticmethod - def _np_pad(img: np.ndarray, pad_width, mode, **kwargs) -> np.ndarray: - return np.pad(img, pad_width, mode=mode, **kwargs) + def _np_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: + img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img + mode = convert_pad_mode(dst=img_np, mode=mode).value + out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) + if isinstance(img, MetaTensor): + out = MetaTensor(out, meta=img.meta, applied_operations=img.applied_operations) + return out @staticmethod def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: @@ -145,18 +150,16 @@ def __call__( # type: ignore # all zeros, skip padding if np.asarray(to_pad_).any(): - try: - mode_ = convert_pad_mode(dst=img_t, mode=mode_).value - out = self._pt_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) - # but if mode or args don't exist in pytorch, use numpy instead - except (ValueError, TypeError) as err: - if "Unsupported option" in str(err) or "unexpected keyword" in str(err): - # extract metadata - img_np = img_t.detach().cpu().numpy() - mode = convert_pad_mode(dst=img_np, mode=mode_).value - out = torch.as_tensor(self._np_pad(img_np, pad_width=to_pad_, mode=mode_, **kwargs_)) - if get_track_meta(): - out = MetaTensor(out, meta=img_t.meta, applied_operations=img_t.applied_operations) + if mode in ["linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"]: + out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) + else: + try: + mode_ = convert_pad_mode(dst=img_t, mode=mode_).value + out = self._pt_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) + # but if mode or args don't exist in pytorch, use numpy instead + except (ValueError, TypeError) as err: + if "Unsupported option" in str(err) or "unexpected keyword" in str(err): + out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) else: out = img_t if get_track_meta(): diff --git a/tests/padders.py b/tests/padders.py index 932a0566cc..3fa3280cb5 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -29,6 +29,7 @@ # "reflect", "wrap", "median", + "mean", ] MODES += NP_MODES MODES += [NumpyPadMode(i) for i in NP_MODES] From ca7e11dbc77419230b82bf2da1fee7e6c005df9e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 29 Jun 2022 10:45:10 +0800 Subject: [PATCH 45/47] [DLMED] add test for deepcopy Signed-off-by: Nic Ma --- monai/transforms/croppad/array.py | 10 ++++++---- tests/test_rand_spatial_crop_samples.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 7904da0355..ce464e036f 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -705,10 +705,12 @@ def __call__(self, img: torch.Tensor) -> List[torch.Tensor]: Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ - ret = [self.cropper(img) for _ in range(self.num_samples)] - if get_track_meta(): - for i, r in enumerate(ret): - r.meta[Key.PATCH_INDEX] = i # type: ignore + ret = [] + for i in range(self.num_samples): + cropped = self.cropper(img) + if get_track_meta(): + cropped.meta[Key.PATCH_INDEX] = i # type: ignore + ret.append(cropped) return ret diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 537c7b4e4e..50571b5955 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -90,8 +90,9 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_last_ite result = xform(p(input_data)) np.testing.assert_equal(len(result), input_param["num_samples"]) - for item, expected in zip(result, expected_shape): + for i, (item, expected) in enumerate(zip(result, expected_shape)): self.assertTupleEqual(item.shape, expected) + self.assertEqual(item.meta["patch_index"], i) assert_allclose(result[-1], expected_last_item, type_test=False) From b0ff94395e03fa62713db8082b5c15f86d96a379 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 29 Jun 2022 14:35:19 +0800 Subject: [PATCH 46/47] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/transforms/__init__.py | 3 +++ monai/transforms/croppad/array.py | 2 ++ tests/test_module_list.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5cc7747aeb..a5c9ed05eb 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -56,6 +56,9 @@ Padd, PadD, PadDict, + RandCropd, + RandCropD, + RandCropDict, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index ce464e036f..0483839759 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -246,6 +246,7 @@ class BorderPad(Pad): Args: spatial_border: specified size for every spatial border. Any -ve values will be set to 0. It can be 3 shapes: + - single int number, pad all the borders with the same size. - length equals the length of image shape, pad every spatial dimension separately. for example, image shape(CHW) is [1, 4, 4], spatial_border is [2, 1], @@ -435,6 +436,7 @@ class SpatialCrop(Crop): So the cropped result may be smaller than the expected ROI, and the cropped results of several images may not have exactly the same shape. It can support to crop ND spatial (channel-first) data. + The cropped region can be parameterised in various ways: - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`) - a spatial center and size diff --git a/tests/test_module_list.py b/tests/test_module_list.py index d81d067c58..d0b5aaf26b 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -38,7 +38,7 @@ def test_public_api(self): def test_transform_api(self): """monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'""" to_exclude = {"MapTransform"} # except for these transforms - to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision"} + to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision", "RandCrop"} to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"}) to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"}) xforms = { From bb37a2bb8d966c55b8ce22d793f0ecf88950cce1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 29 Jun 2022 14:53:26 +0800 Subject: [PATCH 47/47] [DLMED] update docs Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 70a775578b..ad739c0fcd 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -48,7 +48,6 @@ from monai.transforms.utils import is_positive from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import PostFix __all__ = [ "Padd", @@ -578,7 +577,7 @@ def __init__( random_center: bool = True, random_size: bool = True, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = PostFix.meta(), + meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -716,7 +715,7 @@ def __init__( num_samples: int = 1, center_coord_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = PostFix.meta(), + meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) @@ -817,7 +816,7 @@ def __init__( fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = PostFix.meta(), + meta_key_postfix: str = "meta_dict", allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: @@ -961,7 +960,7 @@ def __init__( image_threshold: float = 0.0, indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = PostFix.meta(), + meta_key_postfix: str = "meta_dict", allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: