-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support all integer and floating point dtypes in prototype transform kernels? #6840
Comments
I'm in favour of doing this. It will allow users to cast to float32 first which might be beneficial for cases where the uint8 kernels are slower. |
Before closing this ticket, we should make the following proposed optimization on AugMix: vision/torchvision/prototype/transforms/_auto_augment.py Lines 509 to 513 in e1f464b
Edit: I checked the above and doesn't actually improve the speed, so we should just remove the comment. |
@pmeier I think the last thing to consider before closing the ticket is how
Thoughts on this? Edit: I issued a PR for this at #6874 |
The standard rule for dtype support for images and videos is:
[0.0, 1.0]
and integer tensors in[0, torch.iinfo(dtype).max]
(this is currently under review since there were a few cases, where this was not true or simply not handled. See Don't hardcode 255 unless uint8 is enforced #6825)However we have currently two kernels that only support
uint8
images or videos:vision/torchvision/prototype/transforms/functional/_color.py
Lines 373 to 375 in c84dbfa
vision/torchvision/transforms/functional_tensor.py
Lines 788 to 789 in c84dbfa
This also holds for transforms v1 so this is not a problem of the new API.
One consequence of that is that AA transforms are only supported for
uint8
imagesvision/torchvision/transforms/autoaugment.py
Lines 104 to 107 in c84dbfa
since both
vision/torchvision/transforms/autoaugment.py
Lines 76 to 77 in c84dbfa
and
vision/torchvision/transforms/autoaugment.py
Lines 82 to 83 in c84dbfa
are used.
One possible way of mitigating this to simply have a
convert_dtype(image, torch.uint8)
in the beginning and converting back after computation.That is probably needed for
equalize
since we recently switched away from the histogram ops oftorch
towards our "custom" implementation to enable batch processing (#6757). However, this relies on the fact that the input is an integer and in its current form even onuint8
due to some hardcoded constants.For
posterize
I think it is fairly easy to provide the same functionality for float inputs directly without going through a dtype conversion first.cc @vfdev-5 @datumbox @bjuncek
The text was updated successfully, but these errors were encountered: