Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for transform presets, and various fixes #7223

Merged
merged 7 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import re
from collections import defaultdict

import numpy as np

Expand Down Expand Up @@ -1988,3 +1989,158 @@ def test__transform(self, inpt):
assert type(output) is type(inpt)
assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test or the one below may not be your preferred test style. I'm happy to come back to this when we have more time,, but considering the very pressing timeline with the upcoming release, I kindly request that we focus on the substance (i.e. what the tests are testing + how we fix them + correctness) rather than style / perf / optimizations.


image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
if image_type is PIL.Image:
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)

label = 1 if label_type is int else torch.tensor([1])

if dataset_return_type is dict:
sample = {
"image": image,
"label": label,
}
else:
sample = image, label

t = transforms.Compose(
[
transforms.RandomResizedCrop((224, 224)),
transforms.RandomHorizontalFlip(p=1),
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AugMix(),
transforms.AutoAugment(),
to_tensor(),
# TODO: ConvertImageDtype is a pass-through on PIL images, is that
# intended? This results in a failure if we convert to tensor after
# it, because the image would still be uint8 which make Normalize
# fail.
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
transforms.RandomErasing(p=1),
]
)

out = t(sample)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The asserts below are light, the actual purpose of this test is to make sure that this call doesn't fail.
On main, it fails for a bunch of reasons right now


assert type(out) == type(sample)

if dataset_return_type is tuple:
out_image, out_label = out
else:
assert out.keys() == sample.keys()
out_image, out_label = out.values()

assert out_image.shape[-2:] == (224, 224)
assert out_label == label


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, list))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
if data_augmentation == "hflip":
t = [
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "lsj":
t = [
transforms.ScaleJitter(target_size=(1024, 1024), antialias=True),
# Note: replaced FixedSizeCrop with RandomCrop, becuase we're
# leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet.
# transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})
# ),
# TODO: I have to set a very low size for the crop (30, 30),
# otherwise we'd get an error saying the crop is larger than the
# image. This means RandomCrop doesn't do the same thing as
# FixedSizeCrop and we need ot figure out the key differences
transforms.RandomCrop((30, 30)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomCrop doesn't do any padding by default, but FixedSizedCrop does. You probably only need to set pad_if_needed=True.

transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "multiscale":
t = [
transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True
),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})),
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssdlite":
t = [
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
t = transforms.Compose(t)

num_boxes = 5
H = W = 250

image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
if image_type is PIL.Image:
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)

label = torch.randint(0, 10, size=(num_boxes,))
if label_type is list:
label = label.tolist()

# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W))
boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W))

masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))

sample = {
"image": image,
"label": label,
"boxes": boxes,
"masks": masks,
}

out = t(sample)

if to_tensor is transforms.ToTensor and image_type is not datapoints.Image:
assert is_simple_tensor(out["image"])
else:
assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"]))

out["label"] = torch.tensor(out["label"])
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def _flatten_and_extract_image_or_video(
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)

image_or_videos = []
for idx, inpt in enumerate(flat_inputs):
if check_type(
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
if needs_transform and check_type(
Comment on lines +43 to +44
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is concerning that every single overload of forward() must be aware of the heuristic somehow. It is probably going to be a problem for user-defined transforms, unless we don't want them to override forward() - which would be unfair, since we do it ourselves.

inpt,
(
datapoints.Image,
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def _permute_channels(
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)

output = inpt[..., permutation, :, :]
# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
Copy link
Member Author

@NicolasHug NicolasHug Feb 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that this slipped through the cracks: passing an Image would result in a pure Tensor.
@pmeier don't we have tests that make sure the types are preserved? Are we ensuring in the tests that all the random transforms are actually tested (i.e. set p=1)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have tests that make sure the types are preserved?

We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher.


if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)
Expand Down
30 changes: 22 additions & 8 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,19 @@ def forward(self, *inputs: Any) -> Any:

self._check_inputs(flat_inputs)

params = self._get_params(flat_inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted-out the heuristic in a method, because it is needed in other places (basically, all forward()s of transform subclasses).

params = self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)

flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]

return tree_unflatten(flat_outputs, spec)

def _needs_transform_list(self, flat_inputs):
# Below is a heuristic on how to deal with simple tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
Expand All @@ -53,7 +64,8 @@ def forward(self, *inputs: Any) -> Any:
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
flat_outputs = []

needs_transform_list = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
Expand All @@ -65,10 +77,8 @@ def forward(self, *inputs: Any) -> Any:
transform_simple_tensor = False
else:
needs_transform = False

flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt)

return tree_unflatten(flat_outputs, spec)
needs_transform_list.append(needs_transform)
return needs_transform_list

def extra_repr(self) -> str:
extra = []
Expand Down Expand Up @@ -159,10 +169,14 @@ def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs

params = self._get_params(flat_inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we forgot to udpate this forward() as well...

params = self._get_params(
inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform
)

flat_outputs = [
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]

return tree_unflatten(flat_outputs, spec)