Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all integer and floating point dtypes in prototype transform kernels? #6840

Closed
pmeier opened this issue Oct 26, 2022 · 3 comments · Fixed by #6874
Closed

Support all integer and floating point dtypes in prototype transform kernels? #6840

pmeier opened this issue Oct 26, 2022 · 3 comments · Fixed by #6874

Comments

@pmeier
Copy link
Collaborator

pmeier commented Oct 26, 2022

The standard rule for dtype support for images and videos is:

  • All floating point and integer tensors are supported.
  • Floating point tensors are valid in the range [0.0, 1.0] and integer tensors in [0, torch.iinfo(dtype).max] (this is currently under review since there were a few cases, where this was not true or simply not handled. See Don't hardcode 255 unless uint8 is enforced #6825)

However we have currently two kernels that only support uint8 images or videos:

This also holds for transforms v1 so this is not a problem of the new API.

One consequence of that is that AA transforms are only supported for uint8 images

class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected

since both

elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))

and

elif op_name == "Equalize":
img = F.equalize(img)

are used.

One possible way of mitigating this to simply have a convert_dtype(image, torch.uint8) in the beginning and converting back after computation.

That is probably needed for equalize since we recently switched away from the histogram ops of torch towards our "custom" implementation to enable batch processing (#6757). However, this relies on the fact that the input is an integer and in its current form even on uint8 due to some hardcoded constants.

For posterize I think it is fairly easy to provide the same functionality for float inputs directly without going through a dtype conversion first.

cc @vfdev-5 @datumbox @bjuncek

@datumbox
Copy link
Contributor

I'm in favour of doing this. It will allow users to cast to float32 first which might be beneficial for cases where the uint8 kernels are slower.

@datumbox
Copy link
Contributor

datumbox commented Oct 27, 2022

Before closing this ticket, we should make the following proposed optimization on AugMix:

# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
# TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840
combined_weights[:, i].reshape(batch_dims)
* aug

Edit: I checked the above and doesn't actually improve the speed, so we should just remove the comment.

@datumbox
Copy link
Contributor

datumbox commented Oct 31, 2022

@pmeier I think the last thing to consider before closing the ticket is how AA can send the right threshold values to posterize when the Image is a float. Aka:

"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),

Thoughts on this?

Edit: I issued a PR for this at #6874

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants