Skip to content

Commit

Permalink
[fbsync] [prototype] Optimize Center Crop performance (#6880)
Browse files Browse the repository at this point in the history
Summary:
* Reducing unnecessary method calls

* Optimize pad branch.

* Remove unnecessary call to crop_image_tensor

* Fix linter

Reviewed By: datumbox

Differential Revision: D41020555

fbshipit-source-id: 55d55d80993830d0b70ad4140d55fab2cba9d21e

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
2 people authored and facebook-github-bot committed Nov 4, 2022
1 parent e05a0d0 commit 24c3fb8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
14 changes: 6 additions & 8 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor
from ._meta import _rgb_to_gray, convert_dtype_image_tensor


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
Expand Down Expand Up @@ -45,7 +45,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")

c = get_num_channels_image_tensor(image)
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

Expand Down Expand Up @@ -75,7 +75,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")

c = get_num_channels_image_tensor(image)
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
Expand All @@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat


def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = get_dimensions_image_tensor(image)
num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")

Expand Down Expand Up @@ -210,8 +210,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")

c = get_num_channels_image_tensor(image)

c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

Expand Down Expand Up @@ -342,8 +341,7 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp


def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = get_num_channels_image_tensor(image)

c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

Expand Down
47 changes: 23 additions & 24 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
)
from torchvision.transforms.functional_tensor import _parse_pad_padding

from ._meta import (
convert_format_bounding_box,
get_dimensions_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_image_tensor,
)
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil

horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
Expand Down Expand Up @@ -120,9 +115,9 @@ def resize_image_tensor(
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
shape = image.shape
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
extra_dims = image.shape[:-3]

if image.numel() > 0:
image = image.reshape(-1, num_channels, old_height, old_width)
Expand All @@ -134,7 +129,7 @@ def resize_image_tensor(
antialias=antialias,
)

return image.reshape(extra_dims + (num_channels, new_height, new_width))
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -270,8 +265,8 @@ def affine_image_tensor(
if image.numel() == 0:
return image

num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
shape = image.shape
num_channels, height, width = shape[-3:]
image = image.reshape(-1, num_channels, height, width)

angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
Expand All @@ -285,7 +280,7 @@ def affine_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.reshape(extra_dims + (num_channels, height, width))
return output.reshape(shape)


@torch.jit.unused
Expand Down Expand Up @@ -511,8 +506,8 @@ def rotate_image_tensor(
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
shape = image.shape
num_channels, height, width = shape[-3:]

center_f = [0.0, 0.0]
if center is not None:
Expand All @@ -538,7 +533,7 @@ def rotate_image_tensor(
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)

return image.reshape(extra_dims + (num_channels, new_height, new_width))
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -675,8 +670,8 @@ def _pad_with_scalar_fill(
fill: Union[int, float, None],
padding_mode: str = "constant",
) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
shape = image.shape
num_channels, height, width = shape[-3:]

if image.numel() > 0:
image = _FT.pad(
Expand All @@ -688,7 +683,7 @@ def _pad_with_scalar_fill(
new_height = height + top + bottom
new_width = width + left + right

return image.reshape(extra_dims + (num_channels, new_height, new_width))
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


# TODO: This should be removed once pytorch pad supports non-scalar padding values
Expand Down Expand Up @@ -1130,7 +1125,8 @@ def elastic(

def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)]
s = int(output_size)
return [s, s]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]]
else:
Expand All @@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor(

def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_spatial_size_image_tensor(image)
shape = image.shape
if image.numel() == 0:
return image.reshape(shape[:-2] + (crop_height, crop_width))
image_height, image_width = shape[-2:]

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0)
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0)

image_height, image_width = get_spatial_size_image_tensor(image)
image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height:
return image

crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width)
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]


@torch.jit.unused
Expand Down Expand Up @@ -1332,7 +1331,7 @@ def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = get_spatial_size_image_tensor(image)
image_height, image_width = image.shape[-2:]

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
Expand Down

0 comments on commit 24c3fb8

Please sign in to comment.