-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[prototype] Optimize Center Crop performance #6880
Changes from all commits
825e036
d65197f
b40c3f3
7e3227c
54ef3b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -16,12 +16,7 @@ | |||
) | ||||
from torchvision.transforms.functional_tensor import _parse_pad_padding | ||||
|
||||
from ._meta import ( | ||||
convert_format_bounding_box, | ||||
get_dimensions_image_tensor, | ||||
get_spatial_size_image_pil, | ||||
get_spatial_size_image_tensor, | ||||
) | ||||
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil | ||||
|
||||
horizontal_flip_image_tensor = _FT.hflip | ||||
horizontal_flip_image_pil = _FP.hflip | ||||
|
@@ -120,9 +115,9 @@ def resize_image_tensor( | |||
max_size: Optional[int] = None, | ||||
antialias: bool = False, | ||||
) -> torch.Tensor: | ||||
num_channels, old_height, old_width = get_dimensions_image_tensor(image) | ||||
shape = image.shape | ||||
num_channels, old_height, old_width = shape[-3:] | ||||
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) | ||||
extra_dims = image.shape[:-3] | ||||
|
||||
if image.numel() > 0: | ||||
image = image.reshape(-1, num_channels, old_height, old_width) | ||||
|
@@ -134,7 +129,7 @@ def resize_image_tensor( | |||
antialias=antialias, | ||||
) | ||||
|
||||
return image.reshape(extra_dims + (num_channels, new_height, new_width)) | ||||
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) | ||||
|
||||
|
||||
@torch.jit.unused | ||||
|
@@ -270,8 +265,8 @@ def affine_image_tensor( | |||
if image.numel() == 0: | ||||
return image | ||||
|
||||
num_channels, height, width = image.shape[-3:] | ||||
extra_dims = image.shape[:-3] | ||||
shape = image.shape | ||||
num_channels, height, width = shape[-3:] | ||||
Comment on lines
+268
to
+269
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor nit to align idioms across the code-base. See ref:
Originally I wanted to do something like |
||||
image = image.reshape(-1, num_channels, height, width) | ||||
|
||||
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) | ||||
|
@@ -285,7 +280,7 @@ def affine_image_tensor( | |||
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) | ||||
|
||||
output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill) | ||||
return output.reshape(extra_dims + (num_channels, height, width)) | ||||
return output.reshape(shape) | ||||
|
||||
|
||||
@torch.jit.unused | ||||
|
@@ -511,8 +506,8 @@ def rotate_image_tensor( | |||
fill: features.FillTypeJIT = None, | ||||
center: Optional[List[float]] = None, | ||||
) -> torch.Tensor: | ||||
num_channels, height, width = image.shape[-3:] | ||||
extra_dims = image.shape[:-3] | ||||
shape = image.shape | ||||
num_channels, height, width = shape[-3:] | ||||
|
||||
center_f = [0.0, 0.0] | ||||
if center is not None: | ||||
|
@@ -538,7 +533,7 @@ def rotate_image_tensor( | |||
else: | ||||
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) | ||||
|
||||
return image.reshape(extra_dims + (num_channels, new_height, new_width)) | ||||
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) | ||||
|
||||
|
||||
@torch.jit.unused | ||||
|
@@ -675,8 +670,8 @@ def _pad_with_scalar_fill( | |||
fill: Union[int, float, None], | ||||
padding_mode: str = "constant", | ||||
) -> torch.Tensor: | ||||
num_channels, height, width = image.shape[-3:] | ||||
extra_dims = image.shape[:-3] | ||||
shape = image.shape | ||||
num_channels, height, width = shape[-3:] | ||||
|
||||
if image.numel() > 0: | ||||
image = _FT.pad( | ||||
|
@@ -688,7 +683,7 @@ def _pad_with_scalar_fill( | |||
new_height = height + top + bottom | ||||
new_width = width + left + right | ||||
|
||||
return image.reshape(extra_dims + (num_channels, new_height, new_width)) | ||||
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) | ||||
|
||||
|
||||
# TODO: This should be removed once pytorch pad supports non-scalar padding values | ||||
|
@@ -1130,7 +1125,8 @@ def elastic( | |||
|
||||
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: | ||||
if isinstance(output_size, numbers.Number): | ||||
return [int(output_size), int(output_size)] | ||||
s = int(output_size) | ||||
return [s, s] | ||||
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: | ||||
return [output_size[0], output_size[0]] | ||||
else: | ||||
|
@@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor( | |||
|
||||
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: | ||||
crop_height, crop_width = _center_crop_parse_output_size(output_size) | ||||
image_height, image_width = get_spatial_size_image_tensor(image) | ||||
shape = image.shape | ||||
if image.numel() == 0: | ||||
return image.reshape(shape[:-2] + (crop_height, crop_width)) | ||||
Comment on lines
+1156
to
+1157
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed because the original |
||||
image_height, image_width = shape[-2:] | ||||
|
||||
if crop_height > image_height or crop_width > image_width: | ||||
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) | ||||
image = pad_image_tensor(image, padding_ltrb, fill=0) | ||||
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no real reason to hit Here are V1 vs V2 benchmarks on latest main:
Here is after the PR:
This benchmark forces the branch of padding by putting images with size |
||||
|
||||
image_height, image_width = get_spatial_size_image_tensor(image) | ||||
image_height, image_width = image.shape[-2:] | ||||
if crop_width == image_width and crop_height == image_height: | ||||
return image | ||||
|
||||
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) | ||||
return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width) | ||||
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] | ||||
|
||||
|
||||
@torch.jit.unused | ||||
|
@@ -1332,7 +1331,7 @@ def five_crop_image_tensor( | |||
image: torch.Tensor, size: List[int] | ||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||||
crop_height, crop_width = _parse_five_crop_size(size) | ||||
image_height, image_width = get_spatial_size_image_tensor(image) | ||||
image_height, image_width = image.shape[-2:] | ||||
|
||||
if crop_width > image_width or crop_height > image_height: | ||||
msg = "Requested crop size {} is bigger than input size {}" | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aligns the idiom with other parts of the code-base where we directly rely on
shape
to determine the dimensions:vision/torchvision/prototype/transforms/functional/_geometry.py
Line 539 in 70faba9
At that point the
image
is a pure tensor and making additional method calls to fetch the sizes is unnecessary.