Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proto] Fixed fill type in AA #6621

Merged
merged 6 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
import numbers
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union

import PIL.Image
import torch
Expand All @@ -10,7 +9,7 @@
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw

from ._utils import _isinstance
from ._utils import _isinstance, _setup_fill_arg, FillType

K = TypeVar("K")
V = TypeVar("V")
Expand All @@ -21,14 +20,11 @@ def __init__(
self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
) -> None:
super().__init__()
self.interpolation = interpolation

if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
self.fill = fill
self.fill = _setup_fill_arg(fill)

def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys())
Expand Down Expand Up @@ -63,18 +59,19 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:

def _apply_image_transform(
self,
image: Any,
image: Union[torch.Tensor, PIL.Image.Image, features.Image],
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Union[int, float, Sequence[int], Sequence[float]],
fill: Dict[Type, FillType],
) -> Any:

fill_ = fill[type(image)]
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0
# This is due to BC with stable API which has fill = None by default
fill_ = F._geometry._convert_fill_arg(fill)
if isinstance(fill, int) and fill == 0:
fill_ = F._geometry._convert_fill_arg(fill_)
if isinstance(fill_, int) and fill_ == 0:
fill_ = None

if transform_id == "Identity":
Expand Down Expand Up @@ -186,7 +183,7 @@ def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
Expand Down Expand Up @@ -286,7 +283,7 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
_, height, width = get_chw(image)

policy = self._policies[int(torch.randint(len(self._policies), ()))]

Expand Down Expand Up @@ -346,7 +343,7 @@ def __init__(
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
Expand Down Expand Up @@ -402,7 +399,7 @@ def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
Expand Down Expand Up @@ -462,7 +459,7 @@ def __init__(
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
Expand Down
5 changes: 1 addition & 4 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import query_chw


DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
from ._utils import DType, query_chw


class ToTensor(Transform):
Expand Down
54 changes: 14 additions & 40 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
import numbers
import warnings
from collections import defaultdict
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image
Expand All @@ -14,11 +13,20 @@
from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw


DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_setup_angle,
_setup_fill_arg,
_setup_size,
DType,
FillType,
has_all,
has_any,
query_bounding_box,
query_chw,
)


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down Expand Up @@ -201,40 +209,6 @@ def forward(self, *inputs: Any) -> Any:
return super().forward(*inputs)


def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")


def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)

if isinstance(fill, dict):
return fill

return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")

if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")


# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")


class Pad(Transform):
def __init__(
self,
Expand Down
48 changes: 47 additions & 1 deletion torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,59 @@
from typing import Any, Callable, Tuple, Type, Union
import numbers
from collections import defaultdict

from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union

import PIL.Image

import torch
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features

from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401

from typing_extensions import Literal


# Type shortcuts:
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]


def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")


def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)

if isinstance(fill, dict):
return fill

return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")

if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")


# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")


def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
Expand Down