Skip to content

Commit

Permalink
[proto] Optimized functional pad op for bboxes + tests (#6890)
Browse files Browse the repository at this point in the history
* [proto] Speed-up crop on bboxes and tests

* Fix linter

* Update _geometry.py

* Fixed device issue

* Revert changes in test/prototype_transforms_kernel_infos.py

* Fixed failing correctness tests

* [proto] Optimized functional pad op for bboxes + tests

* Renamed copy-pasted variable name

* Code update

* Fixes according to the review
  • Loading branch information
vfdev-5 authored Nov 3, 2022
1 parent d8cec34 commit 79ca506
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
36 changes: 35 additions & 1 deletion test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from torch.utils._pytree import tree_map
from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding

__all__ = ["KernelInfo", "KERNEL_INFOS"]

Expand Down Expand Up @@ -1078,6 +1078,38 @@ def sample_inputs_pad_video():
yield ArgsKwargs(video_loader, padding=[1])


def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, padding_mode):

left, right, top, bottom = _parse_pad_padding(padding)

affine_matrix = np.array(
[
[1, 0, left],
[0, 1, top],
],
dtype="float32",
)

height = spatial_size[0] + top + bottom
width = spatial_size[1] + left + right

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes, (height, width)


def reference_inputs_pad_bounding_box():
for bounding_box_loader, padding in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
padding=padding,
padding_mode="constant",
)


KERNEL_INFOS.extend(
[
KernelInfo(
Expand All @@ -1097,6 +1129,8 @@ def sample_inputs_pad_video():
KernelInfo(
F.pad_bounding_box,
sample_inputs_fn=sample_inputs_pad_bounding_box,
reference_fn=reference_pad_bounding_box,
reference_inputs_fn=reference_inputs_pad_bounding_box,
test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"),
Expand Down
20 changes: 8 additions & 12 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,14 +768,11 @@ def pad_bounding_box(

left, right, top, bottom = _parse_pad_padding(padding)

bounding_box = bounding_box.clone()

# this works without conversion since padding only affects xy coordinates
bounding_box[..., 0] += left
bounding_box[..., 1] += top
if format == features.BoundingBoxFormat.XYXY:
bounding_box[..., 2] += left
bounding_box[..., 3] += top
pad = [left, top, left, top]
else:
pad = [left, top, 0, 0]
bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device)

height, width = spatial_size
height += top + bottom
Expand Down Expand Up @@ -821,14 +818,13 @@ def crop_bounding_box(
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:

bounding_box = bounding_box.clone()

# Crop or implicit pad if left and/or top have negative values:
if format == features.BoundingBoxFormat.XYXY:
sub = torch.tensor([left, top, left, top], device=bounding_box.device)
sub = [left, top, left, top]
else:
sub = torch.tensor([left, top, 0, 0], device=bounding_box.device)
bounding_box = bounding_box.sub_(sub)
sub = [left, top, 0, 0]

bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)

return bounding_box, (height, width)

Expand Down

0 comments on commit 79ca506

Please sign in to comment.