Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warnings checks for v2 namespaces #7288

Merged
merged 6 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import pathlib
import random
import shutil
import sys
import tempfile
from collections import defaultdict
from subprocess import CalledProcessError, check_output, STDOUT
from typing import Callable, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -838,3 +840,22 @@ def get_closeness_kwargs(self, test_id, *, dtype, device):
if isinstance(device, torch.device):
device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict())


def assert_run_python_script(source_code):
"""Utility to check assertions in an independent Python subprocess.
The script provided in the source code should return 0 and not print
anything on stderr or stdout. Taken from scikit-learn test utils.
source_code (str): The Python source code to execute.
"""
with tempfile.NamedTemporaryFile(mode="wb") as f:
f.write(source_code.encode())
f.flush()

cmd = [sys.executable, f.name]
try:
out = check_output(cmd, stderr=STDOUT)
except CalledProcessError as e:
raise RuntimeError("script errored with output:\n%s" % e.output.decode())
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
if out != b"":
raise AssertionError(out.decode())
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import pytest
import torch
import torchvision
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG


torchvision.disable_beta_transforms_warning()

from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you add this for #7265 or because we actually need it for these tests. I'm not against it, but want to make sure I understand correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly a drive-by, I stole it from #7282.

I think #7278 added some v2 imports in common_utils and we started seeing the warnings in the test suite, so now we have to disable them before we import anything from common_utils.



def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems)
Expand Down
47 changes: 46 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import re
import textwrap
import warnings
from functools import partial

Expand All @@ -24,7 +25,7 @@
except ImportError:
stats = None

from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes
from common_utils import assert_equal, assert_run_python_script, cycle_over, float_dtypes, int_dtypes


GRACE_HOPPER = get_file_path_2(
Expand Down Expand Up @@ -2266,5 +2267,49 @@ def test_random_grayscale_with_grayscale_input():
torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor)


def test_no_warnings_v1_namespace():
source = """
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
import torchvision.transforms
from torchvision import transforms
import torchvision.transforms.functional
from torchvision.transforms import Resize
from torchvision.transforms.functional import resize
"""
assert_run_python_script(textwrap.dedent(source))


# TODO: remove in 0.17 when we can delete functional_pil.py and functional_tensor.py
@pytest.mark.parametrize(
"import_statement",
(
"from torchvision.transforms import functional_pil",
"from torchvision.transforms import functional_tensor",
"from torchvision.transforms.functional_tensor import resize",
"from torchvision.transforms.functional_pil import resize",
),
)
@pytest.mark.parametrize("from_private", (True, False))
def test_functional_deprecation_warning(import_statement, from_private):
Copy link
Member Author

@NicolasHug NicolasHug Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this is slightly overkill, we don't need assert_run_python_script to check the warnings of the deprecated (public) files, we could just use a normal test. But we kinda still need it to properly check the lack of warnings for the now private files, so I mixed both. The whole thing will be removed soon anyway.

if from_private:
import_statement = import_statement.replace("functional", "_functional")
prelude = """
import warnings

with warnings.catch_warnings():
warnings.simplefilter("error")
"""
else:
prelude = """
import pytest
with pytest.warns(UserWarning, match="removed in 0.17"):
"""

source = prelude + " " * 4 + import_statement
assert_run_python_script(textwrap.dedent(source))


if __name__ == "__main__":
pytest.main([__file__])
51 changes: 51 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
import random
import re
import textwrap
import warnings
from collections import defaultdict

Expand All @@ -14,6 +15,7 @@

from common_utils import (
assert_equal,
assert_run_python_script,
cpu_and_gpu,
make_bounding_box,
make_bounding_boxes,
Expand Down Expand Up @@ -2045,3 +2047,52 @@ def test_sanitize_bounding_boxes_errors():
)
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes)


@pytest.mark.parametrize(
"import_statement",
(
"from torchvision.transforms import v2",
"import torchvision.transforms.v2",
"from torchvision.transforms.v2 import Resize",
"import torchvision.transforms.v2.functional",
"from torchvision.transforms.v2.functional import resize",
"from torchvision import datapoints",
"from torchvision.datapoints import Image",
"from torchvision.datasets import wrap_dataset_for_transforms_v2",
),
)
@pytest.mark.parametrize("call_disable_warning", (True, False))
def test_warnings_v2_namespaces(import_statement, call_disable_warning):
if call_disable_warning:
prelude = """
import warnings
import torchvision
torchvision.disable_beta_transforms_warning()
with warnings.catch_warnings():
warnings.simplefilter("error")

"""
else:
prelude = """
import pytest
with pytest.warns(UserWarning, match="v2 namespaces are still Beta"):
"""
source = prelude + " " * 4 + import_statement
assert_run_python_script(textwrap.dedent(source))


def test_no_warnings_v1_namespace():
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
source = """
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
import torchvision.transforms
from torchvision import transforms
import torchvision.transforms.functional
from torchvision.transforms import Resize
from torchvision.transforms.functional import resize
from torchvision import datasets
from torchvision.datasets import ImageNet
"""
assert_run_python_script(textwrap.dedent(source))