diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index e9523fdd22e..0d14a434c1e 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,6 +1,5 @@ import math -import numbers -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union import PIL.Image import torch @@ -10,7 +9,7 @@ from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_chw -from ._utils import _isinstance +from ._utils import _isinstance, _setup_fill_arg, FillType K = TypeVar("K") V = TypeVar("V") @@ -21,14 +20,11 @@ def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__() self.interpolation = interpolation - - if not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") - self.fill = fill + self.fill = _setup_fill_arg(fill) def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: keys = tuple(dct.keys()) @@ -63,19 +59,14 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: def _apply_image_transform( self, - image: Any, + image: Union[torch.Tensor, PIL.Image.Image, features.Image], transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Union[int, float, Sequence[int], Sequence[float]], + fill: Union[Dict[Type, FillType], Dict[Type, None]], ) -> Any: - - # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 - # So, we have to put fill as None if fill == 0 - # This is due to BC with stable API which has fill = None by default - fill_ = F._geometry._convert_fill_arg(fill) - if isinstance(fill, int) and fill == 0: - fill_ = None + fill_ = fill[type(image)] + fill_ = F._geometry._convert_fill_arg(fill_) if transform_id == "Identity": return image @@ -186,7 +177,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -286,7 +277,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image = self._extract_image(sample) - num_channels, height, width = get_chw(image) + _, height, width = get_chw(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -346,7 +337,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -402,7 +393,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -462,7 +453,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index e6af6ba09b6..9b28551a048 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -11,10 +11,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import query_chw - - -DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] +from ._utils import DType, query_chw class ToTensor(Transform): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index fa2baa1f77f..a8f0b09765b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,7 +1,6 @@ import math import numbers import warnings -from collections import defaultdict from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -14,11 +13,20 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw - - -DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] -FillType = Union[int, float, Sequence[int], Sequence[float]] +from ._utils import ( + _check_padding_arg, + _check_padding_mode_arg, + _check_sequence_input, + _setup_angle, + _setup_fill_arg, + _setup_size, + DType, + FillType, + has_all, + has_any, + query_bounding_box, + query_chw, +) class RandomHorizontalFlip(_RandomApplyTransform): @@ -201,40 +209,6 @@ def forward(self, *inputs: Any) -> Any: return super().forward(*inputs) -def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: - if isinstance(fill, dict): - for key, value in fill.items(): - # Check key for type - _check_fill_arg(value) - else: - if not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") - - -def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: - _check_fill_arg(fill) - - if isinstance(fill, dict): - return fill - - return defaultdict(lambda: fill) # type: ignore[arg-type, return-value] - - -def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - - if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: - raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") - - -# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) -# https://github.com/pytorch/vision/issues/6250 -def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - - class Pad(Transform): def __init__( self, diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index eb48b134190..09829629e03 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,6 +1,11 @@ -from typing import Any, Callable, Tuple, Type, Union +import numbers +from collections import defaultdict + +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union import PIL.Image + +import torch from torch.utils._pytree import tree_flatten from torchvision._utils import sequence_to_str from torchvision.prototype import features @@ -8,6 +13,49 @@ from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 +from typing_extensions import Literal + + +# Type shortcuts: +DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] +FillType = Union[int, float, Sequence[int], Sequence[float]] + + +def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None: + if isinstance(fill, dict): + for key, value in fill.items(): + # Check key for type + _check_fill_arg(value) + else: + if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg") + + +def _setup_fill_arg( + fill: Optional[Union[FillType, Dict[Type, FillType]]] +) -> Union[Dict[Type, FillType], Dict[Type, None]]: + _check_fill_arg(fill) + + if isinstance(fill, dict): + return fill + + return defaultdict(lambda: fill) # type: ignore[return-value] + + +def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + +# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) +# https://github.com/pytorch/vision/issues/6250 +def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + def query_bounding_box(sample: Any) -> features.BoundingBox: flat_sample, _ = tree_flatten(sample)