Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize V2 relies on interpolate's native uint8 handling #7557

Merged
merged 8 commits into from
May 16, 2023

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented May 4, 2023

Description:

Context: #7217 and #7497

Benchmarks:

[-- Resize cpu torch.uint8 InterpolationMode.NEAREST -]
                         |  resize v2  |  resize stable
1 threads: --------------------------------------------
      (3, 400, 400)      |      457    |        487    
      (16, 3, 400, 400)  |     6880    |      10100    

Times are in microseconds (us).

[- Resize cpu torch.uint8 InterpolationMode.BILINEAR -]
                         |  resize v2  |  resize stable
1 threads: --------------------------------------------
      (3, 400, 400)      |      330    |        849    
      (16, 3, 400, 400)  |     4420    |      14900    

Times are in microseconds (us).

Source

@pytorch-bot
Copy link

pytorch-bot bot commented May 4, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7557

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 0e27ad8:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@vfdev-5 vfdev-5 requested a review from NicolasHug May 4, 2023 14:46
vfdev-5 added a commit to vfdev-5/pytorch that referenced this pull request May 8, 2023
Context: In torchvision we ensure that functional ops are torchsciptable.
Recently exposed `torch.backends.cpu.get_cpu_capability()` in pytorch#100164 is failing in torchvision CI
```
RuntimeError:
Python builtin <built-in function _get_cpu_capability> is currently not supported in Torchscript:
  File "/usr/local/lib/python3.10/dist-packages/torch/backends/cpu/__init__.py", line 17
    - "AVX512"
    """
    return torch._C._get_cpu_capability()
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
```
Ref: pytorch/vision#7557

In this PR, `torch._C._get_cpu_capability()` is explicitly registered for JIT and tested.
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request May 9, 2023
Description:

Context: In torchvision we ensure that functional ops are torchscriptable. Recently exposed `torch.backends.cpu.get_cpu_capability()` in #100164 is failing in torchvision CI
```
RuntimeError:
Python builtin <built-in function _get_cpu_capability> is currently not supported in Torchscript:
  File "/usr/local/lib/python3.10/dist-packages/torch/backends/cpu/__init__.py", line 17
    - "AVX512"
    """
    return torch._C._get_cpu_capability()
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
```
Ref: pytorch/vision#7557

In this PR, `torch._C._get_cpu_capability()` is explicitly registered for JIT and tested.

Pull Request resolved: #100723
Approved by: https://github.com/albanD
vfdev-5 added 4 commits May 9, 2023 13:08
Description:
- Now that pytorch/pytorch#90771 is merged, let Resize() rely on interpolate()'s native uint8 handling instead of converting to and from float.

  - uint8 input is not casted to f32 for nearest mode and bilinear mode if the latter has AVX2.

Context: pytorch#7217

Benchmarks:
```
[----------- Resize cpu torch.uint8 InterpolationMode.NEAREST -----------]
                         |  resize v2  |  resize stable  |  resize nightly
1 threads: ---------------------------------------------------------------
      (3, 400, 400)      |      457    |        461      |        480
      (16, 3, 400, 400)  |     6870    |       6850      |      10100

Times are in microseconds (us).

[---------- Resize cpu torch.uint8 InterpolationMode.BILINEAR -----------]
                         |  resize v2  |  resize stable  |  resize nightly
1 threads: ---------------------------------------------------------------
      (3, 400, 400)      |      326    |        329      |        844
      (16, 3, 400, 400)  |     4380    |       4390      |      14800

Times are in microseconds (us).
```

[Source](https://gist.github.com/vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2)
torch.backends.cpu.get_cpu_capability() can't be scripted
@vfdev-5 vfdev-5 force-pushed the resize-use-interp-uint8 branch from c87da29 to 427c4c1 Compare May 9, 2023 20:36
@vfdev-5 vfdev-5 force-pushed the resize-use-interp-uint8 branch from 7506d89 to 72ac231 Compare May 12, 2023 07:34
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @vfdev-5 , PR looks good. I made a few comments, LMK what you think

torchvision/transforms/v2/functional/_geometry.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_geometry.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_geometry.py Outdated Show resolved Hide resolved
test/common_utils.py Outdated Show resolved Hide resolved
test/test_transforms_v2_consistency.py Outdated Show resolved Hide resolved
test/test_transforms_v2_consistency.py Outdated Show resolved Hide resolved
test/test_transforms_v2_functional.py Show resolved Hide resolved
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, tests are passing and we also stress-tested the changes locally against the float implementation. Looks safe to merge now.

Thanks a ton @vfdev-5 for all your work on interpolate() and with bug chasing / fixing!

import random

import torch
from torchvision.transforms import Resize as Resize_v1
from torchvision.transforms.v2 import Resize as Resize_v2
import pytest

@pytest.mark.parametrize("C", range(1, 6))
@pytest.mark.parametrize("batch_size", ("3d", 1, 4))
@pytest.mark.parametrize("memory_format", (torch.contiguous_format, torch.channels_last, "strided", "cropped"))
@pytest.mark.parametrize("antialias", (True, False))
@pytest.mark.parametrize("seed", range(100))
def test_resize(C, batch_size, memory_format, antialias, seed):
    
    torch.manual_seed(seed)
    random.seed(seed)

    Hi = 2**random.randint(3, 10) + random.randint(0, 30) 
    Wi = 2**random.randint(3, 10) + random.randint(0, 30) 
    Ho = 2**random.randint(3, 10) + random.randint(0, 30) 
    Wo = 2**random.randint(3, 10) + random.randint(0, 30) 
    print(Hi, Wi, Ho, Wo)

    if batch_size == "3d":
        img = torch.randint(0, 256, size=(C, Hi, Wi), dtype=torch.uint8)
    else:
        img = torch.randint(0, 256, size=(batch_size, C, Hi, Wi), dtype=torch.uint8)

    if memory_format in (torch.contiguous_format, torch.channels_last):
        if batch_size == "3d":
            return
        img = img.to(memory_format=memory_format, copy=True)
    elif memory_format == "strided":
        if batch_size == "3d":
            img = img[:, ::2, ::2]
        else:
            img = img[:, :, ::2, ::2]
    elif memory_format == "cropped":
        a = random.randint(1, Hi // 2)
        b = random.randint(Hi // 2 + 1, Hi)
        c = random.randint(1, Wi // 2)
        d = random.randint(Wi // 2 + 1, Wi)
        if batch_size == "3d":
            img = img[:, a:b, c:d]
        else:
            img = img[:, :, a:b, c:d]
    else:
        raise ValueError("Uh?")
    

    out_uint8 = Resize_v2(size=(Ho, Wo), antialias=antialias)(img)
    out_float = Resize_v1(size=(Ho, Wo), antialias=antialias)(img.float()).round().to(torch.uint8)

    torch.testing.assert_close(out_uint8, out_float, rtol=0, atol=1)

        

@NicolasHug NicolasHug changed the title Resize relies on interpolate's native uint8 handling Resize V2 relies on interpolate's native uint8 handling May 16, 2023
@NicolasHug NicolasHug merged commit 99ec261 into pytorch:main May 16, 2023
@github-actions
Copy link

Hey @NicolasHug!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@NicolasHug NicolasHug added enhancement module: transforms Perf For performance improvements labels May 16, 2023
@vfdev-5 vfdev-5 deleted the resize-use-interp-uint8 branch May 16, 2023 13:44
facebook-github-bot pushed a commit that referenced this pull request May 16, 2023
Summary: Co-authored-by: Nicolas Hug <[email protected]>

Reviewed By: vmoens

Differential Revision: D45908452

fbshipit-source-id: a9821a4e1c50b973b2488753a3117faf59ffe585
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants