diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 2eb2537b49..7eaa17ea43 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -105,6 +105,12 @@ Crop and Pad :members: :special-members: __call__ +`Crop` +"""""" +.. autoclass:: Crop + :members: + :special-members: __call__ + `SpatialCrop` """"""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCrop.png @@ -995,6 +1001,12 @@ Dictionary Transforms Crop and Pad (Dict) ^^^^^^^^^^^^^^^^^^^ +`Padd` +"""""" +.. autoclass:: Padd + :members: + :special-members: __call__ + `SpatialPadd` """"""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialPadd.png @@ -1019,6 +1031,18 @@ Crop and Pad (Dict) :members: :special-members: __call__ +`Cropd` +""""""" +.. autoclass:: Cropd + :members: + :special-members: __call__ + +`RandCropd` +""""""""""" +.. autoclass:: RandCropd + :members: + :special-members: __call__ + `SpatialCropd` """""""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCropd.png diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d4f09474de..a5c9ed05eb 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -16,6 +16,7 @@ BoundingRect, CenterScaleCrop, CenterSpatialCrop, + Crop, CropForeground, DivisiblePad, Pad, @@ -43,19 +44,30 @@ CenterSpatialCropd, CenterSpatialCropD, CenterSpatialCropDict, + Cropd, + CropD, + CropDict, CropForegroundd, CropForegroundD, CropForegroundDict, DivisiblePadd, DivisiblePadD, DivisiblePadDict, - PadModeSequence, + Padd, + PadD, + PadDict, + RandCropd, + RandCropD, + RandCropDict, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, RandCropByPosNegLabeld, RandCropByPosNegLabelD, RandCropByPosNegLabelDict, + RandCropd, + RandCropD, + RandCropDict, RandScaleCropd, RandScaleCropD, RandScaleCropDict, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6537cf3e21..0483839759 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -23,11 +23,15 @@ from monai.config import IndexSelection from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor from monai.data.utils import get_random_patch, get_valid_patch_size +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, convert_pad_mode, + create_translate, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -36,24 +40,17 @@ map_classes_to_indices, weighted_patch_samples, ) -from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum -from monai.utils import ( - Method, - NumpyPadMode, - PytorchPadMode, - ensure_tuple, - ensure_tuple_rep, - fall_back_tuple, - look_up_option, -) -from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type +from monai.utils import ImageMetaKey as Key +from monai.utils import Method, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option +from monai.utils.enums import TraceKeys, TransformBackends +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor __all__ = [ "Pad", "SpatialPad", "BorderPad", "DivisiblePad", + "Crop", "SpatialCrop", "CenterSpatialCrop", "CenterScaleCrop", @@ -69,13 +66,16 @@ ] -class Pad(Transform): +class Pad(InvertibleTransform): """ Perform padding for a given an amount of padding in each dimension. - If input is `torch.Tensor`, `torch.nn.functional.pad` will be used, otherwise, `np.pad` will be used. + + `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, + in which case `np.pad` will be used. Args: to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. + if None, must provide in the `__call__` at runtime. mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. @@ -84,63 +84,113 @@ class Pad(Transform): https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, - to_pad: List[Tuple[int, int]], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, - **kwargs, + self, to_pad: Optional[List[Tuple[int, int]]] = None, mode: str = PytorchPadMode.CONSTANT, **kwargs ) -> None: self.to_pad = to_pad self.mode = mode self.kwargs = kwargs + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + """ + dynamically compute the pad width according to the spatial shape. + + Args: + spatial_shape: spatial shape of the original image. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + @staticmethod - def _np_pad(img: np.ndarray, all_pad_width, mode, **kwargs) -> np.ndarray: - return np.pad(img, all_pad_width, mode=mode, **kwargs) # type: ignore + def _np_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: + img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img + mode = convert_pad_mode(dst=img_np, mode=mode).value + out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) + if isinstance(img, MetaTensor): + out = MetaTensor(out, meta=img.meta, applied_operations=img.applied_operations) + return out @staticmethod - def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: - pt_pad_width = [val for sublist in all_pad_width[1:] for val in sublist[::-1]][::-1] + def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: + pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] # torch.pad expects `[B, C, H, W, [D]]` shape return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) - def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + def __call__( # type: ignore + self, img: torch.Tensor, to_pad: Optional[List[Tuple[int, int]]] = None, mode: Optional[str] = None, **kwargs + ) -> torch.Tensor: """ Args: - img: data to be transformed, assuming `img` is channel-first and - padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"`` or ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. + to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. + default to `self.to_pad`. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ - if not np.asarray(self.to_pad).any(): - # all zeros, skip padding - return img - mode = convert_pad_mode(dst=img, mode=mode or self.mode).value - pad = self._pt_pad if isinstance(img, torch.Tensor) else self._np_pad - return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore + to_pad_ = self.to_pad if to_pad is None else to_pad + if to_pad_ is None: + to_pad_ = self.compute_pad_width(img.shape[1:]) + mode_ = self.mode if mode is None else mode + kwargs_ = dict(self.kwargs) + kwargs_.update(kwargs) + + img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) + + # all zeros, skip padding + if np.asarray(to_pad_).any(): + if mode in ["linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"]: + out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) + else: + try: + mode_ = convert_pad_mode(dst=img_t, mode=mode_).value + out = self._pt_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) + # but if mode or args don't exist in pytorch, use numpy instead + except (ValueError, TypeError) as err: + if "Unsupported option" in str(err) or "unexpected keyword" in str(err): + out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) + else: + out = img_t + if get_track_meta(): + self.update_meta(tensor=out, to_pad=to_pad_) # type: ignore + self.push_transform(out, extra_info={"padded": to_pad_}) + return out + + def update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): + spatial_rank = max(len(tensor.affine) - 1, 1) + to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad + mat = create_translate(spatial_rank, to_shift) + tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + + def inverse(self, data: MetaTensor) -> MetaTensor: + transform = self.pop_transform(data) + padded = transform[TraceKeys.EXTRA_INFO]["padded"] + if padded[0][0] != 0 or padded[0][1] != 0: + raise NotImplementedError( + "Inverse uses SpatialCrop, which hasn't yet been extended to crop channels. Trivial change." + ) + roi_start = [i[0] for i in padded[1:]] + roi_end = [i - j[1] for i, j in zip(data.shape[1:], padded[1:])] + cropper = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + with cropper.trace_transform(False): + return cropper(data) # type: ignore -class SpatialPad(Transform): +class SpatialPad(Pad): """ Performs padding to the data, symmetric for all sides or all on one side for each dimension. - If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. - Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). - - Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad - for additional details. - Args: spatial_size: the spatial size of output data after padding, if a dimension of the input data size is bigger than the pad size, will not pad that dimension. @@ -160,56 +210,37 @@ class SpatialPad(Transform): """ - backend = Pad.backend - def __init__( self, spatial_size: Union[Sequence[int], int], - method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + method: str = Method.SYMMETRIC, + mode: str = PytorchPadMode.CONSTANT, **kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = look_up_option(method, Method) - self.mode = mode - self.kwargs = kwargs - - def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: - spatial_size = fall_back_tuple(self.spatial_size, data_shape) - if self.method == Method.SYMMETRIC: - pad_width = [] - for i, sp_i in enumerate(spatial_size): - width = max(sp_i - data_shape[i], 0) - pad_width.append((width // 2, width - (width // 2))) - return pad_width - return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] + super().__init__(mode=mode, **kwargs) - def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: """ + dynamically compute the pad width according to the spatial shape. + Args: - img: data to be transformed, assuming `img` is channel-first and - padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + spatial_shape: spatial shape of the original image. """ - data_pad_width = self._determine_data_pad_width(img.shape[1:]) - all_pad_width = [(0, 0)] + data_pad_width - if not np.asarray(all_pad_width).any(): - # all zeros, skip padding - return img - - padder = Pad(to_pad=all_pad_width, mode=mode or self.mode, **self.kwargs) - return padder(img) + spatial_size = fall_back_tuple(self.spatial_size, spatial_shape) + if self.method == Method.SYMMETRIC: + pad_width = [] + for i, sp_i in enumerate(spatial_size): + width = max(sp_i - spatial_shape[i], 0) + pad_width.append((width // 2, width - (width // 2))) + else: + pad_width = [(0, max(sp_i - spatial_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] + return [(0, 0)] + pad_width -class BorderPad(Transform): +class BorderPad(Pad): """ Pad the input data by adding specified borders to every dimension. @@ -235,39 +266,13 @@ class BorderPad(Transform): """ - backend = Pad.backend - def __init__( - self, - spatial_border: Union[Sequence[int], int], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, - **kwargs, + self, spatial_border: Union[Sequence[int], int], mode: str = PytorchPadMode.CONSTANT, **kwargs ) -> None: self.spatial_border = spatial_border - self.mode = mode - self.kwargs = kwargs - - def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: - """ - Args: - img: data to be transformed, assuming `img` is channel-first and - padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + super().__init__(mode=mode, **kwargs) - Raises: - ValueError: When ``self.spatial_border`` does not contain ints. - ValueError: When ``self.spatial_border`` length is not one of - [1, len(spatial_shape), 2*len(spatial_shape)]. - - """ - spatial_shape = img.shape[1:] + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_border = ensure_tuple(self.spatial_border) if not all(isinstance(b, int) for b in spatial_border): raise ValueError(f"self.spatial_border must contain only ints, got {spatial_border}.") @@ -284,13 +289,10 @@ def __call__( f"Unsupported spatial_border length: {len(spatial_border)}, available options are " f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) + return [(0, 0)] + data_pad_width - all_pad_width = [(0, 0)] + data_pad_width - padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) - return padder(img) - -class DivisiblePad(Transform): +class DivisiblePad(Pad): """ Pad the input data, so that the spatial sizes are divisible by `k`. """ @@ -300,8 +302,8 @@ class DivisiblePad(Transform): def __init__( self, k: Union[Sequence[int], int], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, - method: Union[Method, str] = Method.SYMMETRIC, + mode: str = PytorchPadMode.CONSTANT, + method: str = Method.SYMMETRIC, **kwargs, ) -> None: """ @@ -323,32 +325,111 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ self.k = k - self.mode: NumpyPadMode = NumpyPadMode(mode) self.method: Method = Method(method) - self.kwargs = kwargs + super().__init__(mode=mode, **kwargs) - def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k) + spatial_pad = SpatialPad(spatial_size=new_size, method=self.method) + return spatial_pad.compute_pad_width(spatial_shape) + + +class Crop(InvertibleTransform): + """ + Perform crop operation on the input image. + + """ + + backend = [TransformBackends.TORCH] + + @staticmethod + def compute_slices( + roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_slices: Optional[Sequence[slice]] = None, + ): """ + Compute the crop slices based on specified `center & size` or `start & end`. + Args: - img: data to be transformed, assuming `img` is channel-first - and padding doesn't apply to the channel dim. - mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to `self.mode`. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + roi_center: voxel coordinates for center of the crop ROI. + roi_size: size of the crop ROI, if a dimension of ROI size is bigger than image size, + will not crop that dimension of the image. + roi_start: voxel coordinates for start of the crop ROI. + roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image, + use the end coordinate of image. + roi_slices: list of slices for each of the spatial dimensions. """ - new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) - spatial_pad = SpatialPad(spatial_size=new_size, method=self.method, mode=mode or self.mode, **self.kwargs) + roi_start_t: torch.Tensor - return spatial_pad(img) + if roi_slices: + if not all(s.step is None or s.step == 1 for s in roi_slices): + raise ValueError("only slice steps of 1/None are currently supported") + return list(roi_slices) + else: + if roi_center is not None and roi_size is not None: + roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True) + roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True) + _zeros = torch.zeros_like(roi_center_t) + roi_start_t = torch.maximum(roi_center_t - torch.div(roi_size_t, 2, rounding_mode="floor"), _zeros) + roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t) + else: + if roi_start is None or roi_end is None: + raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.") + roi_start_t = convert_to_tensor(data=roi_start, dtype=torch.int16, wrap_sequence=True) + roi_start_t = torch.maximum(roi_start_t, torch.zeros_like(roi_start_t)) + roi_end_t = convert_to_tensor(data=roi_end, dtype=torch.int16, wrap_sequence=True) + roi_end_t = torch.maximum(roi_end_t, roi_start_t) + # convert to slices (accounting for 1d) + if roi_start_t.numel() == 1: + return [slice(int(roi_start_t.item()), int(roi_end_t.item()))] + else: + return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] + def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: # type: ignore + """ + Apply the transform to `img`, assuming `img` is channel-first and + slicing doesn't apply to the channel dim. -class SpatialCrop(Transform): + """ + orig_size = img.shape[1:] + slices_ = list(slices) + sd = len(img.shape[1:]) # spatial dims + if len(slices_) < sd: + slices_ += [slice(None)] * (sd - len(slices_)) + # Add in the channel (no cropping) + slices = tuple([slice(None)] + slices_[:sd]) + + img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) + img_t = img_t[slices] # type: ignore + if get_track_meta(): + self.update_meta(tensor=img_t, slices=slices) + cropped_from_start = np.asarray([s.indices(o)[0] for s, o in zip(slices[1:], orig_size)]) + cropped_from_end = np.asarray(orig_size) - img_t.shape[1:] - cropped_from_start + cropped = list(chain(*zip(cropped_from_start.tolist(), cropped_from_end.tolist()))) + self.push_transform(img_t, extra_info={"cropped": cropped}) + return img_t + + def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): + spatial_rank = max(len(tensor.affine) - 1, 1) + to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]] + mat = create_translate(spatial_rank, to_shift) + tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + + def inverse(self, img: MetaTensor) -> MetaTensor: + transform = self.pop_transform(img) + cropped = transform[TraceKeys.EXTRA_INFO]["cropped"] + # the amount we pad is equal to the amount we cropped in each direction + inverse_transform = BorderPad(cropped) + # Apply inverse transform + with inverse_transform.trace_transform(False): + return inverse_transform(img) # type: ignore + + +class SpatialCrop(Crop): """ General purpose cropper to produce sub-volume region of interest (ROI). If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -362,8 +443,6 @@ class SpatialCrop(Transform): - the start and end coordinates of the ROI """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__( self, roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, @@ -382,47 +461,20 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. """ - roi_start_torch: torch.Tensor - - if roi_slices: - if not all(s.step is None or s.step == 1 for s in roi_slices): - raise ValueError("Only slice steps of 1/None are currently supported") - self.slices = list(roi_slices) - else: - if roi_center is not None and roi_size is not None: - roi_center, *_ = convert_data_type( - data=roi_center, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True - ) - roi_size, *_ = convert_to_dst_type(src=roi_size, dst=roi_center, wrap_sequence=True) - _zeros = torch.zeros_like(roi_center) - roi_start_torch = maximum(roi_center - floor_divide(roi_size, 2), _zeros) # type: ignore - roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch) - else: - if roi_start is None or roi_end is None: - raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_torch, *_ = convert_data_type( - data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True - ) - roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore - roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True) - roi_end_torch = maximum(roi_end_torch, roi_start_torch) - # convert to slices (accounting for 1d) - if roi_start_torch.numel() == 1: - self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] - else: - self.slices = [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] + self.slices = self.compute_slices( + roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices + ) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. + """ - sd = min(len(self.slices), len(img.shape[1:])) # spatial dims - slices = [slice(None)] + self.slices[:sd] - return img[tuple(slices)] + return super().__call__(img=img, slices=self.slices) -class CenterSpatialCrop(Transform): +class CenterSpatialCrop(Crop): """ Crop at the center of image with specified ROI size. If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -437,23 +489,24 @@ class CenterSpatialCrop(Transform): the spatial size of output data will be [32, 40, 40]. """ - backend = SpatialCrop.backend - def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def compute_slices(self, spatial_size: Sequence[int]): # type: ignore + roi_size = fall_back_tuple(self.roi_size, spatial_size) + roi_center = [i // 2 for i in spatial_size] + return super().compute_slices(roi_center=roi_center, roi_size=roi_size) + + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. + """ - roi_size = fall_back_tuple(self.roi_size, img.shape[1:]) - center = [i // 2 for i in img.shape[1:]] - cropper = SpatialCrop(roi_center=center, roi_size=roi_size) - return cropper(img) + return super().__call__(img=img, slices=self.compute_slices(img.shape[1:])) -class CenterScaleCrop(Transform): +class CenterScaleCrop(Crop): """ Crop at the center of image with specified scale of ROI size. @@ -463,20 +516,18 @@ class CenterScaleCrop(Transform): """ - backend = CenterSpatialCrop.backend - def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - sp_crop = CenterSpatialCrop(roi_size=roi_size) - return sp_crop(img=img) + cropper = CenterSpatialCrop(roi_size=roi_size) + return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) -class RandSpatialCrop(Randomizable, Transform): +class RandSpatialCrop(Randomizable, Crop): """ Crop image with random size or specific size ROI. It can crop at a random position as center or at the image center. And allows to set the minimum and maximum size to limit the randomly generated ROI. @@ -500,8 +551,6 @@ class RandSpatialCrop(Randomizable, Transform): if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`. """ - backend = CenterSpatialCrop.backend - def __init__( self, roi_size: Union[Sequence[int], int], @@ -514,7 +563,7 @@ def __init__( self.random_center = random_center self.random_size = random_size self._size: Optional[Sequence[int]] = None - self._slices: Optional[Tuple[slice, ...]] = None + self._slices: Tuple[slice, ...] def randomize(self, img_size: Sequence[int]) -> None: self._size = fall_back_tuple(self.roi_size, img_size) @@ -525,20 +574,22 @@ def randomize(self, img_size: Sequence[int]) -> None: self._size = tuple(self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))) if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + self._slices = get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. + """ - self.randomize(img.shape[1:]) + if randomize: + self.randomize(img.shape[1:]) if self._size is None: raise RuntimeError("self._size not specified.") if self.random_center: - return img[self._slices] + return super().__call__(img=img, slices=self._slices) cropper = CenterSpatialCrop(self._size) - return cropper(img) + return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) class RandScaleCrop(RandSpatialCrop): @@ -573,19 +624,26 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: - """ - Apply the transform to `img`, assuming `img` is channel-first and - slicing doesn't apply to the channel dim. - """ - img_size = img.shape[1:] + def get_max_roi_size(self, img_size): ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: self.max_roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.max_roi_scale, ndim), img_size)] else: self.max_roi_size = None - return super().__call__(img=img) + + def randomize(self, img_size: Sequence[int]) -> None: + self.get_max_roi_size(img_size) + super().randomize(img_size) + + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore + """ + Apply the transform to `img`, assuming `img` is channel-first and + slicing doesn't apply to the channel dim. + + """ + self.get_max_roi_size(img.shape[1:]) + return super().__call__(img=img, randomize=randomize) class RandSpatialCropSamples(Randomizable, Transform): @@ -644,15 +702,21 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: + def __call__(self, img: torch.Tensor) -> List[torch.Tensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ - return [self.cropper(img) for _ in range(self.num_samples)] + ret = [] + for i in range(self.num_samples): + cropped = self.cropper(img) + if get_track_meta(): + cropped.meta[Key.PATCH_INDEX] = i # type: ignore + ret.append(cropped) + return ret -class CropForeground(Transform): +class CropForeground(Crop): """ Crop an image using a bounding box. The bounding box is generated by selecting foreground using select_fn at channels channel_indices. margin is added in each spatial dimension of the bounding box. @@ -684,8 +748,6 @@ def threshold_at_one(x): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__( self, select_fn: Callable = is_positive, @@ -694,7 +756,7 @@ def __init__( allow_smaller: bool = True, return_coords: bool = False, k_divisible: Union[Sequence[int], int] = 1, - mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT, + mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, ) -> None: """ @@ -725,10 +787,9 @@ def __init__( self.allow_smaller = allow_smaller self.return_coords = return_coords self.k_divisible = k_divisible - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.pad_kwargs = pad_kwargs + self.padder = Pad(mode=mode, **pad_kwargs) - def compute_bounding_box(self, img: NdarrayOrTensor): + def compute_bounding_box(self, img: torch.Tensor): """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. @@ -748,34 +809,49 @@ def compute_bounding_box(self, img: NdarrayOrTensor): return box_start_, box_end_ def crop_pad( - self, - img: NdarrayOrTensor, - box_start: np.ndarray, - box_end: np.ndarray, - mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: Optional[str] = None, **pad_kwargs ): """ Crop and pad based on the bounding box. """ - cropped = SpatialCrop(roi_start=box_start, roi_end=box_end)(img) + slices = self.compute_slices(roi_start=box_start, roi_end=box_end) + cropped = super().__call__(img=img, slices=slices) pad_to_start = np.maximum(-box_start, 0) pad_to_end = np.maximum(box_end - np.asarray(img.shape[1:]), 0) pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - return BorderPad(spatial_border=pad, mode=mode or self.mode, **self.pad_kwargs)(cropped) - - def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None): + pad_width = BorderPad(spatial_border=pad).compute_pad_width(cropped.shape[1:]) + ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, **pad_kwargs) + # combine the traced cropping and padding into one transformation + # by taking the padded info and placing it in a key inside the crop info. + if get_track_meta(): + ret_: MetaTensor = ret # type: ignore + app_op = ret_.applied_operations.pop(-1) + ret_.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = app_op + return ret + + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs): # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ box_start, box_end = self.compute_bounding_box(img) - cropped = self.crop_pad(img, box_start, box_end, mode) + cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs) if self.return_coords: return cropped, box_start, box_end return cropped + def inverse(self, img: MetaTensor) -> MetaTensor: + transform = self.get_most_recent_transform(img) + # we moved the padding info in the forward, so put it back for the inverse + pad_info = transform[TraceKeys.EXTRA_INFO].pop("pad_info") + img.applied_operations.append(pad_info) + # first inverse the padder + inv = self.padder.inverse(img) + # and then inverse the cropper (self) + return super().inverse(inv) + class RandWeightedCrop(Randomizable, Transform): """ @@ -808,13 +884,16 @@ def randomize(self, weight_map: NdarrayOrTensor) -> None: spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - def __call__(self, img: NdarrayOrTensor, weight_map: Optional[NdarrayOrTensor] = None) -> List[NdarrayOrTensor]: + def __call__( + self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = None, randomize: bool = True + ) -> List[torch.Tensor]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. weight_map: weight map used to generate patch samples. The weights must be non-negative. Each element denotes a sampling weight of the spatial location. 0 indicates no sampling. It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)` + randomize: whether to execute random operations, defautl to `True`. Returns: A list of image patches @@ -826,12 +905,17 @@ def __call__(self, img: NdarrayOrTensor, weight_map: Optional[NdarrayOrTensor] = if img.shape[1:] != weight_map.shape[1:]: raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") - self.randomize(weight_map) + if randomize: + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results: List[NdarrayOrTensor] = [] - for center in self.centers: - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - results.append(cropper(img)) + results: List[torch.Tensor] = [] + for i, center in enumerate(self.centers): + cropped = SpatialCrop(roi_center=center, roi_size=_spatial_size)(img) + if get_track_meta(): + ret_: MetaTensor = cropped # type: ignore + ret_.meta[Key.PATCH_INDEX] = i + ret_.meta["crop_center"] = center + results.append(cropped) return results @@ -890,16 +974,16 @@ class RandCropByPosNegLabel(Randomizable, Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = SpatialCrop.backend def __init__( self, spatial_size: Union[Sequence[int], int], - label: Optional[NdarrayOrTensor] = None, + label: Optional[torch.Tensor] = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, image_threshold: float = 0.0, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, @@ -922,10 +1006,10 @@ def __init__( def randomize( self, - label: NdarrayOrTensor, + label: torch.Tensor, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, ) -> None: if fg_indices is None or bg_indices is None: if self.fg_indices is not None and self.bg_indices is not None: @@ -949,12 +1033,13 @@ def randomize( def __call__( self, - img: NdarrayOrTensor, - label: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + img: torch.Tensor, + label: Optional[torch.Tensor] = None, + image: Optional[torch.Tensor] = None, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, - ) -> List[NdarrayOrTensor]: + randomize: bool = True, + ) -> List[torch.Tensor]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -967,6 +1052,7 @@ def __call__( need to provide `fg_indices` and `bg_indices` together. bg_indices: background indices to randomly select crop centers, need to provide `fg_indices` and `bg_indices` together. + randomize: whether to execute the random operations, default to `True`. """ if label is None: @@ -976,14 +1062,18 @@ def __call__( if image is None: image = self.image - self.randomize(label, fg_indices, bg_indices, image) - results: List[NdarrayOrTensor] = [] + if randomize: + self.randomize(label, fg_indices, bg_indices, image) + results: List[torch.Tensor] = [] if self.centers is not None: - for center in self.centers: + for i, center in enumerate(self.centers): roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - cropper = SpatialCrop(roi_center=center, roi_size=roi_size) - results.append(cropper(img)) - + cropped = SpatialCrop(roi_center=center, roi_size=roi_size)(img) + if get_track_meta(): + ret_: MetaTensor = cropped # type: ignore + ret_.meta[Key.PATCH_INDEX] = i + ret_.meta["crop_center"] = center + results.append(cropped) return results @@ -1051,16 +1141,16 @@ class RandCropByLabelClasses(Randomizable, Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = SpatialCrop.backend def __init__( self, spatial_size: Union[Sequence[int], int], ratios: Optional[List[Union[float, int]]] = None, - label: Optional[NdarrayOrTensor] = None, + label: Optional[torch.Tensor] = None, num_classes: Optional[int] = None, num_samples: int = 1, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, image_threshold: float = 0.0, indices: Optional[List[NdarrayOrTensor]] = None, allow_smaller: bool = False, @@ -1077,10 +1167,7 @@ def __init__( self.allow_smaller = allow_smaller def randomize( - self, - label: NdarrayOrTensor, - indices: Optional[List[NdarrayOrTensor]] = None, - image: Optional[NdarrayOrTensor] = None, + self, label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, image: Optional[torch.Tensor] = None ) -> None: indices_: Sequence[NdarrayOrTensor] if indices is None: @@ -1096,11 +1183,12 @@ def randomize( def __call__( self, - img: NdarrayOrTensor, - label: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + img: torch.Tensor, + label: Optional[torch.Tensor] = None, + image: Optional[torch.Tensor] = None, indices: Optional[List[NdarrayOrTensor]] = None, - ) -> List[NdarrayOrTensor]: + randomize: bool = True, + ) -> List[torch.Tensor]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -1109,6 +1197,7 @@ def __call__( image: optional image data to help select valid area, can be same as `img` or another image array. use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`. indices: list of indices for every class in the image, used to randomly select crop centers. + randomize: whether to execute the random operations, default to `True`. """ if label is None: @@ -1118,18 +1207,23 @@ def __call__( if image is None: image = self.image - self.randomize(label, indices, image) - results: List[NdarrayOrTensor] = [] + if randomize: + self.randomize(label, indices, image) + results: List[torch.Tensor] = [] if self.centers is not None: - for center in self.centers: + for i, center in enumerate(self.centers): roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - results.append(cropper(img)) + cropped = SpatialCrop(roi_center=tuple(center), roi_size=roi_size)(img) + if get_track_meta(): + ret_: MetaTensor = cropped # type: ignore + ret_.meta[Key.PATCH_INDEX] = i + ret_.meta["crop_center"] = center + results.append(cropped) return results -class ResizeWithPadOrCrop(Transform): +class ResizeWithPadOrCrop(InvertibleTransform): """ Resize an image to a target spatial size by either centrally cropping the image or padding it evenly with a user-specified mode. @@ -1139,14 +1233,14 @@ class ResizeWithPadOrCrop(Transform): Args: spatial_size: the spatial size of output data after padding or crop. If has non-positive values, the corresponding size of input image will be used (no padding). + method: {``"symmetric"``, ``"end"``} + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - method: {``"symmetric"``, ``"end"``} - Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1157,16 +1251,14 @@ class ResizeWithPadOrCrop(Transform): def __init__( self, spatial_size: Union[Sequence[int], int], - mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, + mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, ): self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__( - self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None - ) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) -> torch.Tensor: # type: ignore """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1177,8 +1269,33 @@ def __call__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. + """ - return self.padder(self.cropper(img), mode=mode) + ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) + # remove the individual info and combine + if get_track_meta(): + ret_: MetaTensor = ret # type: ignore + pad_info = ret_.applied_operations.pop(-1) + crop_info = ret_.applied_operations.pop(-1) + self.push_transform(ret_, extra_info={"pad_info": pad_info, "crop_info": crop_info}) + return ret + + def inverse(self, img: MetaTensor) -> MetaTensor: + transform = self.pop_transform(img) + return self.inverse_transform(img, transform) + + def inverse_transform(self, img: MetaTensor, transform) -> MetaTensor: + # we joined the cropping and padding, so put them back before calling the inverse + crop_info = transform[TraceKeys.EXTRA_INFO].pop("crop_info") + pad_info = transform[TraceKeys.EXTRA_INFO].pop("pad_info") + img.applied_operations.append(crop_info) + img.applied_operations.append(pad_info) + # first inverse the padder + inv = self.padder.inverse(img) + # and then inverse the cropper (self) + return self.cropper.inverse(inv) class BoundingRect(Transform): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 50cc767cab..ad739c0fcd 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,50 +15,47 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -import contextlib from copy import deepcopy -from enum import Enum -from itertools import chain -from math import ceil, floor -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union import numpy as np +import torch from monai.config import IndexSelection, KeysCollection from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import get_random_patch, get_valid_patch_size +from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import ( BorderPad, BoundingRect, + CenterScaleCrop, CenterSpatialCrop, + Crop, CropForeground, DivisiblePad, + Pad, RandCropByLabelClasses, RandCropByPosNegLabel, + RandScaleCrop, + RandSpatialCrop, + RandSpatialCropSamples, + RandWeightedCrop, ResizeWithPadOrCrop, SpatialCrop, SpatialPad, ) from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable -from monai.transforms.utils import ( - allow_missing_keys_mode, - generate_label_classes_crop_centers, - generate_pos_neg_label_crop_centers, - is_positive, - map_binary_to_indices, - map_classes_to_indices, - weighted_patch_samples, -) -from monai.utils import ImageMetaKey as Key -from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple -from monai.utils.enums import PostFix, TraceKeys +from monai.transforms.utils import is_positive +from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep +from monai.utils.deprecate_utils import deprecated_arg __all__ = [ - "PadModeSequence", + "Padd", "SpatialPadd", "BorderPadd", "DivisiblePadd", + "Cropd", + "RandCropd", "SpatialCropd", "CenterSpatialCropd", "CenterScaleCropd", @@ -71,12 +68,18 @@ "ResizeWithPadOrCropd", "BoundingRectd", "RandCropByLabelClassesd", + "PadD", + "PadDict", "SpatialPadD", "SpatialPadDict", "BorderPadD", "BorderPadDict", "DivisiblePadD", "DivisiblePadDict", + "CropD", + "CropDict", + "RandCropD", + "RandCropDict", "SpatialCropD", "SpatialCropDict", "CenterSpatialCropD", @@ -103,24 +106,67 @@ "RandCropByLabelClassesDict", ] -PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] -DEFAULT_POST_FIX = PostFix.meta() +class Padd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.Pad`. -class SpatialPadd(MapTransform, InvertibleTransform): + """ + + backend = Pad.backend + + def __init__( + self, + keys: KeysCollection, + padder: Pad, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + padder: pad transform for the input image. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.padder = padder + self.mode = ensure_tuple_rep(mode, len(self.keys)) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key, m in self.key_iterator(d, self.mode): + d[key] = self.padder(d[key], mode=m) + return d + + def inverse(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.padder.inverse(d[key]) + return d + + +class SpatialPadd(Padd): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. - """ - backend = SpatialPad.backend + """ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], - method: Union[Method, str] = Method.SYMMETRIC, - mode: PadModeSequence = NumpyPadMode.CONSTANT, + method: str = Method.SYMMETRIC, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -147,39 +193,11 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = SpatialPad(spatial_size, method, **kwargs) + padder = SpatialPad(spatial_size, method, **kwargs) + super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - d[key] = self.padder(d[key], mode=m) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = transform[TraceKeys.ORIG_SIZE] - if self.padder.method == Method.SYMMETRIC: - current_size = d[key].shape[1:] - roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] - else: - roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] - - inverse_transform = SpatialCrop(roi_center, orig_size) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - return d - - -class BorderPadd(MapTransform, InvertibleTransform): +class BorderPadd(Padd): """ Pad the input data by adding specified borders to every dimension. Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`. @@ -191,7 +209,7 @@ def __init__( self, keys: KeysCollection, spatial_border: Union[Sequence[int], int], - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -222,43 +240,11 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = BorderPad(spatial_border=spatial_border, **kwargs) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - d[key] = self.padder(d[key], mode=m) - return d + padder = BorderPad(spatial_border=spatial_border, **kwargs) + super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - roi_start = np.array(self.padder.spatial_border) - # Need to convert single value to [min1,min2,...] - if roi_start.size == 1: - roi_start = np.full((len(orig_size)), roi_start) - # need to convert [min1,max1,min2,...] to [min1,min2,...] - elif roi_start.size == 2 * orig_size.size: - roi_start = roi_start[::2] - roi_end = np.array(transform[TraceKeys.ORIG_SIZE]) + roi_start - - inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d - - -class DivisiblePadd(MapTransform, InvertibleTransform): +class DivisiblePadd(Padd): """ Pad the input data, so that the spatial sizes are divisible by `k`. Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`. @@ -270,8 +256,8 @@ def __init__( self, keys: KeysCollection, k: Union[Sequence[int], int], - mode: PadModeSequence = NumpyPadMode.CONSTANT, - method: Union[Method, str] = Method.SYMMETRIC, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, + method: str = Method.SYMMETRIC, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -298,37 +284,81 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ + padder = DivisiblePad(k=k, method=method, **kwargs) + super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + + +class Cropd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of abstract class :py:class:`monai.transforms.Crop`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + cropper: crop transform for the input image. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = Crop.backend + + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = DivisiblePad(k=k, method=method, **kwargs) + self.cropper = cropper - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - d[key] = self.padder(d[key], mode=m) + for key in self.key_iterator(d): + d[key] = self.cropper(d[key]) # type: ignore return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - + def inverse(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]: + d = dict(data) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - roi_start = np.floor((current_size - orig_size) / 2) - roi_end = orig_size + roi_start - inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) + d[key] = self.cropper.inverse(d[key]) + return d + +class RandCropd(Cropd, Randomizable): + """ + Base class for random crop transform. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + cropper: random crop transform for the input image. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = Crop.backend + + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCropd": + super().set_random_state(seed, state) + if isinstance(self.cropper, Randomizable): + self.cropper.set_random_state(seed, state) + return self + + def randomize(self, img_size: Sequence[int]) -> None: + if isinstance(self.cropper, Randomizable): + self.cropper.randomize(img_size) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + # the first key must exist to execute random operations + self.randomize(d[self.first_key(d)].shape[1:]) + for key in self.key_iterator(d): + kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} + d[key] = self.cropper(d[key], **kwargs) # type: ignore return d -class SpatialCropd(MapTransform, InvertibleTransform): +class SpatialCropd(Cropd): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`. General purpose cropper to produce sub-volume region of interest (ROI). @@ -343,8 +373,6 @@ class SpatialCropd(MapTransform, InvertibleTransform): - the start and end coordinates of the ROI """ - backend = SpatialCrop.backend - def __init__( self, keys: KeysCollection, @@ -367,40 +395,13 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. allow_missing_keys: don't raise exception if key is missing. - """ - super().__init__(keys, allow_missing_keys) - self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key in self.key_iterator(d): - self.push_transform(d, key) - d[key] = self.cropper(d[key]) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + """ + cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class CenterSpatialCropd(MapTransform, InvertibleTransform): +class CenterSpatialCropd(Cropd): """ Dictionary-based wrapper of :py:class:`monai.transforms.CenterSpatialCrop`. If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -418,45 +419,14 @@ class CenterSpatialCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = CenterSpatialCrop.backend - def __init__( self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False ) -> None: - super().__init__(keys, allow_missing_keys) - self.cropper = CenterSpatialCrop(roi_size) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key in self.key_iterator(d): - orig_size = d[key].shape[1:] - d[key] = self.cropper(d[key]) - self.push_transform(d, key, orig_size=orig_size) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) - # in each direction, if original size is even and current size is odd, += 1 - pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 - pad_to_end = orig_size - current_size - pad_to_start - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + cropper = CenterSpatialCrop(roi_size) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class CenterScaleCropd(MapTransform, InvertibleTransform): +class CenterScaleCropd(Cropd): """ Dictionary-based wrapper of :py:class:`monai.transforms.CenterScaleCrop`. Note: as using the same scaled ROI to crop, all the input data specified by `keys` should have @@ -470,54 +440,14 @@ class CenterScaleCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = CenterSpatialCrop.backend - def __init__( self, keys: KeysCollection, roi_scale: Union[Sequence[float], float], allow_missing_keys: bool = False ) -> None: - super().__init__(keys, allow_missing_keys=allow_missing_keys) - self.roi_scale = roi_scale + cropper = CenterScaleCrop(roi_scale) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - first_key: Union[Hashable, List] = self.first_key(d) - if first_key == []: - return d - - # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = d[first_key].shape[1:] # type: ignore - ndim = len(img_size) - roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - cropper = CenterSpatialCrop(roi_size) - for key in self.key_iterator(d): - self.push_transform(d, key, orig_size=img_size) - d[key] = cropper(d[key]) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) - # in each direction, if original size is even and current size is odd, += 1 - pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 - pad_to_end = orig_size - current_size - pad_to_start - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d - - -class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): +class RandSpatialCropd(RandCropd): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. Crop image with random size or specific size ROI. It can crop at a random position as @@ -547,8 +477,6 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = CenterSpatialCrop.backend - def __init__( self, keys: KeysCollection, @@ -558,78 +486,11 @@ def __init__( random_size: bool = True, allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys, allow_missing_keys) - self.roi_size = roi_size - self.max_roi_size = max_roi_size - self.random_center = random_center - self.random_size = random_size - self._slices: Optional[Tuple[slice, ...]] = None - self._size: Optional[Sequence[int]] = None + cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) - def randomize(self, img_size: Sequence[int]) -> None: - self._size = fall_back_tuple(self.roi_size, img_size) - if self.random_size: - max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size) - if any(i > j for i, j in zip(self._size, max_size)): - raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.") - self._size = [self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))] - if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - first_key: Union[Hashable, List] = self.first_key(d) - if first_key == []: - return d - - self.randomize(d[first_key].shape[1:]) # type: ignore - if self._size is None: - raise RuntimeError("self._size not specified.") - for key in self.key_iterator(d): - if self.random_center: - self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore - d[key] = d[key][self._slices] - else: - self.push_transform(d, key) - cropper = CenterSpatialCrop(self._size) - d[key] = cropper(d[key]) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = transform[TraceKeys.ORIG_SIZE] - random_center = self.random_center - pad_to_start = np.empty((len(orig_size)), dtype=np.int32) - pad_to_end = np.empty((len(orig_size)), dtype=np.int32) - if random_center: - for i, _slice in enumerate(transform[TraceKeys.EXTRA_INFO]["slices"]): - pad_to_start[i] = _slice[0] - pad_to_end[i] = orig_size[i] - _slice[1] - else: - current_size = d[key].shape[1:] - for i, (o_s, c_s) in enumerate(zip(orig_size, current_size)): - pad_to_start[i] = pad_to_end[i] = (o_s - c_s) / 2 - if o_s % 2 == 0 and c_s % 2 == 1: - pad_to_start[i] += 1 - elif o_s % 2 == 1 and c_s % 2 == 0: - pad_to_end[i] += 1 - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d - - -class RandScaleCropd(RandSpatialCropd): +class RandScaleCropd(RandCropd): """ Dictionary-based version :py:class:`monai.transforms.RandScaleCrop`. Crop image with random size or specific size ROI. @@ -654,8 +515,6 @@ class RandScaleCropd(RandSpatialCropd): allow_missing_keys: don't raise exception if key is missing. """ - backend = RandSpatialCropd.backend - def __init__( self, keys: KeysCollection, @@ -665,41 +524,11 @@ def __init__( random_size: bool = True, allow_missing_keys: bool = False, ) -> None: - super().__init__( - keys=keys, - roi_size=-1, - max_roi_size=None, - random_center=random_center, - random_size=random_size, - allow_missing_keys=allow_missing_keys, - ) - self.roi_scale = roi_scale - self.max_roi_scale = max_roi_scale + cropper = RandScaleCrop(roi_scale, max_roi_scale, random_center, random_size) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - first_key: Union[Hashable, List] = self.first_key(data) # type: ignore - if first_key == []: - return data # type: ignore - - img_size = data[first_key].shape[1:] # type: ignore - ndim = len(img_size) - self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - if self.max_roi_scale is not None: - self.max_roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.max_roi_scale, ndim), img_size)] - else: - self.max_roi_size = None - return super().__call__(data=data) - - -@contextlib.contextmanager -def _nullcontext(x): - """ - This is just like contextlib.nullcontext but also works in Python 3.6. - """ - yield x - -class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): +class RandSpatialCropSamplesd(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`. Crop image with random size or specific size ROI to generate a list of N samples. @@ -728,15 +557,6 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. Raises: @@ -744,8 +564,10 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): """ - backend = RandSpatialCropd.backend + backend = RandSpatialCropSamples.backend + @deprecated_arg(name="meta_keys", since="0.9") + @deprecated_arg(name="meta_key_postfix", since="0.9") def __init__( self, keys: KeysCollection, @@ -755,63 +577,33 @@ def __init__( random_center: bool = True, random_size: bool = True, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - if num_samples < 1: - raise ValueError(f"num_samples must be positive, got {num_samples}.") - self.num_samples = num_samples - self.cropper = RandSpatialCropd(keys, roi_size, max_roi_size, random_center, random_size, allow_missing_keys) - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSpatialCropSamplesd": - super().set_random_state(seed, state) - self.cropper.set_random_state(seed, state) - return self + self.cropper = RandSpatialCropSamples(roi_size, num_samples, max_roi_size, random_center, random_size) def randomize(self, data: Optional[Any] = None) -> None: - pass - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: - ret = [] - for i in range(self.num_samples): - d = dict(data) - # deep copy all the unmodified data - for key in set(data.keys()).difference(set(self.keys)): - d[key] = deepcopy(data[key]) - cropped = self.cropper(d) - # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd - for key in self.key_iterator(cropped): - cropped[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore - cropped[self.trace_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in cropped: - cropped[meta_key] = {} # type: ignore - cropped[meta_key][Key.PATCH_INDEX] = i # type: ignore - ret.append(cropped) + self.sub_seed = self.R.randint(MAX_SEED, dtype="uint32") + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: + # output starts as empty list of dictionaries + ret: List[Dict[Hashable, torch.Tensor]] = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) + + # for each key we reset the random state to ensure crops are the same + self.randomize() + for key in self.key_iterator(dict(data)): + self.cropper.set_random_state(seed=self.sub_seed) + for i, im in enumerate(self.cropper(data[key])): + ret[i][key] = im return ret - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: - d = deepcopy(dict(data)) - # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd - # Need to revert that since we're calling RandSpatialCropd's inverse - for key in self.key_iterator(d): - d[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__ - d[self.trace_key(key)][-1][TraceKeys.ID] = id(self.cropper) - context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext - with context_manager(self.cropper): - return self.cropper.inverse(d) - -class CropForegroundd(MapTransform, InvertibleTransform): +class CropForegroundd(Cropd): """ Dictionary-based version :py:class:`monai.transforms.CropForeground`. Crop only the foreground object of the expected images. @@ -824,8 +616,6 @@ class CropForegroundd(MapTransform, InvertibleTransform): channels. And it can also add margin to every dim of the bounding box of foreground object. """ - backend = CropForeground.backend - def __init__( self, keys: KeysCollection, @@ -835,7 +625,7 @@ def __init__( margin: Union[Sequence[int], int] = 0, allow_smaller: bool = True, k_divisible: Union[Sequence[int], int] = 1, - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, @@ -869,11 +659,10 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - super().__init__(keys, allow_missing_keys) self.source_key = source_key self.start_coord_key = start_coord_key self.end_coord_key = end_coord_key - self.cropper = CropForeground( + cropper = CropForeground( select_fn=select_fn, channel_indices=channel_indices, margin=margin, @@ -881,48 +670,21 @@ def __init__( k_divisible=k_divisible, **pad_kwargs, ) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) + self.cropper: CropForeground box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) d[self.start_coord_key] = box_start d[self.end_coord_key] = box_end for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key, extra_info={"box_start": box_start, "box_end": box_end}) d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - cur_size = np.asarray(d[key].shape[1:]) - extra_info = transform[TraceKeys.EXTRA_INFO] - box_start = np.asarray(extra_info["box_start"]) - box_end = np.asarray(extra_info["box_end"]) - # first crop the padding part - roi_start = np.maximum(-box_start, 0) - roi_end = cur_size - np.maximum(box_end - orig_size, 0) - - d[key] = SpatialCrop(roi_start=roi_start, roi_end=roi_end)(d[key]) - - # update bounding box to pad - pad_to_start = np.maximum(box_start, 0) - pad_to_end = orig_size - np.minimum(box_end, orig_size) - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - # second pad back the original size - d[key] = BorderPad(pad)(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d - -class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): +class RandWeightedCropd(Randomizable, MapTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -934,16 +696,6 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): spatial_size: the spatial size of the image patch e.g. [224, 224, 128]. If its components have non-positive values, the corresponding size of `img` will be used. num_samples: number of samples (image patches) to take in the returned list. - center_coord_key: if specified, the actual sampling location will be stored with the corresponding key. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. See Also: @@ -952,6 +704,9 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): backend = SpatialCrop.backend + @deprecated_arg(name="meta_keys", since="0.9") + @deprecated_arg(name="meta_key_postfix", since="0.9") + @deprecated_arg(name="center_coord_key", since="0.9", msg_suffix="coords stored in img.meta['crop_center']") def __init__( self, keys: KeysCollection, @@ -960,85 +715,41 @@ def __init__( num_samples: int = 1, center_coord_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) - self.spatial_size = ensure_tuple(spatial_size) self.w_key = w_key - self.num_samples = int(num_samples) - self.center_coord_key = center_coord_key - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: List[np.ndarray] = [] + self.cropper = RandWeightedCrop(spatial_size, num_samples) - def randomize(self, weight_map: NdarrayOrTensor) -> None: - self.centers = weighted_patch_samples( - spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R - ) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: - d = dict(data) - self.randomize(d[self.w_key]) - _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) - - # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(data) for _ in range(self.num_samples)] - # fill in the extra keys with unmodified data - for i in range(self.num_samples): - for key in set(data.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(data[key]) - for key in self.key_iterator(d): - img = d[key] - if img.shape[1:] != d[self.w_key].shape[1:]: - raise ValueError( - f"data {key} and weight map {self.w_key} spatial shape mismatch: " - f"{img.shape[1:]} vs {d[self.w_key].shape[1:]}." - ) - for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - orig_size = img.shape[1:] - results[i][key] = cropper(img) - self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) - if self.center_coord_key: - results[i][self.center_coord_key] = center - # fill in the extra keys with unmodified data - for i in range(self.num_samples): - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore - - return results - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - current_size = np.asarray(d[key].shape[1:]) - center = transform[TraceKeys.EXTRA_INFO]["center"] - cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandWeightedCropd": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) + return self - return d + def randomize(self, weight_map: NdarrayOrTensor) -> None: + self.cropper.randomize(weight_map) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: + # output starts as empty list of dictionaries + ret: List = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) + + self.randomize(weight_map=data[self.w_key]) + for key in self.key_iterator(data): + for i, im in enumerate(self.cropper(data[key], weight_map=data[self.w_key], randomize=False)): + ret[i][key] = im + return ret -class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): +@deprecated_arg(name="meta_keys", since="0.9") +@deprecated_arg(name="meta_key_postfix", since="0.9") +class RandCropByPosNegLabeld(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`. Crop random fixed sized regions with the center being a foreground or background voxel @@ -1079,15 +790,6 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key` and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndicesd` transform first and cache the results. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). @@ -1114,54 +816,41 @@ def __init__( fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = "meta_dict", allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key - self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size - if pos < 0 or neg < 0: - raise ValueError(f"pos and neg must be nonnegative, got pos={pos} neg={neg}.") - if pos + neg == 0: - raise ValueError("Incompatible values: pos=0 and neg=0.") - self.pos_ratio = pos / (pos + neg) - self.num_samples = num_samples self.image_key = image_key - self.image_threshold = image_threshold self.fg_indices_key = fg_indices_key self.bg_indices_key = bg_indices_key - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[int]]] = None - self.allow_smaller = allow_smaller + self.cropper = RandCropByPosNegLabel( + spatial_size=spatial_size, + pos=pos, + neg=neg, + num_samples=num_samples, + image_threshold=image_threshold, + allow_smaller=allow_smaller, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCropByPosNegLabeld": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) + return self def randomize( self, - label: NdarrayOrTensor, + label: torch.Tensor, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, - image: Optional[NdarrayOrTensor] = None, + image: Optional[torch.Tensor] = None, ) -> None: - if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) - else: - fg_indices_ = fg_indices - bg_indices_ = bg_indices - self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, - self.num_samples, - self.pos_ratio, - label.shape[1:], - fg_indices_, - bg_indices_, - self.R, - self.allow_smaller, - ) + self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1169,57 +858,23 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab bg_indices = d.pop(self.bg_indices_key, None) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) - if self.centers is None: - raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] - - for i, center in enumerate(self.centers): - # fill in the extra keys with unmodified data - for key in set(d.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(d[key]) - for key in self.key_iterator(d): - img = d[key] - orig_size = img.shape[1:] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - results[i][key] = cropper(img) - self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore - - return results - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - current_size = np.asarray(d[key].shape[1:]) - center = transform[TraceKeys.EXTRA_INFO]["center"] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) # type: ignore - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + ret: List = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) + + for key in self.key_iterator(data): + for i, im in enumerate(self.cropper(data[key], label=label, randomize=False)): + ret[i][key] = im + return ret -class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): +@deprecated_arg(name="meta_keys", since="0.9") +@deprecated_arg(name="meta_key_postfix", since="0.9") +class RandCropByLabelClassesd(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`. Crop random fixed sized regions with the center being a class based on the specified ratios of every class. @@ -1284,15 +939,6 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first and cache the results for better performance. - meta_keys: explicitly indicate the key of the corresponding metadata dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the metadata is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according - to the key data, default is `meta_dict`, the metadata is a dictionary object. - used to add `patch_index` to the meta dict. allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will remain unchanged. @@ -1314,98 +960,57 @@ def __init__( image_threshold: float = 0.0, indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = DEFAULT_POST_FIX, + meta_key_postfix: str = "meta_dict", allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key - self.spatial_size = spatial_size - self.ratios = ratios - self.num_classes = num_classes - self.num_samples = num_samples self.image_key = image_key - self.image_threshold = image_threshold self.indices_key = indices_key - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[int]]] = None - self.allow_smaller = allow_smaller + self.cropper = RandCropByLabelClasses( + spatial_size=spatial_size, + ratios=ratios, + num_classes=num_classes, + num_samples=num_samples, + image_threshold=image_threshold, + allow_smaller=allow_smaller, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCropByLabelClassesd": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) + return self def randomize( - self, - label: NdarrayOrTensor, - indices: Optional[List[NdarrayOrTensor]] = None, - image: Optional[NdarrayOrTensor] = None, + self, label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, image: Optional[torch.Tensor] = None ) -> None: - if indices is None: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) - else: - indices_ = indices - self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller - ) + self.cropper.randomize(label=label, indices=indices, image=image) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]: + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, torch.Tensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None indices = d.pop(self.indices_key, None) if self.indices_key is not None else None self.randomize(label, indices, image) - if self.centers is None: - raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] - - for i, center in enumerate(self.centers): - # fill in the extra keys with unmodified data - for key in set(d.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(d[key]) - for key in self.key_iterator(d): - img = d[key] - orig_size = img.shape[1:] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - results[i][key] = cropper(img) - self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) - # add `patch_index` to the metadata - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore - - return results - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) - current_size = np.asarray(d[key].shape[1:]) - center = transform[TraceKeys.EXTRA_INFO]["center"] - roi_size = fall_back_tuple(self.spatial_size, default=orig_size) - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) # type: ignore - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d + ret: List = [{} for _ in range(self.cropper.num_samples)] + # deep copy all the unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for r in ret: + r[key] = deepcopy(data[key]) + + for key in self.key_iterator(data): + for i, im in enumerate(self.cropper(data[key], label=label, randomize=False)): + ret[i][key] = im + return ret -class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): +class ResizeWithPadOrCropd(Padd): """ Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`. @@ -1429,63 +1034,17 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ - backend = ResizeWithPadOrCrop.backend - def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], - mode: PadModeSequence = NumpyPadMode.CONSTANT, + mode: Union[Sequence[str], str] = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, - method: Union[Method, str] = Method.SYMMETRIC, + method: str = Method.SYMMETRIC, **pad_kwargs, ) -> None: - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) - - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key, m in self.key_iterator(d, self.mode): - orig_size = d[key].shape[1:] - d[key] = self.padcropper(d[key], mode=m) - self.push_transform(d, key, orig_size=orig_size, extra_info={"mode": m.value if isinstance(m, Enum) else m}) - return d - - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) - current_size = np.array(d[key].shape[1:]) - # Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding. - # Instead, we first pad any smaller dimensions, and then we crop any larger dimensions. - - # First, do pad - if np.any((orig_size - current_size) > 0): - pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) - # in each direction, if original size is even and current size is odd, += 1 - pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 - pad_to_start[pad_to_start < 0] = 0 - pad_to_end = orig_size - current_size - pad_to_start - pad_to_end[pad_to_end < 0] = 0 - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - d[key] = BorderPad(pad)(d[key]) - - # Next crop - if np.any((orig_size - current_size) < 0): - if self.padcropper.padder.method == Method.SYMMETRIC: - roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] - else: - roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] - - d[key] = SpatialCrop(roi_center, orig_size)(d[key]) - - # Remove the applied transform - self.pop_transform(d, key) - - return d + padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) + super().__init__(keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys) # type: ignore class BoundingRectd(MapTransform): @@ -1528,9 +1087,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +PadD = PadDict = Padd SpatialPadD = SpatialPadDict = SpatialPadd BorderPadD = BorderPadDict = BorderPadd DivisiblePadD = DivisiblePadDict = DivisiblePadd +CropD = CropDict = Cropd +RandCropD = RandCropDict = RandCropd SpatialCropD = SpatialCropDict = SpatialCropd CenterSpatialCropD = CenterSpatialCropDict = CenterSpatialCropd CenterScaleCropD = CenterScaleCropDict = CenterScaleCropd diff --git a/tests/croppers.py b/tests/croppers.py new file mode 100644 index 0000000000..8f78249d90 --- /dev/null +++ b/tests/croppers.py @@ -0,0 +1,103 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from copy import deepcopy + +import numpy as np + +from monai.data.meta_tensor import MetaTensor +from monai.transforms.transform import MapTransform +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose + + +class CropTest(unittest.TestCase): + @staticmethod + def get_arr(shape): + return np.random.randint(100, size=shape).astype(float) + + def crop_test(self, input_param, input_shape, expected_shape, same_area=None): + base_comparison = None + input_image = self.get_arr(input_shape) + + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + # input parameters, such as roi_start can be numpy, torch, list etc. + for param_type in TEST_NDARRAYS_ALL + (None,): + with self.subTest(param_type=param_type): + input_param_mod = deepcopy(input_param) + if param_type is not None: + for k in ("roi_start", "roi_end", "roi_center", "roi_size", "roi_scale"): + if k in input_param: + input_param_mod[k] = param_type(input_param[k]) + im = im_type(input_image) + cropper = self.Cropper(**input_param_mod) + is_map = isinstance(cropper, MapTransform) + input_data = {"img": im} if is_map else im + result = cropper(input_data) + out_im = result["img"] if is_map else result + self.assertIsInstance(out_im, MetaTensor) + self.assertTupleEqual(out_im.shape, expected_shape) + if same_area is not None: + assert_allclose(out_im, im[same_area], type_test=False) + # check result is the same regardless of input type + if base_comparison is None: + base_comparison = out_im + else: + assert_allclose(out_im, base_comparison) + + # test inverse + inv = cropper.inverse(result) + inv_im = inv["img"] if is_map else inv + self.assertIsInstance(inv_im, MetaTensor) + if same_area is not None: + assert_allclose(inv_im[same_area], im[same_area], type_test=False) + self.assertEqual(inv_im.applied_operations, []) + + def crop_test_value(self, input_param, input_arr, expected_array): + cropper = self.Cropper(**input_param) + is_map = isinstance(cropper, MapTransform) + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + im = im_type(input_arr) + input_data = {"img": im} if is_map else im + result = self.Cropper(**input_param)(input_data) + out_im = result["img"] if is_map else result + self.assertIsInstance(out_im, MetaTensor) + assert_allclose(out_im, expected_array, type_test=False) + + def multi_inverse(self, input_shape, init_params): + input_data = np.arange(np.prod(input_shape)).reshape(*input_shape) + 1 + xform = self.Cropper(**init_params) + xform.set_random_state(1234) + out = xform(input_data) + if "num_samples" in init_params: + self.assertEqual(len(out), init_params["num_samples"]) + inv = xform.inverse(out) + self.assertIsInstance(inv, MetaTensor) + self.assertEqual(inv.applied_operations, []) + self.assertTrue("patch_index" not in inv.meta) + self.assertTupleEqual(inv.shape, input_shape) + inv_np = inv.numpy() + + # get list of all numbers that exist inside the crops + uniques = set() + for o in out: + uniques.update(set(o.flatten().tolist())) + + # make sure that + for i in uniques: + a = np.where(input_data == i) + b = np.where(inv_np == i) + self.assertTupleEqual(a, b) + # there should be as many zeros as elements missing from uniques + missing = input_data.size - len(uniques) + self.assertEqual((inv_np == 0).sum(), missing) diff --git a/tests/padders.py b/tests/padders.py new file mode 100644 index 0000000000..3fa3280cb5 --- /dev/null +++ b/tests/padders.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import List + +import numpy as np +import torch + +from monai.data.meta_tensor import MetaTensor +from monai.transforms.transform import MapTransform +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose + +MODES = [] +# Test modes +NP_MODES: List = [ + "constant", + "edge", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", + "wrap", + "median", + "mean", +] +MODES += NP_MODES +MODES += [NumpyPadMode(i) for i in NP_MODES] + +PT_MODES: list = [ + "constant", + "replicate", + "circular", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", +] +MODES += PT_MODES +MODES += [PytorchPadMode(i) for i in PT_MODES] + + +class PadTest(unittest.TestCase): + @staticmethod + def get_arr(shape): + return np.random.randint(100, size=shape).astype(float) + + def pad_test(self, input_param, input_shape, expected_shape, modes=None): + # loop over each mode + for mode in modes or MODES: + with self.subTest(mode=mode): + base_comparison = None + im = self.get_arr(input_shape) + padder = self.Padder(mode=mode, **input_param) + is_map = isinstance(padder, MapTransform) + # check result is the same regardless of input type + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + input_image = im_type(im) + input_data = {"img": im_type(im)} if is_map else im_type(im) + # our array transforms can also take `mode` as an argument to `__call__` + # Check this gives equivalent results + for call_extra_args in [{}] if is_map else [{}, {"mode": mode}]: + with self.subTest(call_extra_args=call_extra_args): + r_out = padder(input_data, **call_extra_args) + r_im = r_out["img"] if is_map else r_out + # check shape, type, etc. + np.testing.assert_allclose(r_im.shape, expected_shape) + self.assertIsInstance(r_im, MetaTensor) + self.assertEqual(len(r_im.applied_operations), 1) + # check results are same regardless of input type + if base_comparison is None: + base_comparison = r_im + else: + assert_allclose(r_im, base_comparison) + # test inverse + if isinstance(r_im, MetaTensor): + r_out = padder.inverse(r_out) + r_im = r_out["img"] if is_map else r_out + self.assertIsInstance(r_im, MetaTensor) + assert_allclose(r_im, input_image, type_test=False) + self.assertEqual(r_im.applied_operations, []) + + def pad_test_kwargs(self, unchanged_slices, **input_param): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + for kwargs in ({"value": 2}, {"constant_values": ((0, 0), (1, 1), (2, 2))}): + with self.subTest(kwargs=kwargs): + im = im_type(np.random.randint(-100, -10, size=(3, 8, 4))) + padder = self.Padder(**input_param, **kwargs) + result = padder(im) + if isinstance(result, torch.Tensor): + result = result.cpu() + assert_allclose(result[unchanged_slices], im, type_test=False) + # we should have the same as the input plus some 2s (if value) or 1s and 2s (if constant_values) + expected_vals = np.unique(im).tolist() + expected_vals += [2] if "value" in kwargs else [1, 2] + assert_allclose(np.unique(result), expected_vals, type_test=False) + # check inverse + if isinstance(result, MetaTensor): + inv = padder.inverse(result) + assert_allclose(im, inv, type_test=False) + self.assertEqual(inv.applied_operations, []) diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index b632ff831f..1194ae49a6 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -11,45 +11,32 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import BorderPad -from monai.utils import NumpyPadMode -from tests.utils import TEST_NDARRAYS - -TEST_CASE_1 = [{"spatial_border": 2, "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 12, 12, 8))] - -TEST_CASE_2 = [{"spatial_border": [1, 2, 3], "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 10, 12, 10))] - -TEST_CASE_3 = [ - {"spatial_border": [1, 2, 3, 4, 5, 6], "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 11, 15, 15)), +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.padders import PadTest + +TESTS = [ + [{"spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], + [{"spatial_border": [1, 2, 3]}, (3, 8, 8, 4), (3, 10, 12, 10)], + [{"spatial_border": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)], + [{"spatial_border": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)], ] -TEST_CASE_4 = [ - {"spatial_border": [1, 2, 3, 4, 5, 6], "mode": NumpyPadMode.CONSTANT}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 11, 15, 15)), -] +class TestBorderPad(PadTest): + Padder = BorderPad -class TestBorderPad(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) - def test_pad_shape(self, input_param, input_data, expected_val): - for p in TEST_NDARRAYS: - padder = BorderPad(**input_param) - r1 = padder(p(input_data)) - r2 = padder(input_data, mode=input_param["mode"]) - self.assertAlmostEqual(r1.shape, expected_val.shape) - self.assertAlmostEqual(r2.shape, expected_val.shape) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT] + self.pad_test(input_param, input_shape, expected_shape, modes) def test_pad_kwargs(self): - padder = BorderPad(spatial_border=2, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) - result = padder(np.zeros((3, 8, 4))) - np.testing.assert_allclose(result[:, :2, 2:6], np.ones((3, 2, 4))) - np.testing.assert_allclose(result[:, :, :2], np.ones((3, 12, 2)) + 1) + kwargs = {"spatial_border": 2, "mode": "constant"} + unchanged_slices = [slice(None), slice(2, -2), slice(2, -2)] + self.pad_test_kwargs(unchanged_slices, **kwargs) if __name__ == "__main__": diff --git a/tests/test_border_padd.py b/tests/test_border_padd.py index e4b8dd20ea..ca55c8b09d 100644 --- a/tests/test_border_padd.py +++ b/tests/test_border_padd.py @@ -11,49 +11,28 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import BorderPadd -from monai.utils import NumpyPadMode - -TEST_CASE_1 = [ - {"keys": ["img", "seg"], "spatial_border": 2, "mode": ["constant", "edge"]}, - {"img": np.zeros((3, 8, 8, 4)), "seg": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 12, 12, 8)), -] - -TEST_CASE_2 = [ - {"keys": "img", "spatial_border": [1, 2, 3], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 10, 12, 10)), -] - -TEST_CASE_3 = [ - {"keys": "img", "spatial_border": [1, 2, 3, 4, 5, 6], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 11, 15, 15)), -] - -TEST_CASE_4 = [ - {"keys": ["img", "seg"], "spatial_border": 2, "mode": ["constant", NumpyPadMode.EDGE]}, - {"img": np.zeros((3, 8, 8, 4)), "seg": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 12, 12, 8)), +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.padders import PadTest + +TESTS = [ + [{"keys": "img", "spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], + [{"keys": "img", "spatial_border": [1, 2, 3]}, (3, 8, 8, 4), (3, 10, 12, 10)], + [{"keys": "img", "spatial_border": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)], + [{"keys": "img", "spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], + [{"keys": "img", "spatial_border": 2}, (3, 8, 8, 4), (3, 12, 12, 8)], ] -TEST_CASE_5 = [ - {"keys": ["img", "seg"], "spatial_border": 2, "mode": [NumpyPadMode.CONSTANT, NumpyPadMode.EDGE]}, - {"img": np.zeros((3, 8, 8, 4)), "seg": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 12, 12, 8)), -] +class TestBorderPadd(PadTest): + Padder = BorderPadd -class TestBorderPadd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = BorderPadd(**input_param) - result = padder(input_data) - self.assertAlmostEqual(result["img"].shape, expected_val.shape) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, "edge", NumpyPadMode.EDGE] + self.pad_test(input_param, input_shape, expected_shape, modes) if __name__ == "__main__": diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index f22651e3e0..ab07a44eb5 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -9,43 +9,40 @@ # See the License for the specific language governing permissions and # limitations under the License. + import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import CenterScaleCrop +from tests.croppers import CropTest -TEST_CASE_0 = [{"roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] - -TEST_CASE_1 = [{"roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] - -TEST_CASE_2 = [ - {"roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), +TEST_SHAPES = [ + [{"roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3)], + [{"roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2)], ] -TEST_CASE_3 = [ - {"roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), +TEST_VALUES = [ + [ + {"roi_scale": [0.4, 0.4]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] ] -class TestCenterScaleCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterScaleCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) - np.testing.assert_allclose(result.shape, expected_shape) +class TestCenterSpatialCrop(CropTest): + Cropper = CenterScaleCrop + + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) - def test_value(self, input_param, input_data, expected_value): - result = CenterScaleCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) - np.testing.assert_allclose(result, expected_value) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_arr, expected_arr): + self.crop_test_value(input_param, input_arr, expected_arr) if __name__ == "__main__": diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 8aef2dbe5b..894692530d 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -12,44 +12,37 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import CenterScaleCropd +from tests.croppers import CropTest -TEST_CASE_0 = [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] - -TEST_CASE_1 = [{"keys": "img", "roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] - -TEST_CASE_2 = [ - {"keys": "img", "roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), +TESTS = [ + [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3)], + [{"keys": "img", "roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"keys": "img", "roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2)], ] -TEST_CASE_3 = [ - {"keys": "img", "roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] -TEST_CASE_4 = [ - {"keys": "test", "roi_scale": 0.6, "allow_missing_keys": True}, - np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"keys": "img", "roi_scale": [0.4, 0.4]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] ] -class TestCenterScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3, TEST_CASE_4]) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterScaleCropd(**input_param)({"img": input_data}) - np.testing.assert_allclose(result["img"].shape, expected_shape) +class TestCenterScaleCropd(CropTest): + Cropper = CenterScaleCropd + + @parameterized.expand(TESTS) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) - def test_value(self, input_param, input_data, expected_value): - result = CenterScaleCropd(**input_param)({"img": input_data}) - np.testing.assert_allclose(result["img"], expected_value) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_arr, expected_arr): + self.crop_test_value(input_param, input_arr, expected_arr) if __name__ == "__main__": diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 09f61be2f1..7b5b19107d 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -12,40 +12,36 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import CenterSpatialCrop +from tests.croppers import CropTest -TEST_CASE_0 = [{"roi_size": [2, 2, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 3)] - -TEST_CASE_1 = [{"roi_size": [2, 2, 2]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] - -TEST_CASE_2 = [ - {"roi_size": [2, 2]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), +TEST_SHAPES = [ + [{"roi_size": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3)], + [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], ] -TEST_CASE_3 = [ - {"roi_size": [2, 2, 2]}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), +TEST_VALUES = [ + [ + {"roi_size": [2, 2]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] ] -class TestCenterSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterSpatialCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) - np.testing.assert_allclose(result.shape, expected_shape) +class TestCenterSpatialCrop(CropTest): + Cropper = CenterSpatialCrop + + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) - def test_value(self, input_param, input_data, expected_value): - result = CenterSpatialCrop(**input_param)(input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) - np.testing.assert_allclose(result, expected_value) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_arr, expected_arr): + self.crop_test_value(input_param, input_arr, expected_arr) if __name__ == "__main__": diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index bdbc1a5031..fa7bc8c8fa 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -15,43 +15,42 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCropd -from tests.utils import TEST_NDARRAYS, assert_allclose - -TEST_SHAPES = [] -for p in TEST_NDARRAYS: - TEST_SHAPES.append( - [{"keys": "img", "roi_size": [2, -1, -1]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 3, 3)] - ) - - TEST_SHAPES.append( - [{"keys": "img", "roi_size": [2, 2, 2]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 2, 2)] - ) - -TEST_CASES = [] -for p in TEST_NDARRAYS: - TEST_CASES.append( - [ - {"keys": "img", "roi_size": [2, 2]}, - { - "img": p( - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) - ) - }, - p(np.array([[[1, 2], [2, 3]]])), - ] - ) - - -class TestCenterSpatialCropd(unittest.TestCase): +from tests.croppers import CropTest + +TEST_SHAPES = [ + [ + {"keys": "img", "roi_size": [2, -1, -1]}, + (3, 3, 3, 3), + (3, 2, 3, 3), + (slice(None), slice(None, -1), slice(None), slice(None)), + ], + [ + {"keys": "img", "roi_size": [2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, -1), slice(None, -1), slice(None, -1)), + ], +] + +TEST_CASES = [ + [ + {"keys": "img", "roi_size": [2, 2]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + np.array([[[1, 2], [2, 3]]]), + ] +] + + +class TestCenterSpatialCropd(CropTest): + Cropper = CenterSpatialCropd + @parameterized.expand(TEST_SHAPES) - def test_shape(self, input_param, input_data, expected_shape): - result = CenterSpatialCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + def test_shape(self, input_param, input_shape, expected_shape, same_area): + self.crop_test(input_param, input_shape, expected_shape, same_area) @parameterized.expand(TEST_CASES) def test_value(self, input_param, input_data, expected_value): - result = CenterSpatialCropd(**input_param)(input_data) - assert_allclose(result["img"], expected_value, type_test=False) + self.crop_test_value(input_param, input_data, expected_value) if __name__ == "__main__": diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index af945673fe..e400406e4d 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -12,15 +12,15 @@ import unittest import numpy as np -import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForeground -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_COORDS, TESTS = [], [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_COORDS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, @@ -89,8 +89,15 @@ class TestCropForeground(unittest.TestCase): @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, argments, image, expected_data): - result = CropForeground(**argments)(image) - torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) + cropper = CropForeground(**argments) + result = cropper(image) + assert_allclose(result, expected_data, type_test=False) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(len(result.applied_operations), 1) + inv = cropper.inverse(result) + self.assertIsInstance(inv, MetaTensor) + self.assertEqual(inv.applied_operations, []) + self.assertTupleEqual(inv.shape, image.shape) @parameterized.expand(TEST_COORDS) def test_return_coords(self, argments, image, _): diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index fa69143827..d641c5a376 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -12,14 +12,13 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import CropForegroundd -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_POSITION, TESTS = [], [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_POSITION.append( [ @@ -151,12 +150,15 @@ class TestCropForegroundd(unittest.TestCase): @parameterized.expand(TEST_POSITION + TESTS) def test_value(self, argments, input_data, expected_data): - result = CropForegroundd(**argments)(input_data) - r, i = result["img"], input_data["img"] - self.assertEqual(type(r), type(i)) - if isinstance(r, torch.Tensor): - self.assertEqual(r.device, i.device) - assert_allclose(r, expected_data) + cropper = CropForegroundd(**argments) + result = cropper(input_data) + assert_allclose(result["img"], expected_data, type_test=False) + if "label" in input_data and "img" in input_data: + self.assertTupleEqual(result["img"].shape, result["label"].shape) + inv = cropper.inverse(result) + self.assertTupleEqual(inv["img"].shape, input_data["img"].shape) + if "label" in input_data: + self.assertTupleEqual(inv["label"].shape, input_data["label"].shape) @parameterized.expand(TEST_POSITION) def test_foreground_position(self, argments, input_data, _): diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index f940636fa8..df610c4939 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -11,43 +11,32 @@ import unittest -import numpy as np -import torch from parameterized import parameterized from monai.transforms import DivisiblePad -from tests.utils import TEST_NDARRAYS +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.padders import PadTest TESTS = [] -for p in TEST_NDARRAYS: - # pad first dim to be divisible by 7, the second unchanged. - TESTS.append([{"k": (7, -1), "mode": "constant"}, p(np.zeros((3, 8, 7))), p(np.zeros((3, 14, 7)))]) +# pad first dim to be divisible by 7, the second unchanged. +TESTS.append([{"k": (7, -1)}, (3, 8, 7), (3, 14, 7)]) +# pad all dimensions to be divisible by 5 +TESTS.append([{"k": 5, "method": "end"}, (3, 10, 5, 17), (3, 10, 5, 20)]) - # pad all dimensions to be divisible by 5 - TESTS.append( - [{"k": 5, "mode": "constant", "method": "end"}, p(np.zeros((3, 10, 5, 17))), p(np.zeros((3, 10, 5, 20)))] - ) +class TestDivisiblePad(PadTest): + Padder = DivisiblePad -class TestDivisiblePad(unittest.TestCase): @parameterized.expand(TESTS) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = DivisiblePad(**input_param) - result = padder(input_data) - self.assertAlmostEqual(result.shape, expected_val.shape) - result = padder(input_data, mode=input_param["mode"]) - self.assertAlmostEqual(result.shape, expected_val.shape) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT] + self.pad_test(input_param, input_shape, expected_shape, modes) def test_pad_kwargs(self): - for p in TEST_NDARRAYS: - input_data = p(np.zeros((3, 8, 4))) - if isinstance(input_data, np.ndarray): - result = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))(input_data) - np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) - else: - result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() - torch.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1, rtol=1e-7, atol=0) + kwargs = {"k": 5, "method": "end"} + unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)] + self.pad_test_kwargs(unchanged_slices, **kwargs) if __name__ == "__main__": diff --git a/tests/test_divisible_padd.py b/tests/test_divisible_padd.py index 61fe917421..93e5a879f0 100644 --- a/tests/test_divisible_padd.py +++ b/tests/test_divisible_padd.py @@ -11,32 +11,25 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import DivisiblePadd +from monai.utils.enums import NumpyPadMode, PytorchPadMode +from tests.padders import PadTest -TEST_CASE_1 = [ - {"keys": ["img"], "k": [4, 3, 2], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 8, 9, 4)), +TESTS = [ + [{"keys": "img", "k": [4, 3, 2]}, (3, 8, 8, 4), (3, 8, 9, 4)], + [{"keys": "img", "k": 7, "method": "end"}, (3, 8, 7), (3, 14, 7)], ] -TEST_CASE_2 = [ - {"keys": ["img"], "k": 7, "mode": "constant", "method": "end"}, - {"img": np.zeros((3, 8, 7))}, - np.zeros((3, 14, 7)), -] - -TEST_CASE_3 = [{"keys": ["img"], "k": 0, "mode": {"constant"}}, {"img": np.zeros((3, 8))}, np.zeros((3, 8))] +class TestDivisiblePadd(PadTest): + Padder = DivisiblePadd -class TestDivisiblePadd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = DivisiblePadd(**input_param) - result = padder(input_data) - np.testing.assert_allclose(result["img"], expected_val) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, "edge", NumpyPadMode.EDGE] + self.pad_test(input_param, input_shape, expected_shape, modes) if __name__ == "__main__": diff --git a/tests/test_module_list.py b/tests/test_module_list.py index d81d067c58..d0b5aaf26b 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -38,7 +38,7 @@ def test_public_api(self): def test_transform_api(self): """monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'""" to_exclude = {"MapTransform"} # except for these transforms - to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision"} + to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision", "RandCrop"} to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"}) to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"}) xforms = { diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index 11d73df74e..b1165e8986 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -15,10 +15,10 @@ from parameterized import parameterized from monai.transforms import ClassesToIndices, RandCropByLabelClasses -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS_INDICES, TESTS_SHAPE = [], [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: # One-Hot label TESTS_INDICES.append( [ diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 92780458e0..9a99ebab29 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -15,10 +15,10 @@ from parameterized import parameterized from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TESTS.append( [ # One-Hot label diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index 1d9e2612c7..f6da393ab9 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import RandCropByPosNegLabel -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [ [ @@ -103,7 +103,7 @@ def convert_data_type(im_type, d, keys=("img", "image", "label")): @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_shape): results = [] - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: input_param_mod = self.convert_data_type(p, input_param) input_data_mod = self.convert_data_type(p, input_data) cropper = RandCropByPosNegLabel(**input_param_mod) diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index a2808bd65d..64673bf4bf 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -16,8 +16,7 @@ from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld -from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL TESTS = [ [ @@ -35,7 +34,6 @@ "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - PostFix.meta("image"): {"affine": np.eye(3), "shape": "CHWD"}, }, (3, 3, 2, 2), ], @@ -54,7 +52,6 @@ "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - PostFix.meta("label"): {"affine": np.eye(3), "shape": "CHWD"}, }, (3, 2, 2, 2), ], @@ -69,12 +66,7 @@ "image_key": None, "image_threshold": 0, }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, - }, + {"image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3])}, (3, 2, 2, 2), ], [ @@ -89,12 +81,7 @@ "image_threshold": 0, "allow_smaller": True, }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, - }, + {"image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3])}, (3, 3, 3, 2), ], [ @@ -109,12 +96,7 @@ "image_threshold": 0, "allow_smaller": True, }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, - }, + {"image": np.zeros([3, 3, 3, 3]) - 1, "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3])}, (3, 3, 3, 3), ], ] @@ -131,13 +113,13 @@ def convert_data_type(im_type, d, keys=("img", "image", "label")): @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: input_param_mod = self.convert_data_type(p, input_param) input_data_mod = self.convert_data_type(p, input_data) cropper = RandCropByPosNegLabeld(**input_param_mod) cropper.set_random_state(0) result = cropper(input_data_mod) - self.assertListEqual(cropper.spatial_size, input_param["spatial_size"]) + self.assertListEqual(cropper.cropper.spatial_size, input_param["spatial_size"]) self.assertIsInstance(result, list) @@ -146,7 +128,7 @@ def test_type_shape(self, input_param, input_data, expected_shape): for k in ("image", "extra", "label"): self.assertTupleEqual(result[0][k].shape, expected_shape) for i, item in enumerate(result): - self.assertEqual(item[PostFix.meta(k)]["patch_index"], i) + self.assertEqual(item[k].meta["patch_index"], i) def test_correct_center(self): cropper = RandCropByPosNegLabeld(keys="label", label_key="label", spatial_size=[3, 3]) diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index 5d6312002f..aea26d62bb 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -15,66 +15,57 @@ from parameterized import parameterized from monai.transforms import RandScaleCrop -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.croppers import CropTest +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_1 = [ - {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), +TEST_SHAPES = [ + [{"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, (3, 3, 3, 4), (3, 3, 3, 4)], + [{"roi_scale": [1.0, 1.0, 1.0], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], ] -TEST_CASE_2 = [ - {"roi_scale": [1.0, 1.0, 1.0], "random_center": False}, - np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"roi_scale": [0.6, 0.6], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_3 = [ - {"roi_scale": [0.6, 0.6], "random_center": False}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), +TEST_RANDOM_SHAPES = [ + [ + {"roi_scale": [0.75, 0.6, 0.5], "max_roi_scale": [1.0, -1.0, 0.6], "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 4, 3), + ], + [{"roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, (1, 4, 5, 6), (1, 3, 4, 4)], + [{"roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, (1, 4, 5, 6), (1, 3, 2, 4)], ] -TEST_CASE_4 = [ - {"roi_scale": [0.75, 0.6, 0.5], "max_roi_scale": [1.0, -1.0, 0.6], "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 4, 3), -] - -TEST_CASE_5 = [ - {"roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 4, 4), -] - -TEST_CASE_6 = [ - {"roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 2, 4), -] +class TestRandScaleCrop(CropTest): + Cropper = RandScaleCrop -class TestRandScaleCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: - result = RandScaleCrop(**input_param)(p(input_data)) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS: - cropper = RandScaleCrop(**input_param) - result = cropper(p(input_data)) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandScaleCrop(**input_param) + result = cropper(im_type(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_random_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: - cropper = RandScaleCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(p(input_data)) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(seed=123) + input_data = im_type(np.random.randint(0, 2, input_shape)) + result = cropper(input_data) + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 5e833fef98..645c058dfb 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -15,74 +15,77 @@ from parameterized import parameterized from monai.transforms import RandScaleCropd -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.croppers import CropTest +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), +TEST_SHAPES = [ + [{"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, (3, 3, 3, 4), (3, 3, 3, 4)], + [ + # test `allow_missing_keys` with key "label" + {"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, + (3, 3, 3, 3), + (3, 3, 3, 3), + ], ] -TEST_CASE_2 = [ - # test `allow_missing_keys` with key "label" - {"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"keys": "img", "roi_scale": [0.6, 0.6], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_3 = [ - {"keys": "img", "roi_scale": [0.6, 0.6], "random_center": False}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, +TEST_RANDOM_SHAPES = [ + [ + { + "keys": "img", + "roi_scale": [0.75, 0.6, 0.5], + "max_roi_scale": [1.0, -1.0, 0.6], + "random_center": True, + "random_size": True, + }, + (1, 4, 5, 6), + (1, 3, 4, 3), + ], + [ + {"keys": "img", "roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 4, 4), + ], + [ + {"keys": "img", "roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 2, 4), + ], ] -TEST_CASE_4 = [ - { - "keys": "img", - "roi_scale": [0.75, 0.6, 0.5], - "max_roi_scale": [1.0, -1.0, 0.6], - "random_center": True, - "random_size": True, - }, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 4, 3), -] - -TEST_CASE_5 = [ - {"keys": "img", "roi_scale": 0.6, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 4, 4), -] - -TEST_CASE_6 = [ - {"keys": "img", "roi_scale": 0.2, "max_roi_scale": 0.8, "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 2, 4), -] +class TestRandScaleCropd(CropTest): + Cropper = RandScaleCropd -class TestRandScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) - def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS: - cropper = RandScaleCropd(**input_param) - input_data["img"] = p(input_data["img"]) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - assert_allclose( - result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False - ) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_im): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + input_data = {"img": im_type(input_im)} + result = cropper(input_data)["img"] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] + assert_allclose(result, input_im[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCropd(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + cropper.set_random_state(seed=123) + input_data = {"img": im_type(np.random.randint(0, 2, input_shape))} + result = cropper(input_data)["img"] + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 8f4bb0fffa..0c8d4ab132 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -15,60 +15,57 @@ from parameterized import parameterized from monai.transforms import RandSpatialCrop -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.croppers import CropTest +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_0 = [ - {"roi_size": [3, 3, -1], "random_center": True}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), +TEST_SHAPES = [ + [{"roi_size": [3, 3, -1], "random_center": True}, (3, 3, 3, 4), (3, 3, 3, 4)], + [{"roi_size": [3, 3, 3], "random_center": True}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"roi_size": [3, 3, 3], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], ] -TEST_CASE_1 = [{"roi_size": [3, 3, 3], "random_center": True}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 3, 3, 3)] - -TEST_CASE_2 = [ - {"roi_size": [3, 3, 3], "random_center": False}, - np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 3, 3, 3), -] - -TEST_CASE_3 = [ - {"roi_size": [3, 3], "random_center": False}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), +TEST_VALUES = [ + [ + {"roi_size": [3, 3], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_4 = [ - {"roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 4, 4, 3), +TEST_RANDOM_SHAPES = [ + [ + {"roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 4, 4, 3), + ], + [{"roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, (1, 4, 5, 6), (1, 3, 4, 3)], ] -TEST_CASE_5 = [ - {"roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, - np.random.randint(0, 2, size=[1, 4, 5, 6]), - (1, 3, 4, 3), -] +class TestRandSpatialCrop(CropTest): + Cropper = RandSpatialCrop -class TestRandSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - result = RandSpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data): - for p in TEST_NDARRAYS: - cropper = RandSpatialCrop(**input_param) - result = cropper(p(input_data)) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandSpatialCrop(**input_param) + result = cropper(im_type(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) - def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandSpatialCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = RandSpatialCrop(**input_param) + cropper.set_random_state(seed=123) + input_data = im_type(np.random.randint(0, 2, input_shape)) + result = cropper(input_data) + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 18fdf38773..50571b5955 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -15,11 +15,12 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropSamples -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.croppers import CropTest +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, - np.arange(192).reshape(3, 4, 4, 4), + (3, 4, 4, 4), [(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], np.array( [ @@ -44,7 +45,7 @@ TEST_CASE_2 = [ {"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True}, - np.arange(192).reshape(3, 4, 4, 4), + (3, 4, 4, 4), [(3, 4, 4, 3), (3, 4, 3, 3), (3, 3, 4, 4), (3, 4, 4, 4), (3, 3, 3, 4), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], np.array( [ @@ -67,18 +68,31 @@ ), ] +TEST_INVERSE_LIST = [ + [(1, 2, 2), {"roi_size": (1, 1), "num_samples": 4, "random_size": False}], + [(1, 3, 2), {"roi_size": (1, 1), "num_samples": 100, "random_size": False}], + [(3, 10, 11, 12), {"roi_size": (3, 5, 4), "num_samples": 7, "random_size": False}], + [(3, 10, 11, 12), {"roi_size": (10, 11, 12), "num_samples": 3, "random_size": False}], + [(3, 10, 11, 12), {"roi_size": (3, 4, 5), "num_samples": 100, "random_size": False}], +] + + +class TestRandSpatialCropSamples(CropTest): + Cropper = RandSpatialCropSamples -class TestRandSpatialCropSamples(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape, expected_last_item): - for p in TEST_NDARRAYS: + def test_shape(self, input_param, input_shape, expected_shape, expected_last_item): + input_data = np.arange(192).reshape(*input_shape) + + for p in TEST_NDARRAYS_ALL: xform = RandSpatialCropSamples(**input_param) xform.set_random_state(1234) result = xform(p(input_data)) np.testing.assert_equal(len(result), input_param["num_samples"]) - for item, expected in zip(result, expected_shape): + for i, (item, expected) in enumerate(zip(result, expected_shape)): self.assertTupleEqual(item.shape, expected) + self.assertEqual(item.meta["patch_index"], i) assert_allclose(result[-1], expected_last_item, type_test=False) diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 0891068488..4da438d2a0 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -14,46 +14,45 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Compose, RandSpatialCropSamplesd, ToTensord -from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS, assert_allclose +from monai.transforms import Compose, DivisiblePadd, RandSpatialCropSamplesd +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)], + [(3, 2, 2, 2), (3, 2, 3, 3), (3, 2, 3, 2), (3, 2, 3, 2)], { "img": np.array( [ - [[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]], - [[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]], - [[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]], + [[[1, 2], [4, 5], [7, 8]], [[10, 11], [13, 14], [16, 17]]], + [[[28, 29], [31, 32], [34, 35]], [[37, 38], [40, 41], [43, 44]]], + [[[55, 56], [58, 59], [61, 62]], [[64, 65], [67, 68], [70, 71]]], ] ), "seg": np.array( [ - [[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]], - [[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]], - [[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]], + [[[80, 79], [77, 76], [74, 73]], [[71, 70], [68, 67], [65, 64]]], + [[[53, 52], [50, 49], [47, 46]], [[44, 43], [41, 40], [38, 37]]], + [[[26, 25], [23, 22], [20, 19]], [[17, 16], [14, 13], [11, 10]]], ] ), }, ] TEST_CASE_2 = [] -for p in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: TEST_CASE_2.append( [ {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, [ - (3, 3, 3, 3), - (3, 2, 3, 3), (3, 2, 2, 3), - (3, 2, 3, 3), + (3, 2, 2, 3), (3, 3, 3, 3), + (3, 2, 3, 3), (3, 3, 3, 3), - (3, 2, 2, 3), + (3, 2, 3, 3), + (3, 2, 3, 3), (3, 3, 2, 3), ], { @@ -90,10 +89,10 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): self.assertTupleEqual(item["img"].shape, expected) self.assertTupleEqual(item["seg"].shape, expected) for i, item in enumerate(result): - self.assertEqual(item[PostFix.meta("img")]["patch_index"], i) - self.assertEqual(item[PostFix.meta("seg")]["patch_index"], i) - assert_allclose(item["img"], expected_last["img"], type_test=True) - assert_allclose(item["seg"], expected_last["seg"], type_test=True) + self.assertEqual(item["img"].meta["patch_index"], i) + self.assertEqual(item["seg"].meta["patch_index"], i) + assert_allclose(item["img"], expected_last["img"], type_test=False) + assert_allclose(item["seg"], expected_last["seg"], type_test=False) def test_deep_copy(self): data = {"img": np.ones((1, 10, 11, 12))} @@ -101,11 +100,11 @@ def test_deep_copy(self): sampler = RandSpatialCropSamplesd( keys=["img"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False ) - transform = Compose([ToTensord(keys="img"), sampler]) + transform = Compose([DivisiblePadd(keys="img", k=5), sampler]) samples = transform(data) self.assertEqual(len(samples), num_samples) for sample in samples: - self.assertEqual(len(sample["img_transforms"]), len(transform)) + self.assertEqual(len(sample["img"].applied_operations), len(transform)) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 9e6e86eea2..c6a0fbe5e7 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -15,65 +15,62 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropd -from tests.utils import TEST_NDARRAYS +from tests.croppers import CropTest +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_CASE_0 = [ - {"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 5])}, - (3, 3, 3, 5), +TEST_SHAPES = [ + [{"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, (3, 3, 3, 5), (3, 3, 3, 5)], + [{"keys": "img", "roi_size": [3, 3, 3], "random_center": True}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"keys": "img", "roi_size": [3, 3, 3], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], ] -TEST_CASE_1 = [ - {"keys": "img", "roi_size": [3, 3, 3], "random_center": True}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 3, 3, 3), +TEST_VALUES = [ + [ + {"keys": "img", "roi_size": [3, 3], "random_center": False}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + ] ] -TEST_CASE_2 = [ - {"keys": "img", "roi_size": [3, 3, 3], "random_center": False}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 3, 3, 3), +TEST_RANDOM_SHAPES = [ + [ + {"keys": "img", "roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 4, 4, 3), + ], + [ + {"keys": "img", "roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, + (1, 4, 5, 6), + (1, 3, 4, 3), + ], ] -TEST_CASE_3 = [ - {"keys": "img", "roi_size": [3, 3], "random_center": False}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, -] - -TEST_CASE_4 = [ - {"keys": "img", "roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 4, 4, 3), -] - -TEST_CASE_5 = [ - {"keys": "img", "roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, - {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, - (1, 3, 4, 3), -] +class TestRandSpatialCropd(CropTest): + Cropper = RandSpatialCropd -class TestRandSpatialCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): - result = RandSpatialCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_shape(self, input_param, input_shape, expected_shape): + self.crop_test(input_param, input_shape, expected_shape) - @parameterized.expand([TEST_CASE_3]) - def test_value(self, input_param, input_data): - cropper = RandSpatialCropd(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + @parameterized.expand(TEST_VALUES) + def test_value(self, input_param, input_im): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + input_data = {"img": im_type(input_im)} + result = cropper(input_data)["img"] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size] + assert_allclose(result, input_im[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) - def test_random_shape(self, input_param, input_data, expected_shape): - for p in TEST_NDARRAYS: - cropper = RandSpatialCropd(**input_param) - cropper.set_random_state(seed=123) - input_data["img"] = p(input_data["img"]) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + @parameterized.expand(TEST_RANDOM_SHAPES) + def test_random_shape(self, input_param, input_shape, expected_shape): + for im_type in TEST_NDARRAYS_ALL: + with self.subTest(im_type=im_type): + cropper = self.Cropper(**input_param) + cropper.set_random_state(seed=123) + input_data = {"img": im_type(np.random.randint(0, 2, input_shape))} + result = cropper(input_data)["img"] + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index dae7f05016..696de9c05e 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -12,11 +12,12 @@ import unittest import numpy as np -import torch from parameterized.parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import RandWeightedCrop -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose +from tests.croppers import CropTest +from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose def get_data(ndim): @@ -30,8 +31,8 @@ def get_data(ndim): TESTS = [] -for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: +for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: im = SEG1_2D weight = np.zeros_like(im) weight[0, 30, 17] = 1.1 @@ -148,7 +149,9 @@ def get_data(ndim): ) -class TestRandWeightedCrop(unittest.TestCase): +class TestRandWeightedCrop(CropTest): + Cropper = RandWeightedCrop + @parameterized.expand(TESTS) def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, expected_vals): crop = RandWeightedCrop(**input_params) @@ -161,10 +164,9 @@ def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, # if desired ROI is larger than image, check image is unchanged if all(s >= i for i, s in zip(img.shape[1:], input_params["spatial_size"])): for res in result: - self.assertEqual(type(img), type(res)) - if isinstance(img, torch.Tensor): - self.assertEqual(res.device, img.device) - assert_allclose(res, img) + self.assertIsInstance(res, MetaTensor) + assert_allclose(res, img, type_test=False) + self.assertEqual(len(res.applied_operations), 1) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index a357398f1c..b3fc92b445 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -12,180 +12,144 @@ import unittest import numpy as np +from parameterized import parameterized from monai.transforms.croppad.dictionary import RandWeightedCropd -from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose - - -class TestRandWeightedCrop(NumpyImageTestCase2D): - def test_rand_weighted_crop_small_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - d = {"img": p(img), "w": q(weight)} - result = crop(d) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) - for c, e in zip(crop.centers, [[80, 21], [30, 17], [40, 31]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_default_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - data = {"im": p(img), "weight": q(weight), "others": np.nan} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) - for c, e in zip(crop.centers, [[14, 32], [105, 32], [20, 32]]): - assert_allclose(c, e, type_test=False) - assert_allclose(result[1]["coords"], [105, 32], type_test=False) - - def test_rand_weighted_crop_large_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - data = {"img": p(img), "seg": p(self.imt[0]), "weight": q(weight)} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) - for c, e in zip(crop.centers, [[64, 32], [64, 32], [64, 32]]): - assert_allclose(c, e, type_test=False) - assert_allclose(result[1]["location"], [64, 32], type_test=False) - - def test_rand_weighted_crop_bad_w(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) - for c, e in zip(crop.centers, [[63, 37], [31, 43], [66, 20]]): - assert_allclose(c, e, type_test=False) - - -class TestRandWeightedCrop3D(NumpyImageTestCase3D): - def test_rand_weighted_crop_small_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop({"img": p(img), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) - for c, e in zip(crop.centers, [[11, 23, 21], [5, 30, 17], [8, 40, 31]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_default_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) - for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_large_roi(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop({"img": p(img), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_bad_w(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) - for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): - assert_allclose(c, e, type_test=False) - - def test_rand_weighted_crop_patch_index(self): - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop( - {"img": p(img), "seg": p(self.segn[0]), "w": q(weight), PostFix.meta("img"): {"affine": None}} - ) - self.assertTrue(len(result) == n_samples) - for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): - assert_allclose(c, e, type_test=False) - for i in range(n_samples): - np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i][PostFix.meta("img")]["patch_index"], i) - np.testing.assert_allclose(result[i][PostFix.meta("seg")]["patch_index"], i) +from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D + + +def get_data(ndim): + im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D() + im_gen.setUp() + return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0] + + +IMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2) +IMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3) + +TESTS = [] +for p in TEST_NDARRAYS_ALL: + for q in TEST_NDARRAYS_ALL: + im = IMT_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + TESTS.append( + [ + "small roi 2d", + dict(keys="img", w_key="w", spatial_size=(10, 12), num_samples=3), + {"img": p(im), "w": q(weight)}, + (1, 10, 12), + [[80, 21], [30, 17], [40, 31]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + TESTS.append( + [ + "default roi 2d", + dict(keys="img", w_key="w", spatial_size=(10, -1), num_samples=3), + {"img": p(im), "w": q(weight), "others": np.nan}, + (1, 10, 64), + [[14, 32], [105, 32], [20, 32]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + TESTS.append( + [ + "large roi 2d", + dict(keys=("img", "seg"), w_key="weight", spatial_size=(10000, 400), num_samples=3), + {"img": p(im), "seg": p(SEGN_2D), "weight": q(weight)}, + (1, 128, 64), + [[64, 32], [64, 32], [64, 32]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w roi 2d", + dict(keys=("img", "seg"), w_key="w", spatial_size=(20, 40), num_samples=3), + {"img": p(im), "seg": p(SEGN_2D), "w": q(weight)}, + (1, 20, 40), + [[63, 37], [31, 43], [66, 20]], + ] + ) + + im = IMT_3D + weight = np.zeros_like(im) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + TESTS.append( + [ + "small roi 3d", + dict(keys="img", w_key="w", spatial_size=(8, 10, 12), num_samples=3), + {"img": p(im), "w": q(weight)}, + (1, 8, 10, 12), + [[11, 23, 21], [5, 30, 17], [8, 40, 31]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + TESTS.append( + [ + "default roi 3d", + dict(keys=("img", "seg"), w_key="w", spatial_size=(10, -1, -1), num_samples=3), + {"img": p(im), "seg": p(SEGN_3D), "w": q(weight)}, + (1, 10, 64, 80), + [[14, 32, 40], [41, 32, 40], [20, 32, 40]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + TESTS.append( + [ + "large roi 3d", + dict(keys="img", w_key="w", spatial_size=(10000, 400, 80), num_samples=3), + {"img": p(im), "w": q(weight)}, + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w roi 3d", + dict(keys=("img", "seg"), w_key="w", spatial_size=(48, 64, 80), num_samples=3), + {"img": p(im), "seg": p(SEGN_3D), "w": q(weight)}, + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + + +class TestRandWeightedCrop(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, expected_centers): + crop = RandWeightedCropd(**init_params) + crop.set_random_state(10) + result = crop(input_data) + self.assertTrue(len(result) == init_params["num_samples"]) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index f81e1d4b08..4e097cd3d4 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -15,8 +15,9 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCrop -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL, pytorch_after TEST_CASES = [ [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8)], @@ -26,24 +27,38 @@ (3, 15, 4, 8), ], [{"spatial_size": [15, 4, -1], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 4, 4)], - [{"spatial_size": [15, 4, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 15, 4, 4)], - [{"spatial_size": [-1, -1, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 8, 8, 4)], + [ + {"spatial_size": [15, 4, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + (3, 8, 8, 4), + (3, 15, 4, 4), + ], + [ + {"spatial_size": [-1, -1, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + (3, 8, 8, 4), + (3, 8, 8, 4), + ], ] class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: if isinstance(p(0), torch.Tensor) and ( "constant_values" in input_param or input_param["mode"] == "reflect" ): continue - paddcroper = ResizeWithPadOrCrop(**input_param) - result = paddcroper(p(np.zeros(input_shape))) + padcropper = ResizeWithPadOrCrop(**input_param) + result = padcropper(p(np.zeros(input_shape))) np.testing.assert_allclose(result.shape, expected_shape) - result = paddcroper(p(np.zeros(input_shape)), mode="constant") + result = padcropper(p(np.zeros(input_shape)), mode="constant") np.testing.assert_allclose(result.shape, expected_shape) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(len(result.applied_operations), 1) + inv = padcropper.inverse(result) + self.assertTupleEqual(inv.shape, input_shape) + self.assertIsInstance(inv, MetaTensor) + self.assertEqual(inv.applied_operations, []) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 28993a2bf4..eb4e5f09cc 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS_ALL, pytorch_after TEST_CASES = [ [{"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 8, 8)], @@ -26,23 +26,34 @@ (3, 15, 4, 8), ], [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], - [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], - [{"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 8, 8, 4)], + [ + {"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + {"img": np.zeros((3, 8, 8, 4))}, + (3, 15, 4, 4), + ], + [ + {"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, + {"img": np.zeros((3, 8, 8, 4))}, + (3, 8, 8, 4), + ], ] class TestResizeWithPadOrCropd(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_data, expected_val): - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS_ALL: if isinstance(p(0), torch.Tensor) and ( "constant_values" in input_param or input_param["mode"] == "reflect" ): continue - paddcroper = ResizeWithPadOrCropd(**input_param) + padcropper = ResizeWithPadOrCropd(**input_param) input_data["img"] = p(input_data["img"]) - result = paddcroper(input_data) + result = padcropper(input_data) np.testing.assert_allclose(result["img"].shape, expected_val) + inv = padcropper.inverse(result) + for k in input_data: + self.assertTupleEqual(inv[k].shape, input_data[k].shape) if __name__ == "__main__": diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index bf1eb11491..6fdfbd3f70 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -11,12 +11,10 @@ import unittest -import numpy as np -import torch from parameterized import parameterized from monai.transforms import SpatialCrop -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.croppers import CropTest TESTS = [ [{"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], @@ -26,31 +24,28 @@ [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [8, 8, 8, 2, 2]}, (3, 3, 3, 3), (3, 3, 3, 3)], [{"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, (3, 3, 3, 3), (3, 0, 3, 3)], - [{"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, (3, 3, 3, 3), (3, 1, 2, 2)], + [ + {"roi_slices": [slice(s, e) for s, e in zip([None, None, None], [None, None, None])]}, + (3, 11, 12, 15), + (3, 11, 12, 15), + ], + [{"roi_slices": [slice(s, e) for s, e in zip([1, None, 0], [None, None, None])]}, (3, 7, 9, 11), (3, 6, 9, 11)], + [{"roi_slices": [slice(s, e) for s, e in zip([0, None, None], [-1, None, None])]}, (3, 7, 9, 11), (3, 6, 9, 11)], + [{"roi_slices": [slice(s, e) for s, e in zip([1, None, None], [None, None, None])]}, (3, 10, 8, 6), (3, 9, 8, 6)], + [{"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, (3, 15, 17, 8), (3, 1, 2, 2)], + [{"roi_slices": [slice(s, e) for s, e in zip([None, None, None], [-2, -1, 2])]}, (3, 13, 8, 6), (3, 11, 7, 2)], + [{"roi_start": [-1, 0], "roi_end": [5, 5]}, (1, 5, 5), (1, 5, 5)], ] TEST_ERRORS = [[{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}]] -class TestSpatialCrop(unittest.TestCase): +class TestSpatialCrop(CropTest): + Cropper = SpatialCrop + @parameterized.expand(TESTS) def test_shape(self, input_param, input_shape, expected_shape): - input_data = np.random.randint(0, 2, size=input_shape) - results = [] - for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS + (None,): - input_param_mod = { - k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() - } - im = p(input_data) - result = SpatialCrop(**input_param_mod)(im) - self.assertEqual(type(im), type(result)) - if isinstance(result, torch.Tensor): - self.assertEqual(result.device, im.device) - self.assertTupleEqual(result.shape, expected_shape) - results.append(result) - if len(results) > 1: - assert_allclose(results[0], results[-1], type_test=False) + self.crop_test(input_param, input_shape, expected_shape) @parameterized.expand(TEST_ERRORS) def test_error(self, input_param): diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 5b16f460fd..11f6da0811 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -11,56 +11,57 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import SpatialCropd -from tests.utils import TEST_NDARRAYS +from tests.croppers import CropTest -TESTS = [] -for p in TEST_NDARRAYS: - TESTS.append( - [ - {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 2), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 2), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 3), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 2, 2, 2), - ] - ) - TESTS.append( - [ - {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, - {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, - (3, 1, 2, 2), - ] - ) +TESTS = [ + [ + {"keys": ["img"], "roi_center": [1, 1], "roi_size": [2, 2]}, + (1, 3, 3), + (1, 2, 2), + (slice(None), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 3), + (slice(None), slice(None, 2), slice(None, 2), slice(None)), + ], + [ + {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, + (3, 3, 3, 3), + (3, 2, 2, 2), + (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)), + ], + [ + {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, + (3, 3, 3, 3), + (3, 1, 2, 2), + (slice(None), slice(-1, None), slice(-2, None), slice(0, 2)), + ], +] -class TestSpatialCropd(unittest.TestCase): +class TestSpatialCropd(CropTest): + Cropper = SpatialCropd + @parameterized.expand(TESTS) - def test_shape(self, input_param, input_data, expected_shape): - result = SpatialCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + def test_shape(self, input_param, input_shape, expected_shape, same_area): + self.crop_test(input_param, input_shape, expected_shape, same_area) if __name__ == "__main__": diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 4cdeb6d64e..5a70c10686 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -10,91 +10,28 @@ # limitations under the License. import unittest -from typing import List -import numpy as np -import torch from parameterized import parameterized from monai.transforms import SpatialPad -from monai.utils.enums import NumpyPadMode, PytorchPadMode -from monai.utils.misc import set_determinism -from tests.utils import TEST_NDARRAYS +from tests.padders import PadTest TESTS = [] +TESTS.append([{"spatial_size": [3, 4], "method": "end"}, (1, 2, 3), (1, 3, 4)]) +TESTS.append([{"spatial_size": [15, 4, -1], "method": "symmetric"}, (3, 8, 8, 4), (3, 15, 8, 4)]) -MODES = [] -# Test modes -NP_MODES: List = [ - "constant", - "edge", - # `reflect` mode is not supported in some PyTorch versions, skip the test - # "reflect", - "wrap", -] -MODES += NP_MODES -MODES += [NumpyPadMode(i) for i in NP_MODES] - -PT_MODES: list = [ - "constant", - "replicate", - "circular", - # `reflect` mode is not supported in some PyTorch versions, skip the test - # "reflect", -] -MODES += PT_MODES -MODES += [PytorchPadMode(i) for i in PT_MODES] - -for mode in MODES: - TESTS.append([{"spatial_size": [3, 4], "method": "end", "mode": mode}, (1, 2, 3), (1, 3, 4)]) - - TESTS.append([{"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, (3, 8, 8, 4), (3, 15, 8, 4)]) - - -class TestSpatialPad(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - - @staticmethod - def get_arr(shape): - return np.random.randint(100, size=shape).astype(float) +class TestSpatialPad(PadTest): + Padder = SpatialPad @parameterized.expand(TESTS) - def test_pad_shape(self, input_param, input_shape, expected_shape): - results_1 = [] - results_2 = [] - input_data = self.get_arr(input_shape) - # check result is the same regardless of input type - for p in TEST_NDARRAYS: - padder = SpatialPad(**input_param) - r1 = padder(p(input_data)) - r2 = padder(p(input_data), mode=input_param["mode"]) - results_1.append(r1.cpu() if isinstance(r1, torch.Tensor) else r1) - results_2.append(r2.cpu() if isinstance(r2, torch.Tensor) else r2) - for results in (results_1, results_2): - np.testing.assert_allclose(results[-1].shape, expected_shape) - if input_param["mode"] not in ("empty", NumpyPadMode.EMPTY): - torch.testing.assert_allclose(results[0], results[-1], atol=0, rtol=1e-5) + def test_pad(self, input_param, input_shape, expected_shape): + self.pad_test(input_param, input_shape, expected_shape) def test_pad_kwargs(self): - for p in TEST_NDARRAYS: - input_data = p(np.zeros((3, 8, 4))) - if isinstance(input_data, torch.Tensor): - result = ( - SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) - .cpu() - .numpy() - ) - else: - result = SpatialPad( - spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) - )(img=input_data) - torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) - torch.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1, rtol=1e-7, atol=0) + kwargs = {"spatial_size": [15, 8], "method": "end", "mode": "constant"} + unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)] + self.pad_test_kwargs(unchanged_slices, **kwargs) if __name__ == "__main__": diff --git a/tests/test_spatial_padd.py b/tests/test_spatial_padd.py index 762a1145f5..656a731de0 100644 --- a/tests/test_spatial_padd.py +++ b/tests/test_spatial_padd.py @@ -11,42 +11,26 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import SpatialPadd +from tests.padders import PadTest -TEST_CASE_1 = [ - {"keys": ["img"], "spatial_size": [15, 8, 8], "method": "symmetric", "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 15, 8, 8)), +TESTS = [ + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "symmetric"}, (3, 8, 8, 4), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 4), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 4), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, -1], "method": "end"}, (3, 8, 4, 4), (3, 15, 8, 4)], ] -TEST_CASE_2 = [ - {"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end", "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 15, 8, 8)), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end", "mode": {"constant"}}, - {"img": np.zeros((3, 8, 8, 4))}, - np.zeros((3, 15, 8, 8)), -] - -TEST_CASE_4 = [ - {"keys": ["img"], "spatial_size": [15, 8, -1], "method": "end", "mode": {"constant"}}, - {"img": np.zeros((3, 8, 4, 4))}, - np.zeros((3, 15, 8, 4)), -] +class TestSpatialPadd(PadTest): + Padder = SpatialPadd -class TestSpatialPadd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = SpatialPadd(**input_param) - result = padder(input_data) - np.testing.assert_allclose(result["img"].shape, expected_val.shape) + @parameterized.expand(TESTS) + def test_pad(self, input_param, input_shape, expected_shape): + modes = ["constant", {"constant"}] + self.pad_test(input_param, input_shape, expected_shape, modes) if __name__ == "__main__":