diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 67a55cfb1d7..f0b47b4411f 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,7 +2,7 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor +from ._meta import _rgb_to_gray, convert_dtype_image_tensor def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -45,7 +45,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") - c = get_num_channels_image_tensor(image) + c = image.shape[-3] if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") @@ -75,7 +75,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") - c = get_num_channels_image_tensor(image) + c = image.shape[-3] if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") dtype = image.dtype if torch.is_floating_point(image) else torch.float32 @@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: - num_channels, height, width = get_dimensions_image_tensor(image) + num_channels, height, width = image.shape[-3:] if num_channels not in (1, 3): raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") @@ -210,8 +210,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") - c = get_num_channels_image_tensor(image) - + c = image.shape[-3] if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") @@ -342,8 +341,7 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: - c = get_num_channels_image_tensor(image) - + c = image.shape[-3] if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index b4a528c5478..9ed8a965ee3 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -16,12 +16,7 @@ ) from torchvision.transforms.functional_tensor import _parse_pad_padding -from ._meta import ( - convert_format_bounding_box, - get_dimensions_image_tensor, - get_spatial_size_image_pil, - get_spatial_size_image_tensor, -) +from ._meta import convert_format_bounding_box, get_spatial_size_image_pil horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip @@ -120,9 +115,9 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: bool = False, ) -> torch.Tensor: - num_channels, old_height, old_width = get_dimensions_image_tensor(image) + shape = image.shape + num_channels, old_height, old_width = shape[-3:] new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) - extra_dims = image.shape[:-3] if image.numel() > 0: image = image.reshape(-1, num_channels, old_height, old_width) @@ -134,7 +129,7 @@ def resize_image_tensor( antialias=antialias, ) - return image.reshape(extra_dims + (num_channels, new_height, new_width)) + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) @torch.jit.unused @@ -270,8 +265,8 @@ def affine_image_tensor( if image.numel() == 0: return image - num_channels, height, width = image.shape[-3:] - extra_dims = image.shape[:-3] + shape = image.shape + num_channels, height, width = shape[-3:] image = image.reshape(-1, num_channels, height, width) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) @@ -285,7 +280,7 @@ def affine_image_tensor( matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill) - return output.reshape(extra_dims + (num_channels, height, width)) + return output.reshape(shape) @torch.jit.unused @@ -511,8 +506,8 @@ def rotate_image_tensor( fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: - num_channels, height, width = image.shape[-3:] - extra_dims = image.shape[:-3] + shape = image.shape + num_channels, height, width = shape[-3:] center_f = [0.0, 0.0] if center is not None: @@ -538,7 +533,7 @@ def rotate_image_tensor( else: new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) - return image.reshape(extra_dims + (num_channels, new_height, new_width)) + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) @torch.jit.unused @@ -675,8 +670,8 @@ def _pad_with_scalar_fill( fill: Union[int, float, None], padding_mode: str = "constant", ) -> torch.Tensor: - num_channels, height, width = image.shape[-3:] - extra_dims = image.shape[:-3] + shape = image.shape + num_channels, height, width = shape[-3:] if image.numel() > 0: image = _FT.pad( @@ -688,7 +683,7 @@ def _pad_with_scalar_fill( new_height = height + top + bottom new_width = width + left + right - return image.reshape(extra_dims + (num_channels, new_height, new_width)) + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) # TODO: This should be removed once pytorch pad supports non-scalar padding values @@ -1130,7 +1125,8 @@ def elastic( def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): - return [int(output_size), int(output_size)] + s = int(output_size) + return [s, s] elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: return [output_size[0], output_size[0]] else: @@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_height, image_width = get_spatial_size_image_tensor(image) + shape = image.shape + if image.numel() == 0: + return image.reshape(shape[:-2] + (crop_height, crop_width)) + image_height, image_width = shape[-2:] if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) - image = pad_image_tensor(image, padding_ltrb, fill=0) + image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0) - image_height, image_width = get_spatial_size_image_tensor(image) + image_height, image_width = image.shape[-2:] if crop_width == image_width and crop_height == image_height: return image crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) - return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width) + return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] @torch.jit.unused @@ -1332,7 +1331,7 @@ def five_crop_image_tensor( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: crop_height, crop_width = _parse_five_crop_size(size) - image_height, image_width = get_spatial_size_image_tensor(image) + image_height, image_width = image.shape[-2:] if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}"