Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for transform presets, and various fixes #7223

Merged
merged 7 commits into from
Feb 13, 2023

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Feb 11, 2023

This PR adds tests for some transforms pipeline(s) and fixes a few stuff.

The common theme to most (all?) of the fixes here is that we forgot to also apply the heuristic we introduced in #7170 in the other places where it's needed.

(The fact that it's needed in so many places is IMO a red-flag that suggests it's either not a viable solution to the original problem, or that the design we currently have involves highly coupled components. Either way, i'm not sure we'll have time to tackle that before the release, but we should come back to it)

cc @pmeier @bjuncek

Comment on lines +43 to +44
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
if needs_transform and check_type(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is concerning that every single overload of forward() must be aware of the heuristic somehow. It is probably going to be a problem for user-defined transforms, unless we don't want them to override forward() - which would be unfair, since we do it ourselves.

]
)

out = t(sample)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The asserts below are light, the actual purpose of this test is to make sure that this call doesn't fail.
On main, it fails for a bunch of reasons right now

@@ -36,8 +36,19 @@ def forward(self, *inputs: Any) -> Any:

self._check_inputs(flat_inputs)

params = self._get_params(flat_inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted-out the heuristic in a method, because it is needed in other places (basically, all forward()s of transform subclasses).

Comment on lines 40 to 42
params = self._get_params(
inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note how we're only passing the relevant inputs to get_params() now. This is because if we were to pass all the inputs, we'd get a failure down the line in query_spatial_size() & Co.

This is because pure tensors are also "queried" there:

or is_simple_tensor(inpt)

which in the case of the 1D label tensor would lead to an error:

raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
.

Another way to fix this would be to make query_spatial_size() & Co aware of the heuristic somehow. But it's harder because right now, the heuristic relies on self._transformed_types. It's probably safer this way for user-defined transforms.

@@ -159,10 +169,14 @@ def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs

params = self._get_params(flat_inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we forgot to udpate this forward() as well...

@@ -169,7 +169,8 @@ def _permute_channels(
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)

output = inpt[..., permutation, :, :]
# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
Copy link
Member Author

@NicolasHug NicolasHug Feb 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that this slipped through the cracks: passing an Image would result in a pure Tensor.
@pmeier don't we have tests that make sure the types are preserved? Are we ensuring in the tests that all the random transforms are actually tested (i.e. set p=1)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have tests that make sure the types are preserved?

We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher.

@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test or the one below may not be your preferred test style. I'm happy to come back to this when we have more time,, but considering the very pressing timeline with the upcoming release, I kindly request that we focus on the substance (i.e. what the tests are testing + how we fix them + correctness) rather than style / perf / optimizations.

@NicolasHug NicolasHug mentioned this pull request Feb 12, 2023
49 tasks
Comment on lines 2070 to 2074
# TODO: I have to set a very low size for the crop (30, 30),
# otherwise we'd get an error saying the crop is larger than the
# image. This means RandomCrop doesn't do the same thing as
# FixedSizeCrop and we need ot figure out the key differences
transforms.RandomCrop((30, 30)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomCrop doesn't do any padding by default, but FixedSizedCrop does. You probably only need to set pad_if_needed=True.

@@ -169,7 +169,8 @@ def _permute_channels(
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)

output = inpt[..., permutation, :, :]
# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have tests that make sure the types are preserved?

We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher.

Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Nicolas!

@NicolasHug NicolasHug merged commit 0aed832 into pytorch:main Feb 13, 2023
facebook-github-bot pushed a commit that referenced this pull request Mar 28, 2023
Reviewed By: vmoens

Differential Revision: D44416274

fbshipit-source-id: 87f1e0dd1b8bafc383cef15f31391d7c3c0ed6d3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants