-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resize V2 relies on interpolate's native uint8 handling #7557
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7557
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 0e27ad8: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Context: In torchvision we ensure that functional ops are torchsciptable. Recently exposed `torch.backends.cpu.get_cpu_capability()` in pytorch#100164 is failing in torchvision CI ``` RuntimeError: Python builtin <built-in function _get_cpu_capability> is currently not supported in Torchscript: File "/usr/local/lib/python3.10/dist-packages/torch/backends/cpu/__init__.py", line 17 - "AVX512" """ return torch._C._get_cpu_capability() ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE ``` Ref: pytorch/vision#7557 In this PR, `torch._C._get_cpu_capability()` is explicitly registered for JIT and tested.
Description: Context: In torchvision we ensure that functional ops are torchscriptable. Recently exposed `torch.backends.cpu.get_cpu_capability()` in #100164 is failing in torchvision CI ``` RuntimeError: Python builtin <built-in function _get_cpu_capability> is currently not supported in Torchscript: File "/usr/local/lib/python3.10/dist-packages/torch/backends/cpu/__init__.py", line 17 - "AVX512" """ return torch._C._get_cpu_capability() ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE ``` Ref: pytorch/vision#7557 In this PR, `torch._C._get_cpu_capability()` is explicitly registered for JIT and tested. Pull Request resolved: #100723 Approved by: https://github.com/albanD
Description: - Now that pytorch/pytorch#90771 is merged, let Resize() rely on interpolate()'s native uint8 handling instead of converting to and from float. - uint8 input is not casted to f32 for nearest mode and bilinear mode if the latter has AVX2. Context: pytorch#7217 Benchmarks: ``` [----------- Resize cpu torch.uint8 InterpolationMode.NEAREST -----------] | resize v2 | resize stable | resize nightly 1 threads: --------------------------------------------------------------- (3, 400, 400) | 457 | 461 | 480 (16, 3, 400, 400) | 6870 | 6850 | 10100 Times are in microseconds (us). [---------- Resize cpu torch.uint8 InterpolationMode.BILINEAR -----------] | resize v2 | resize stable | resize nightly 1 threads: --------------------------------------------------------------- (3, 400, 400) | 326 | 329 | 844 (16, 3, 400, 400) | 4380 | 4390 | 14800 Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2)
torch.backends.cpu.get_cpu_capability() can't be scripted
Added tests on mem format
c87da29
to
427c4c1
Compare
7506d89
to
72ac231
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vfdev-5 , PR looks good. I made a few comments, LMK what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, tests are passing and we also stress-tested the changes locally against the float implementation. Looks safe to merge now.
Thanks a ton @vfdev-5 for all your work on interpolate()
and with bug chasing / fixing!
import random
import torch
from torchvision.transforms import Resize as Resize_v1
from torchvision.transforms.v2 import Resize as Resize_v2
import pytest
@pytest.mark.parametrize("C", range(1, 6))
@pytest.mark.parametrize("batch_size", ("3d", 1, 4))
@pytest.mark.parametrize("memory_format", (torch.contiguous_format, torch.channels_last, "strided", "cropped"))
@pytest.mark.parametrize("antialias", (True, False))
@pytest.mark.parametrize("seed", range(100))
def test_resize(C, batch_size, memory_format, antialias, seed):
torch.manual_seed(seed)
random.seed(seed)
Hi = 2**random.randint(3, 10) + random.randint(0, 30)
Wi = 2**random.randint(3, 10) + random.randint(0, 30)
Ho = 2**random.randint(3, 10) + random.randint(0, 30)
Wo = 2**random.randint(3, 10) + random.randint(0, 30)
print(Hi, Wi, Ho, Wo)
if batch_size == "3d":
img = torch.randint(0, 256, size=(C, Hi, Wi), dtype=torch.uint8)
else:
img = torch.randint(0, 256, size=(batch_size, C, Hi, Wi), dtype=torch.uint8)
if memory_format in (torch.contiguous_format, torch.channels_last):
if batch_size == "3d":
return
img = img.to(memory_format=memory_format, copy=True)
elif memory_format == "strided":
if batch_size == "3d":
img = img[:, ::2, ::2]
else:
img = img[:, :, ::2, ::2]
elif memory_format == "cropped":
a = random.randint(1, Hi // 2)
b = random.randint(Hi // 2 + 1, Hi)
c = random.randint(1, Wi // 2)
d = random.randint(Wi // 2 + 1, Wi)
if batch_size == "3d":
img = img[:, a:b, c:d]
else:
img = img[:, :, a:b, c:d]
else:
raise ValueError("Uh?")
out_uint8 = Resize_v2(size=(Ho, Wo), antialias=antialias)(img)
out_float = Resize_v1(size=(Ho, Wo), antialias=antialias)(img.float()).round().to(torch.uint8)
torch.testing.assert_close(out_uint8, out_float, rtol=0, atol=1)
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: vmoens Differential Revision: D45908452 fbshipit-source-id: a9821a4e1c50b973b2488753a3117faf59ffe585
Description:
Now that Add uint8 support for interpolate for CPU images pytorch#90771 is merged, let Resize() rely on interpolate()'s native uint8 handling instead of converting to and from float.
Context: #7217 and #7497
Benchmarks:
Source