-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tests for transform presets, and various fixes #7223
Changes from 3 commits
6d946cf
c845fd4
3d56ea2
6a67591
e0ae71d
6e0b04e
98df638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import itertools | ||
import re | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
|
||
|
@@ -1988,3 +1989,158 @@ def test__transform(self, inpt): | |
assert type(output) is type(inpt) | ||
assert output.shape[-4] == num_samples | ||
assert output.dtype == inpt.dtype | ||
|
||
|
||
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) | ||
@pytest.mark.parametrize("label_type", (torch.Tensor, int)) | ||
@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) | ||
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) | ||
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): | ||
|
||
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) | ||
if image_type is PIL.Image: | ||
image = to_pil_image(image[0]) | ||
elif image_type is torch.Tensor: | ||
image = image.as_subclass(torch.Tensor) | ||
assert is_simple_tensor(image) | ||
|
||
label = 1 if label_type is int else torch.tensor([1]) | ||
|
||
if dataset_return_type is dict: | ||
sample = { | ||
"image": image, | ||
"label": label, | ||
} | ||
else: | ||
sample = image, label | ||
|
||
t = transforms.Compose( | ||
[ | ||
transforms.RandomResizedCrop((224, 224)), | ||
transforms.RandomHorizontalFlip(p=1), | ||
transforms.RandAugment(), | ||
transforms.TrivialAugmentWide(), | ||
transforms.AugMix(), | ||
transforms.AutoAugment(), | ||
to_tensor(), | ||
# TODO: ConvertImageDtype is a pass-through on PIL images, is that | ||
# intended? This results in a failure if we convert to tensor after | ||
# it, because the image would still be uint8 which make Normalize | ||
# fail. | ||
transforms.ConvertImageDtype(torch.float), | ||
transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), | ||
transforms.RandomErasing(p=1), | ||
] | ||
) | ||
|
||
out = t(sample) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
assert type(out) == type(sample) | ||
|
||
if dataset_return_type is tuple: | ||
out_image, out_label = out | ||
else: | ||
assert out.keys() == sample.keys() | ||
out_image, out_label = out.values() | ||
|
||
assert out_image.shape[-2:] == (224, 224) | ||
assert out_label == label | ||
|
||
|
||
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) | ||
@pytest.mark.parametrize("label_type", (torch.Tensor, list)) | ||
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) | ||
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) | ||
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): | ||
if data_augmentation == "hflip": | ||
t = [ | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "lsj": | ||
t = [ | ||
transforms.ScaleJitter(target_size=(1024, 1024), antialias=True), | ||
# Note: replaced FixedSizeCrop with RandomCrop, becuase we're | ||
# leaving FixedSizeCrop in prototype for now, and it expects Label | ||
# classes which we won't release yet. | ||
# transforms.FixedSizeCrop( | ||
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}) | ||
# ), | ||
# TODO: I have to set a very low size for the crop (30, 30), | ||
# otherwise we'd get an error saying the crop is larger than the | ||
# image. This means RandomCrop doesn't do the same thing as | ||
# FixedSizeCrop and we need ot figure out the key differences | ||
transforms.RandomCrop((30, 30)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "multiscale": | ||
t = [ | ||
transforms.RandomShortestSize( | ||
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True | ||
), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "ssd": | ||
t = [ | ||
transforms.RandomPhotometricDistort(p=1), | ||
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), | ||
# TODO: put back IoUCrop once we remove its hard requirement for Labels | ||
# transforms.RandomIoUCrop(), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "ssdlite": | ||
t = [ | ||
# TODO: put back IoUCrop once we remove its hard requirement for Labels | ||
# transforms.RandomIoUCrop(), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
t = transforms.Compose(t) | ||
|
||
num_boxes = 5 | ||
H = W = 250 | ||
|
||
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)) | ||
if image_type is PIL.Image: | ||
image = to_pil_image(image[0]) | ||
elif image_type is torch.Tensor: | ||
image = image.as_subclass(torch.Tensor) | ||
assert is_simple_tensor(image) | ||
|
||
label = torch.randint(0, 10, size=(num_boxes,)) | ||
if label_type is list: | ||
label = label.tolist() | ||
|
||
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks | ||
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) | ||
boxes[:, 2:] += boxes[:, :2] | ||
boxes = boxes.clamp(min=0, max=min(H, W)) | ||
boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W)) | ||
|
||
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) | ||
|
||
sample = { | ||
"image": image, | ||
"label": label, | ||
"boxes": boxes, | ||
"masks": masks, | ||
} | ||
|
||
out = t(sample) | ||
|
||
if to_tensor is transforms.ToTensor and image_type is not datapoints.Image: | ||
assert is_simple_tensor(out["image"]) | ||
else: | ||
assert isinstance(out["image"], datapoints.Image) | ||
assert isinstance(out["label"], type(sample["label"])) | ||
|
||
out["label"] = torch.tensor(out["label"]) | ||
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,10 +37,11 @@ def _flatten_and_extract_image_or_video( | |
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask), | ||
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]: | ||
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) | ||
needs_transform_list = self._needs_transform_list(flat_inputs) | ||
|
||
image_or_videos = [] | ||
for idx, inpt in enumerate(flat_inputs): | ||
if check_type( | ||
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)): | ||
if needs_transform and check_type( | ||
Comment on lines
+43
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is concerning that every single overload of |
||
inpt, | ||
( | ||
datapoints.Image, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,7 +169,8 @@ def _permute_channels( | |
if isinstance(orig_inpt, PIL.Image.Image): | ||
inpt = F.pil_to_tensor(inpt) | ||
|
||
output = inpt[..., permutation, :, :] | ||
# TODO: Find a better fix than as_subclass??? | ||
output = inpt[..., permutation, :, :].as_subclass(type(inpt)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am concerned that this slipped through the cracks: passing an Image would result in a pure Tensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher. |
||
|
||
if isinstance(orig_inpt, PIL.Image.Image): | ||
output = F.to_image_pil(output) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,8 +36,19 @@ def forward(self, *inputs: Any) -> Any: | |
|
||
self._check_inputs(flat_inputs) | ||
|
||
params = self._get_params(flat_inputs) | ||
needs_transform_list = self._needs_transform_list(flat_inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I extracted-out the heuristic in a method, because it is needed in other places (basically, all |
||
params = self._get_params( | ||
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] | ||
) | ||
|
||
flat_outputs = [ | ||
self._transform(inpt, params) if needs_transform else inpt | ||
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) | ||
] | ||
|
||
return tree_unflatten(flat_outputs, spec) | ||
|
||
def _needs_transform_list(self, flat_inputs): | ||
# Below is a heuristic on how to deal with simple tensor inputs: | ||
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image | ||
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. | ||
|
@@ -53,7 +64,8 @@ def forward(self, *inputs: Any) -> Any: | |
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone | ||
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. | ||
# However, this case wasn't supported by transforms v1 either, so there is no BC concern. | ||
flat_outputs = [] | ||
|
||
needs_transform_list = [] | ||
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) | ||
for inpt in flat_inputs: | ||
needs_transform = True | ||
|
@@ -65,10 +77,8 @@ def forward(self, *inputs: Any) -> Any: | |
transform_simple_tensor = False | ||
else: | ||
needs_transform = False | ||
|
||
flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) | ||
|
||
return tree_unflatten(flat_outputs, spec) | ||
needs_transform_list.append(needs_transform) | ||
return needs_transform_list | ||
|
||
def extra_repr(self) -> str: | ||
extra = [] | ||
|
@@ -159,10 +169,14 @@ def forward(self, *inputs: Any) -> Any: | |
if torch.rand(1) >= self.p: | ||
return inputs | ||
|
||
params = self._get_params(flat_inputs) | ||
needs_transform_list = self._needs_transform_list(flat_inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we forgot to udpate this |
||
params = self._get_params( | ||
inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform | ||
) | ||
|
||
flat_outputs = [ | ||
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs | ||
self._transform(inpt, params) if needs_transform else inpt | ||
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) | ||
] | ||
|
||
return tree_unflatten(flat_outputs, spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test or the one below may not be your preferred test style. I'm happy to come back to this when we have more time,, but considering the very pressing timeline with the upcoming release, I kindly request that we focus on the substance (i.e. what the tests are testing + how we fix them + correctness) rather than style / perf / optimizations.