Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Optimize Center Crop performance #6880

Merged
merged 5 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor
from ._meta import _rgb_to_gray, convert_dtype_image_tensor


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
Expand Down Expand Up @@ -45,7 +45,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")

c = get_num_channels_image_tensor(image)
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

Expand Down Expand Up @@ -75,7 +75,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")

c = get_num_channels_image_tensor(image)
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
Expand All @@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat


def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = get_dimensions_image_tensor(image)
num_channels, height, width = image.shape[-3:]
Copy link
Contributor Author

@datumbox datumbox Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligns the idiom with other parts of the code-base where we directly rely on shape to determine the dimensions:

new_height, new_width = image.shape[-2:]

At that point the image is a pure tensor and making additional method calls to fetch the sizes is unnecessary.

if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")

Expand Down Expand Up @@ -210,8 +210,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")

c = get_num_channels_image_tensor(image)

c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

Expand Down Expand Up @@ -342,8 +341,7 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp


def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = get_num_channels_image_tensor(image)

c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

Expand Down
47 changes: 23 additions & 24 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
)
from torchvision.transforms.functional_tensor import _parse_pad_padding

from ._meta import (
convert_format_bounding_box,
get_dimensions_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_image_tensor,
)
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil

horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
Expand Down Expand Up @@ -120,9 +115,9 @@ def resize_image_tensor(
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
shape = image.shape
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
extra_dims = image.shape[:-3]

if image.numel() > 0:
image = image.reshape(-1, num_channels, old_height, old_width)
Expand All @@ -134,7 +129,7 @@ def resize_image_tensor(
antialias=antialias,
)

return image.reshape(extra_dims + (num_channels, new_height, new_width))
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -270,8 +265,8 @@ def affine_image_tensor(
if image.numel() == 0:
return image

num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
shape = image.shape
num_channels, height, width = shape[-3:]
Comment on lines +268 to +269
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit to align idioms across the code-base. See ref:

Originally I wanted to do something like *extra_dims, num_channels, height, width but that's not JIT-scriptable. So I opted for keeping the whole original shape as we do on other parts of the code-base.

image = image.reshape(-1, num_channels, height, width)

angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
Expand All @@ -285,7 +280,7 @@ def affine_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.reshape(extra_dims + (num_channels, height, width))
return output.reshape(shape)


@torch.jit.unused
Expand Down Expand Up @@ -511,8 +506,8 @@ def rotate_image_tensor(
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
shape = image.shape
num_channels, height, width = shape[-3:]

center_f = [0.0, 0.0]
if center is not None:
Expand All @@ -538,7 +533,7 @@ def rotate_image_tensor(
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)

return image.reshape(extra_dims + (num_channels, new_height, new_width))
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -675,8 +670,8 @@ def _pad_with_scalar_fill(
fill: Union[int, float, None],
padding_mode: str = "constant",
) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
shape = image.shape
num_channels, height, width = shape[-3:]

if image.numel() > 0:
image = _FT.pad(
Expand All @@ -688,7 +683,7 @@ def _pad_with_scalar_fill(
new_height = height + top + bottom
new_width = width + left + right

return image.reshape(extra_dims + (num_channels, new_height, new_width))
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


# TODO: This should be removed once pytorch pad supports non-scalar padding values
Expand Down Expand Up @@ -1130,7 +1125,8 @@ def elastic(

def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)]
s = int(output_size)
return [s, s]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]]
else:
Expand All @@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor(

def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_spatial_size_image_tensor(image)
shape = image.shape
if image.numel() == 0:
return image.reshape(shape[:-2] + (crop_height, crop_width))
Comment on lines +1156 to +1157
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because the original pad_image_tensor method below had a mitigation for zero batch images. So I thought to hit that point early.

image_height, image_width = shape[-2:]

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0)
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0)
Copy link
Contributor Author

@datumbox datumbox Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no real reason to hit pad_image_tensor() and go through the validation of input, parsing of parameters and multiple method calls to actually hit PyTorch's pad. We should try to be as explicit as possible on the internal implementations, as this pays out. As we can see below the performance gains are significant.

Here are V1 vs V2 benchmarks on latest main:

[-------------- CenterCrop cpu torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   42 (+-  0) us  |   57 (+-  0) us
      (16, 3, 40, 50)  |  528 (+-  1) us  |  555 (+-  2) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   88 (+-  1) us  |  104 (+-  3) us
      (16, 3, 40, 50)  |  615 (+-  5) us  |  643 (+- 19) us

Times are in microseconds (us).

[------------- CenterCrop cuda torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   48 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  1) us  |   48 (+-  1) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  1) us

Times are in microseconds (us).

[--------------- CenterCrop cpu torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   68 (+-  0) us  |   77 (+-  0) us
      (16, 3, 40, 50)  |  760 (+-  4) us  |  785 (+-  4) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |  106 (+-  1) us  |  116 (+-  2) us
      (16, 3, 40, 50)  |  837 (+-  6) us  |  861 (+- 18) us

Times are in microseconds (us).

[-------------- CenterCrop cuda torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   48 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  1) us  |   48 (+-  1) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   48 (+-  1) us

Times are in microseconds (us).

Here is after the PR:

[-------------- CenterCrop cpu torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   45 (+-  0) us  |   42 (+-  0) us
      (16, 3, 40, 50)  |  525 (+-  2) us  |  535 (+-  1) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   88 (+-  3) us  |   88 (+-  1) us
      (16, 3, 40, 50)  |  615 (+- 23) us  |  626 (+- 18) us

Times are in microseconds (us).

[------------- CenterCrop cuda torch.float32 -------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   35 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   35 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  0) us

Times are in microseconds (us).

[--------------- CenterCrop cpu torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   68 (+-  0) us  |   65 (+-  0) us
      (16, 3, 40, 50)  |  763 (+-  4) us  |  773 (+-  6) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |  106 (+-  3) us  |  104 (+-  2) us
      (16, 3, 40, 50)  |  839 (+- 22) us  |  851 (+- 23) us

Times are in microseconds (us).

[-------------- CenterCrop cuda torch.uint8 --------------]
                       |        old       |        new     
1 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   34 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  0) us
6 threads: ------------------------------------------------
      (3, 40, 50)      |   37 (+-  0) us  |   35 (+-  0) us
      (16, 3, 40, 50)  |   31 (+-  0) us  |   35 (+-  1) us

Times are in microseconds (us).

This benchmark forces the branch of padding by putting images with size (40,50) and requesting a crop of (224, 224).


image_height, image_width = get_spatial_size_image_tensor(image)
image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height:
return image

crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width)
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]


@torch.jit.unused
Expand Down Expand Up @@ -1332,7 +1331,7 @@ def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = get_spatial_size_image_tensor(image)
image_height, image_width = image.shape[-2:]

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
Expand Down