Skip to content

Commit

Permalink
[fbsync] [proto][tests] Added ref functions for h/v flips (#6876)
Browse files Browse the repository at this point in the history
Summary:
* [proto][tests] Added ref functions for h/v flips

* Better dtype handling in reference_affine_bounding_box_helper

Reviewed By: datumbox

Differential Revision: D41020552

fbshipit-source-id: 24cd18322b1b83d53479c427751476618a6938b3
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Nov 4, 2022
1 parent f68c4f4 commit 95be285
Showing 1 changed file with 66 additions and 13 deletions.
79 changes: 66 additions & 13 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,29 @@ def sample_inputs_horizontal_flip_video():
yield ArgsKwargs(video_loader)


def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[-1, 0, spatial_size[1]],
[0, 1, 0],
],
dtype="float32",
)

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)

return expected_bboxes


def reference_inputs_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)


KERNEL_INFOS.extend(
[
KernelInfo(
Expand All @@ -158,6 +181,8 @@ def sample_inputs_horizontal_flip_video():
KernelInfo(
F.horizontal_flip_bounding_box,
sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box,
reference_fn=reference_horizontal_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box,
),
KernelInfo(
F.horizontal_flip_mask,
Expand Down Expand Up @@ -409,15 +434,13 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix


def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
if center is None:
center = [s * 0.5 for s in spatial_size[::-1]]

def transform(bbox):
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :]

bbox_xyxy = F.convert_format_bounding_box(bbox, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix):
def transform(bbox, affine_matrix_, format_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
bbox_xyxy = F.convert_format_bounding_box(
bbox.float(), old_format=format_, new_format=features.BoundingBoxFormat.XYXY, inplace=True
)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
Expand All @@ -426,22 +449,24 @@ def transform(bbox):
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
transformed_points = np.matmul(points, affine_matrix_.T)
out_bbox = torch.tensor(
[
np.min(transformed_points[:, 0]).item(),
np.min(transformed_points[:, 1]).item(),
np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(),
],
dtype=bbox.dtype,
)
return F.convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format)
out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
return out_bbox.to(dtype=in_dtype)

if bounding_box.ndim < 2:
bounding_box = [bounding_box]

expected_bboxes = [transform(bbox) for bbox in bounding_box]
expected_bboxes = [transform(bbox, affine_matrix, format) for bbox in bounding_box]
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
Expand All @@ -450,6 +475,18 @@ def transform(bbox):
return expected_bboxes


def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
if center is None:
center = [s * 0.5 for s in spatial_size[::-1]]

affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :]

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)

return expected_bboxes


def reference_inputs_affine_bounding_box():
for bounding_box_loader, affine_kwargs in itertools.product(
make_bounding_box_loaders(extra_dims=[()]),
Expand Down Expand Up @@ -643,6 +680,20 @@ def sample_inputs_vertical_flip_video():
yield ArgsKwargs(video_loader)


def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[1, 0, 0],
[0, -1, spatial_size[0]],
],
dtype="float32",
)

expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)

return expected_bboxes


KERNEL_INFOS.extend(
[
KernelInfo(
Expand All @@ -656,6 +707,8 @@ def sample_inputs_vertical_flip_video():
KernelInfo(
F.vertical_flip_bounding_box,
sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
reference_fn=reference_vertical_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box,
),
KernelInfo(
F.vertical_flip_mask,
Expand Down

0 comments on commit 95be285

Please sign in to comment.