Skip to content

Commit

Permalink
[fbsync] add proper smoke test for prototype transforms (#7238)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D44416579

fbshipit-source-id: 9c3b0da79fe1270c13b6f705c5894ddd7783911f
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Mar 28, 2023
1 parent ac6942e commit c629504
Showing 1 changed file with 192 additions and 42 deletions.
234 changes: 192 additions & 42 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import pathlib
import re
import warnings
from collections import defaultdict
Expand All @@ -20,15 +21,16 @@
make_image,
make_images,
make_label,
make_masks,
make_one_hot_labels,
make_segmentation_mask,
make_video,
make_videos,
)
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs):
)


def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
for creation_fn in [
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_masks,
make_videos,
]:
inputs = list(creation_fn())
try:
output = transform(inputs[0])
except Exception:
def auto_augment_adapter(transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
else:
if output is inputs[0]:
continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input


def linear_transformation_adapter(transform, input, device):
flat_inputs = list(input.values())
c, h, w = query_chw(
[
item
for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs))
if needs_transform
]
)
num_elements = c * h * w
transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
transform.mean_vector = torch.randn((num_elements,), device=device)
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}

transforms_with_inputs.append((transform, inputs))

return parametrize(transforms_with_inputs)
def normalize_adapter(transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
# normalize doesn't support integer images
value = F.convert_dtype(value, torch.float32)
adapted_input[key] = value
return adapted_input


class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16], antialias=True),
transforms.CenterCrop([16, 16]),
transforms.ConvertDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
transforms.RandomRotation(degrees=(-45, 45)),
transforms.RandomAffine(degrees=(-45, 45)),
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
# TODO: Something wrong with input data setup. Let's fix that
# transforms.RandomEqualize(),
# transforms.RandomInvert(),
# transforms.RandomPosterize(bits=4),
# transforms.RandomSolarize(threshold=0.5),
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
@pytest.mark.parametrize(
("transform", "adapter"),
[
(transforms.RandomErasing(p=1.0), None),
(transforms.AugMix(), auto_augment_adapter),
(transforms.AutoAugment(), auto_augment_adapter),
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.Grayscale(), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
(transforms.RandomAutocontrast(p=1.0), None),
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None),
(transforms.ElasticTransform(sigma=1.0), None),
(transforms.Pad(4), None),
(transforms.RandomAffine(degrees=30.0), None),
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
(transforms.RandomHorizontalFlip(p=1.0), None),
(transforms.RandomPerspective(p=1.0), None),
(transforms.RandomResize(min_size=10, max_size=20), None),
(transforms.RandomResizedCrop([16, 16]), None),
(transforms.RandomRotation(degrees=30), None),
(transforms.RandomShortestSize(min_size=10), None),
(transforms.RandomVerticalFlip(p=1.0), None),
(transforms.RandomZoomOut(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
# because for we neither know the spatial size nor the device at this point
transformation_matrix=torch.empty((1, 1)),
mean_vector=torch.empty((1,)),
),
linear_transformation_adapter,
),
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
(transforms.ToDtype(torch.float64), None),
(transforms.UniformTemporalSubsample(num_samples=2), None),
],
ids=lambda transform: type(transform).__name__,
)
def test_common(self, transform, input):
transform(input)
@pytest.mark.parametrize("container_type", [dict, list, tuple])
@pytest.mark.parametrize(
"image_or_video",
[
make_image(),
make_video(),
next(make_pil_images(color_spaces=["RGB"])),
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_common(self, transform, adapter, container_type, image_or_video, device):
spatial_size = F.get_spatial_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_datapoint=make_image(size=spatial_size),
video_datapoint=make_video(size=spatial_size),
image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])),
bounding_box_xyxy=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,)
),
bounding_box_xywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,)
),
bounding_box_cxcywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,)
),
bounding_box_degenerate_xyxy=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[2, 0, 1, 1], # x1 > x2, y1 < y2
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
),
bounding_box_degenerate_xywh=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=spatial_size,
),
bounding_box_degenerate_cxcywh=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.CXCYWH,
spatial_size=spatial_size,
),
detection_mask=make_detection_mask(size=spatial_size),
segmentation_mask=make_segmentation_mask(size=spatial_size),
int=0,
float=0.0,
bool=True,
none=None,
str="str",
path=pathlib.Path.cwd(),
object=object(),
tensor=torch.empty(5),
array=np.empty(5),
)
if adapter is not None:
input = adapter(transform, input, device)

if container_type in {tuple, list}:
input = container_type(input.values())

input_flat, input_spec = tree_flatten(input)
input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
input = tree_unflatten(input_flat, input_spec)

torch.manual_seed(0)
output = transform(input)
output_flat, output_spec = tree_flatten(output)

assert output_spec == input_spec

for output_item, input_item, should_be_transformed in zip(
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
):
if should_be_transformed:
assert type(output_item) is type(input_item)
else:
assert output_item is input_item

@parametrize(
[
Expand Down

0 comments on commit c629504

Please sign in to comment.