From 6ef66cab49561b54fd42fead1950822a54c33102 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 21 Sep 2022 13:43:11 +0200 Subject: [PATCH 1/4] [proto] Fixed fill type in AA --- .../prototype/transforms/_auto_augment.py | 21 ++++---- .../prototype/transforms/_deprecated.py | 5 +- torchvision/prototype/transforms/_geometry.py | 54 +++++-------------- torchvision/prototype/transforms/_utils.py | 48 ++++++++++++++++- 4 files changed, 71 insertions(+), 57 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index e9523fdd22e..bb57e5cad60 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,5 +1,4 @@ import math -import numbers from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union import PIL.Image @@ -10,7 +9,7 @@ from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_chw -from ._utils import _isinstance +from ._utils import _isinstance, _setup_fill_arg, FillType K = TypeVar("K") V = TypeVar("V") @@ -21,14 +20,11 @@ def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, ) -> None: super().__init__() self.interpolation = interpolation - - if not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") - self.fill = fill + self.fill = _setup_fill_arg(fill) def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: keys = tuple(dct.keys()) @@ -63,18 +59,19 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: def _apply_image_transform( self, - image: Any, + image: Union[torch.Tensor, PIL.Image.Image, features.Image], transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Union[int, float, Sequence[int], Sequence[float]], + fill: Dict[Type, FillType], ) -> Any: + fill_ = fill[type(image)] # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we have to put fill as None if fill == 0 # This is due to BC with stable API which has fill = None by default - fill_ = F._geometry._convert_fill_arg(fill) - if isinstance(fill, int) and fill == 0: + fill_ = F._geometry._convert_fill_arg(fill_) + if isinstance(fill_, int) and fill_ == 0: fill_ = None if transform_id == "Identity": @@ -286,7 +283,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image = self._extract_image(sample) - num_channels, height, width = get_chw(image) + _, height, width = get_chw(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index e6af6ba09b6..9b28551a048 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -11,10 +11,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import query_chw - - -DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] +from ._utils import DType, query_chw class ToTensor(Transform): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index fa2baa1f77f..a8f0b09765b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,7 +1,6 @@ import math import numbers import warnings -from collections import defaultdict from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -14,11 +13,20 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw - - -DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] -FillType = Union[int, float, Sequence[int], Sequence[float]] +from ._utils import ( + _check_padding_arg, + _check_padding_mode_arg, + _check_sequence_input, + _setup_angle, + _setup_fill_arg, + _setup_size, + DType, + FillType, + has_all, + has_any, + query_bounding_box, + query_chw, +) class RandomHorizontalFlip(_RandomApplyTransform): @@ -201,40 +209,6 @@ def forward(self, *inputs: Any) -> Any: return super().forward(*inputs) -def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: - if isinstance(fill, dict): - for key, value in fill.items(): - # Check key for type - _check_fill_arg(value) - else: - if not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") - - -def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: - _check_fill_arg(fill) - - if isinstance(fill, dict): - return fill - - return defaultdict(lambda: fill) # type: ignore[arg-type, return-value] - - -def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - - if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: - raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") - - -# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) -# https://github.com/pytorch/vision/issues/6250 -def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - - class Pad(Transform): def __init__( self, diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index eb48b134190..55c7cbabcdf 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,6 +1,11 @@ -from typing import Any, Callable, Tuple, Type, Union +import numbers +from collections import defaultdict + +from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union import PIL.Image + +import torch from torch.utils._pytree import tree_flatten from torchvision._utils import sequence_to_str from torchvision.prototype import features @@ -8,6 +13,47 @@ from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 +from typing_extensions import Literal + + +# Type shortcuts: +DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] +FillType = Union[int, float, Sequence[int], Sequence[float]] + + +def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: + if isinstance(fill, dict): + for key, value in fill.items(): + # Check key for type + _check_fill_arg(value) + else: + if not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg") + + +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: + _check_fill_arg(fill) + + if isinstance(fill, dict): + return fill + + return defaultdict(lambda: fill) # type: ignore[arg-type, return-value] + + +def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + +# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) +# https://github.com/pytorch/vision/issues/6250 +def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + def query_bounding_box(sample: Any) -> features.BoundingBox: flat_sample, _ = tree_flatten(sample) From 36c4d5cea501f939e853c42ed531e249cf8b290a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 21 Sep 2022 14:03:39 +0200 Subject: [PATCH 2/4] Fixed missed typehints --- torchvision/prototype/transforms/_auto_augment.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index bb57e5cad60..a0b1f5f2f79 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,5 +1,5 @@ import math -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union import PIL.Image import torch @@ -183,7 +183,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -343,7 +343,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -399,7 +399,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -459,7 +459,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 From d7486daa7dba28452cf40ca0abd557f1dd5c5399 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 21 Sep 2022 15:58:11 +0200 Subject: [PATCH 3/4] Set fill as None by default --- .../prototype/transforms/_auto_augment.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index a0b1f5f2f79..704e2356e9f 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -20,11 +20,11 @@ def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__() self.interpolation = interpolation - self.fill = _setup_fill_arg(fill) + self.fill = _setup_fill_arg(fill) if fill is not None else fill def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: keys = tuple(dct.keys()) @@ -63,16 +63,10 @@ def _apply_image_transform( transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Dict[Type, FillType], + fill: Optional[Dict[Type, FillType]], ) -> Any: - - fill_ = fill[type(image)] - # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 - # So, we have to put fill as None if fill == 0 - # This is due to BC with stable API which has fill = None by default + fill_ = fill[type(image)] if fill is not None else fill fill_ = F._geometry._convert_fill_arg(fill_) - if isinstance(fill_, int) and fill_ == 0: - fill_ = None if transform_id == "Identity": return image @@ -183,7 +177,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -343,7 +337,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -399,7 +393,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -459,7 +453,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 From c8390ff81672d4f3b60810ae8bc4c1ebf88dc1f9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 21 Sep 2022 19:07:47 +0200 Subject: [PATCH 4/4] Another fix --- torchvision/prototype/transforms/_auto_augment.py | 6 +++--- torchvision/prototype/transforms/_utils.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 704e2356e9f..0d14a434c1e 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -24,7 +24,7 @@ def __init__( ) -> None: super().__init__() self.interpolation = interpolation - self.fill = _setup_fill_arg(fill) if fill is not None else fill + self.fill = _setup_fill_arg(fill) def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: keys = tuple(dct.keys()) @@ -63,9 +63,9 @@ def _apply_image_transform( transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Optional[Dict[Type, FillType]], + fill: Union[Dict[Type, FillType], Dict[Type, None]], ) -> Any: - fill_ = fill[type(image)] if fill is not None else fill + fill_ = fill[type(image)] fill_ = F._geometry._convert_fill_arg(fill_) if transform_id == "Identity": diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 55c7cbabcdf..09829629e03 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,7 +1,7 @@ import numbers from collections import defaultdict -from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -21,23 +21,25 @@ FillType = Union[int, float, Sequence[int], Sequence[float]] -def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: +def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None: if isinstance(fill, dict): for key, value in fill.items(): # Check key for type _check_fill_arg(value) else: - if not isinstance(fill, (numbers.Number, tuple, list)): + if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate fill arg") -def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: +def _setup_fill_arg( + fill: Optional[Union[FillType, Dict[Type, FillType]]] +) -> Union[Dict[Type, FillType], Dict[Type, None]]: _check_fill_arg(fill) if isinstance(fill, dict): return fill - return defaultdict(lambda: fill) # type: ignore[arg-type, return-value] + return defaultdict(lambda: fill) # type: ignore[return-value] def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: