Skip to content

Commit

Permalink
[fbsync] clamp bounding boxes in some geometry kernels (#7215)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: vfdev-5 <[email protected]>

Reviewed By: vmoens

Differential Revision: D44416581

fbshipit-source-id: 52a04b420e7364ca3d980ed940b2807f6384eeff
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Mar 28, 2023
1 parent 6309187 commit d8d4539
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 89 deletions.
126 changes: 93 additions & 33 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
}


def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6):
def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
return {
(("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
}
Expand Down Expand Up @@ -211,10 +211,12 @@ def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size
[-1, 0, spatial_size[1]],
[0, 1, 0],
],
dtype="float32",
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)

return expected_bboxes

Expand Down Expand Up @@ -322,7 +324,7 @@ def reference_inputs_resize_image_tensor():
def sample_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
yield ArgsKwargs(bounding_box_loader, spatial_size=bounding_box_loader.spatial_size, size=size)


def sample_inputs_resize_mask():
Expand All @@ -344,19 +346,20 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
[new_width / old_width, 0, 0],
[0, new_height / old_height, 0],
],
dtype="float32",
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix
bounding_box,
format=bounding_box.format,
spatial_size=(new_height, new_width),
affine_matrix=affine_matrix,
)
return expected_bboxes, (new_height, new_width)


def reference_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,))
):
for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))):
for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)

Expand Down Expand Up @@ -543,14 +546,17 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix


def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix):
def transform(bbox, affine_matrix_, format_):
def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix):
def transform(bbox, affine_matrix_, format_, spatial_size_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_box(
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
bbox.as_subclass(torch.Tensor),
old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY,
inplace=True,
)
points = np.array(
[
Expand All @@ -573,12 +579,15 @@ def transform(bbox, affine_matrix_, format_):
out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
)
return out_bbox.to(dtype=in_dtype)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_)
out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox

if bounding_box.ndim < 2:
bounding_box = [bounding_box]

expected_bboxes = [transform(bbox, affine_matrix, format) for bbox in bounding_box]
expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box]
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
Expand All @@ -594,7 +603,9 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :]

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)

return expected_bboxes

Expand Down Expand Up @@ -643,9 +654,6 @@ def sample_inputs_affine_video():
sample_inputs_fn=sample_inputs_affine_bounding_box,
reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
test_marks=[
xfail_jit_python_scalar_arg("shear"),
],
Expand Down Expand Up @@ -729,10 +737,12 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
[1, 0, 0],
[0, -1, spatial_size[0]],
],
dtype="float32",
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)

return expected_bboxes

Expand Down Expand Up @@ -806,6 +816,43 @@ def sample_inputs_rotate_bounding_box():
)


def reference_inputs_rotate_bounding_box():
for bounding_box_loader, angle in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), _ROTATE_ANGLES
):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
angle=angle,
)

# TODO: add samples with expand=True and center


def reference_rotate_bounding_box(bounding_box, *, format, spatial_size, angle, expand=False, center=None):

if center is None:
center = [spatial_size[1] * 0.5, spatial_size[0] * 0.5]

a = np.cos(angle * np.pi / 180.0)
b = np.sin(angle * np.pi / 180.0)
cx = center[0]
cy = center[1]
affine_matrix = np.array(
[
[a, b, cx - cx * a - b * cy],
[-b, a, cy + cx * b - a * cy],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes, spatial_size


def sample_inputs_rotate_mask():
for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
yield ArgsKwargs(mask_loader, angle=15.0)
Expand Down Expand Up @@ -834,9 +881,11 @@ def sample_inputs_rotate_video():
KernelInfo(
F.rotate_bounding_box,
sample_inputs_fn=sample_inputs_rotate_bounding_box,
reference_fn=reference_rotate_bounding_box,
reference_inputs_fn=reference_inputs_rotate_bounding_box,
closeness_kwargs={
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
},
),
KernelInfo(
Expand Down Expand Up @@ -897,17 +946,19 @@ def sample_inputs_crop_video():


def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):

affine_matrix = np.array(
[
[1, 0, -left],
[0, 1, -top],
],
dtype="float32",
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes, (height, width)
spatial_size = (height, width)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes, spatial_size


def reference_inputs_crop_bounding_box():
Expand Down Expand Up @@ -1119,13 +1170,15 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p
[1, 0, left],
[0, 1, top],
],
dtype="float32",
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

height = spatial_size[0] + top + bottom
width = spatial_size[1] + left + right

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix
)
return expected_bboxes, (height, width)


Expand Down Expand Up @@ -1225,14 +1278,16 @@ def sample_inputs_perspective_bounding_box():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
startpoints=None,
endpoints=None,
coefficients=_PERSPECTIVE_COEFFS[0],
)

format = datapoints.BoundingBoxFormat.XYXY
loader = make_bounding_box_loader(format=format)
yield ArgsKwargs(
make_bounding_box_loader(format=format), format=format, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
)


Expand Down Expand Up @@ -1269,13 +1324,17 @@ def sample_inputs_perspective_video():
**pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
},
),
KernelInfo(
F.perspective_bounding_box,
sample_inputs_fn=sample_inputs_perspective_bounding_box,
closeness_kwargs={
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
},
),
KernelInfo(
F.perspective_mask,
Expand All @@ -1292,8 +1351,8 @@ def sample_inputs_perspective_video():
sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
},
),
]
Expand Down Expand Up @@ -1331,6 +1390,7 @@ def sample_inputs_elastic_bounding_box():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
displacement=displacement,
)

Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class TestSmoke:
(transforms.RandomZoomOut(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ClampBoundingBox(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
Expand Down
Loading

0 comments on commit d8d4539

Please sign in to comment.