-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tests for transform presets, and various fixes #7223
Conversation
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)): | ||
if needs_transform and check_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is concerning that every single overload of forward()
must be aware of the heuristic somehow. It is probably going to be a problem for user-defined transforms, unless we don't want them to override forward()
- which would be unfair, since we do it ourselves.
] | ||
) | ||
|
||
out = t(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The asserts
below are light, the actual purpose of this test is to make sure that this call doesn't fail.
On main
, it fails for a bunch of reasons right now
@@ -36,8 +36,19 @@ def forward(self, *inputs: Any) -> Any: | |||
|
|||
self._check_inputs(flat_inputs) | |||
|
|||
params = self._get_params(flat_inputs) | |||
needs_transform_list = self._needs_transform_list(flat_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extracted-out the heuristic in a method, because it is needed in other places (basically, all forward()s
of transform subclasses).
params = self._get_params( | ||
inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note how we're only passing the relevant inputs to get_params()
now. This is because if we were to pass all the inputs, we'd get a failure down the line in query_spatial_size()
& Co.
This is because pure tensors are also "queried" there:
or is_simple_tensor(inpt) |
which in the case of the 1D label tensor would lead to an error:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") |
Another way to fix this would be to make query_spatial_size()
& Co aware of the heuristic somehow. But it's harder because right now, the heuristic relies on self._transformed_types
. It's probably safer this way for user-defined transforms.
@@ -159,10 +169,14 @@ def forward(self, *inputs: Any) -> Any: | |||
if torch.rand(1) >= self.p: | |||
return inputs | |||
|
|||
params = self._get_params(flat_inputs) | |||
needs_transform_list = self._needs_transform_list(flat_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we forgot to udpate this forward()
as well...
@@ -169,7 +169,8 @@ def _permute_channels( | |||
if isinstance(orig_inpt, PIL.Image.Image): | |||
inpt = F.pil_to_tensor(inpt) | |||
|
|||
output = inpt[..., permutation, :, :] | |||
# TODO: Find a better fix than as_subclass??? | |||
output = inpt[..., permutation, :, :].as_subclass(type(inpt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am concerned that this slipped through the cracks: passing an Image would result in a pure Tensor.
@pmeier don't we have tests that make sure the types are preserved? Are we ensuring in the tests that all the random transforms are actually tested (i.e. set p=1
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we have tests that make sure the types are preserved?
We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher.
@pytest.mark.parametrize("label_type", (torch.Tensor, int)) | ||
@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) | ||
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) | ||
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test or the one below may not be your preferred test style. I'm happy to come back to this when we have more time,, but considering the very pressing timeline with the upcoming release, I kindly request that we focus on the substance (i.e. what the tests are testing + how we fix them + correctness) rather than style / perf / optimizations.
test/test_prototype_transforms.py
Outdated
# TODO: I have to set a very low size for the crop (30, 30), | ||
# otherwise we'd get an error saying the crop is larger than the | ||
# image. This means RandomCrop doesn't do the same thing as | ||
# FixedSizeCrop and we need ot figure out the key differences | ||
transforms.RandomCrop((30, 30)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RandomCrop
doesn't do any padding by default, but FixedSizedCrop
does. You probably only need to set pad_if_needed=True
.
@@ -169,7 +169,8 @@ def _permute_channels( | |||
if isinstance(orig_inpt, PIL.Image.Image): | |||
inpt = F.pil_to_tensor(inpt) | |||
|
|||
output = inpt[..., permutation, :, :] | |||
# TODO: Find a better fix than as_subclass??? | |||
output = inpt[..., permutation, :, :].as_subclass(type(inpt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we have tests that make sure the types are preserved?
We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Nicolas!
Reviewed By: vmoens Differential Revision: D44416274 fbshipit-source-id: 87f1e0dd1b8bafc383cef15f31391d7c3c0ed6d3
This PR adds tests for some transforms pipeline(s) and fixes a few stuff.
The common theme to most (all?) of the fixes here is that we forgot to also apply the heuristic we introduced in #7170 in the other places where it's needed.
(The fact that it's needed in so many places is IMO a red-flag that suggests it's either not a viable solution to the original problem, or that the design we currently have involves highly coupled components. Either way, i'm not sure we'll have time to tackle that before the release, but we should come back to it)
cc @pmeier @bjuncek