diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 220a793ac9d..e56da8bbacc 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -22,7 +22,7 @@ UnsupportedInputs, ) from torchvision.prototype import features -from torchvision.prototype.transforms.functional import convert_image_dtype, to_image_tensor +from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.transforms.functional_tensor import _max_value as get_max_value __all__ = [ @@ -97,8 +97,8 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses): def _equalize_attributes(self, actual, expected): if actual.dtype != expected.dtype: dtype = torch.promote_types(actual.dtype, expected.dtype) - actual = convert_image_dtype(actual, dtype) - expected = convert_image_dtype(expected, dtype) + actual = convert_dtype_image_tensor(actual, dtype) + expected = convert_dtype_image_tensor(expected, dtype) return super()._equalize_attributes(actual, expected) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 82173907c6f..fab4b3583d5 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -416,4 +416,14 @@ def xfail_all_tests(*, reason, condition): skip_dispatch_feature, ], ), + DispatcherInfo( + F.convert_dtype, + kernels={ + features.Image: F.convert_dtype_image_tensor, + features.Video: F.convert_dtype_video, + }, + test_marks=[ + skip_dispatch_feature, + ], + ), ] diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index aab01904026..77a77444ba7 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1979,7 +1979,7 @@ def sample_inputs_normalize_video(): ) -def sample_inputs_convert_image_dtype(): +def sample_inputs_convert_dtype_image_tensor(): for input_dtype, output_dtype in itertools.product( [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2 ): @@ -1992,10 +1992,8 @@ def sample_inputs_convert_image_dtype(): ): yield ArgsKwargs(image_loader, dtype=output_dtype) - yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8) - -def reference_convert_image_dtype(image, dtype=torch.float): +def reference_convert_dtype_image_tensor(image, dtype=torch.float): input_dtype = image.dtype output_dtype = dtype @@ -2026,7 +2024,7 @@ def fn(value): return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype) -def reference_inputs_convert_image_dtype(): +def reference_inputs_convert_dtype_image_tensor(): for input_dtype, output_dtype in itertools.product( [ torch.uint8, @@ -2055,24 +2053,32 @@ def reference_inputs_convert_image_dtype(): yield ArgsKwargs(image, dtype=output_dtype) +def sample_inputs_convert_dtype_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader) + + +_common_convert_dtype_marks = [ + TestMark( + ("TestKernels", "test_dtype_and_device_consistency"), + pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), + condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32), + ), + TestMark( + ("TestKernels", "test_scripted_vs_eager"), + pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %')}:UserWarning"), + ), +] + KERNEL_INFOS.extend( [ KernelInfo( - F.convert_image_dtype, - sample_inputs_fn=sample_inputs_convert_image_dtype, - reference_fn=reference_convert_image_dtype, - reference_inputs_fn=reference_inputs_convert_image_dtype, + F.convert_dtype_image_tensor, + sample_inputs_fn=sample_inputs_convert_dtype_image_tensor, + reference_fn=reference_convert_dtype_image_tensor, + reference_inputs_fn=reference_inputs_convert_dtype_image_tensor, test_marks=[ - TestMark( - ("TestKernels", "test_scripted_vs_eager"), - pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"), - ), - TestMark( - ("TestKernels", "test_dtype_and_device_consistency"), - pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), - condition=lambda args_kwargs: args_kwargs.args[0].dtype - != args_kwargs.kwargs.get("dtype", torch.float32), - ), + *_common_convert_dtype_marks, TestMark( ("TestKernels", "test_against_reference"), pytest.mark.xfail(reason="Conversion overflows"), @@ -2080,10 +2086,6 @@ def reference_inputs_convert_image_dtype(): args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} and not args_kwargs.kwargs["dtype"].is_floating_point ) - or ( - args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} - and args_kwargs.kwargs["dtype"] == torch.int64 - ) or ( args_kwargs.args[0].dtype in {torch.int32, torch.int64} and args_kwargs.kwargs["dtype"] == torch.float16 @@ -2091,5 +2093,10 @@ def reference_inputs_convert_image_dtype(): ), ], ), + KernelInfo( + F.convert_dtype_video, + sample_inputs_fn=sample_inputs_convert_dtype_video, + test_marks=_common_convert_dtype_marks, + ), ] ) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 351430e1c9d..1a4a098dbab 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -92,7 +92,7 @@ class TestSmoke: transforms.RandomErasing(p=1.0), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), - transforms.ConvertImageDtype(), + transforms.ConvertDtype(), transforms.RandomHorizontalFlip(), transforms.Pad(5), transforms.RandomZoomOut(), diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 7d72463260e..b0022baaa37 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -153,7 +153,7 @@ def __init__( ), ), ConsistencyConfig( - prototype_transforms.ConvertImageDtype, + prototype_transforms.ConvertDtype, legacy_transforms.ConvertImageDtype, [ ArgsKwargs(torch.float16), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index dc867f8ffa4..cad5c204af8 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -307,6 +307,7 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on): (F.get_image_num_channels, F.get_num_channels), (F.to_pil_image, F.to_image_pil), (F.elastic_transform, F.elastic), + (F.convert_image_dtype, F.convert_dtype_image_tensor), ] ], ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5bf5a12cd78..099c30c9c6c 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -39,7 +39,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype +from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype from ._misc import ( GaussianBlur, Identity, diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 6e5a8139704..4a85175e901 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -25,7 +25,7 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat return features.BoundingBox.wrap_like(inpt, output, format=params["format"]) -class ConvertImageDtype(Transform): +class ConvertDtype(Transform): _transformed_types = (features.is_simple_tensor, features.Image, features.Video) def __init__(self, dtype: torch.dtype = torch.float32) -> None: @@ -35,12 +35,12 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: def _transform( self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] ) -> Union[features.TensorImageType, features.TensorVideoType]: - # TODO: the `inpt.as_subclass(torch.Tensor)` call can be removed as soon as we have a proper dispatcher that - # handles this. See https://github.com/pytorch/vision/pull/6783 for details. - output = F.convert_image_dtype(inpt.as_subclass(torch.Tensor), dtype=self.dtype) - return ( - output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined] - ) + return F.convert_dtype(inpt, self.dtype) + + +# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +ConvertImageDtype = ConvertDtype class ConvertColorSpace(Transform): diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index fb72e7b57a3..7e520c98691 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -8,6 +8,10 @@ convert_color_space_image_pil, convert_color_space_video, convert_color_space, + convert_dtype_image_tensor, + convert_dtype, + convert_dtype_video, + convert_image_dtype, get_dimensions_image_tensor, get_dimensions_image_pil, get_dimensions, @@ -162,7 +166,6 @@ normalize_video, ) from ._type_conversion import ( - convert_image_dtype, decode_image_with_pil, decode_video_with_av, pil_to_tensor, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 57155656212..674cba84677 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -285,3 +285,99 @@ def convert_color_space( return features.Video.wrap_like(inpt, output, color_space=color_space) else: return convert_color_space_image_pil(inpt, color_space) + + +def _num_value_bits(dtype: torch.dtype) -> int: + if dtype == torch.uint8: + return 8 + elif dtype == torch.int8: + return 7 + elif dtype == torch.int16: + return 15 + elif dtype == torch.int32: + return 31 + elif dtype == torch.int64: + return 63 + else: + raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") + + +def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + if image.dtype == dtype: + return image + + float_input = image.is_floating_point() + if torch.jit.is_scripting(): + # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT + float_output = torch.tensor(0, dtype=dtype).is_floating_point() + else: + float_output = dtype.is_floating_point + + if float_input: + # float to float + if float_output: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") + + # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting + # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only + # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # for a detailed analysis. + # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. + # Instead, we can also multiply by the maximum value plus something close to `1`. See + # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. + eps = 1e-3 + max_value = float(_FT._max_value(dtype)) + # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the + # discrete set `{0, 1}`. + return image.mul(max_value + 1.0 - eps).to(dtype) + else: + # int to float + if float_output: + return image.to(dtype).div_(_FT._max_value(image.dtype)) + + # int to int + num_value_bits_input = _num_value_bits(image.dtype) + num_value_bits_output = _num_value_bits(dtype) + + if num_value_bits_input > num_value_bits_output: + return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) + else: + # The bitshift kernel is not vectorized + # https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322 + # This results in the multiplication actually being faster. + # TODO: If the bitshift kernel is optimized in core, replace the computation below with + # `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)` + max_value_input = float(_FT._max_value(dtype)) + max_value_output = float(_FT._max_value(image.dtype)) + factor = int((max_value_input + 1) // (max_value_output + 1)) + return image.to(dtype).mul_(factor) + + +# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +convert_image_dtype = convert_dtype_image_tensor + + +def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + return convert_dtype_image_tensor(video, dtype) + + +def convert_dtype( + inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float +) -> torch.Tensor: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return convert_dtype_image_tensor(inpt, dtype) + elif isinstance(inpt, features.Image): + output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) + return features.Image.wrap_like(inpt, output) + else: # isinstance(inpt, features.Video): + output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) + return features.Video.wrap_like(inpt, output) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index a57fbc65536..712ca62ecb5 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -7,7 +7,7 @@ from torchvision.io.video import read_video from torchvision.prototype import features from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer -from torchvision.transforms import functional as _F, functional_tensor as _FT +from torchvision.transforms import functional as _F @torch.jit.unused @@ -41,78 +41,3 @@ def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> # We changed the names to align them with the new naming scheme. Still, `to_pil_image` is # prevalent and well understood. Thus, we just alias it without deprecating the old name. to_pil_image = to_image_pil - - -def _num_value_bits(dtype: torch.dtype) -> int: - if dtype == torch.uint8: - return 8 - elif dtype == torch.int8: - return 7 - elif dtype == torch.int16: - return 15 - elif dtype == torch.int32: - return 31 - elif dtype == torch.int64: - return 63 - else: - raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") - - -def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: - if not isinstance(image, torch.Tensor): - raise TypeError("Input img should be Tensor Image") - - if image.dtype == dtype: - return image - - float_input = image.is_floating_point() - if torch.jit.is_scripting(): - # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT - float_output = torch.tensor(0, dtype=dtype).is_floating_point() - else: - float_output = dtype.is_floating_point - - if float_input: - # float to float - if float_output: - return image.to(dtype) - - # float to int - if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( - image.dtype == torch.float64 and dtype == torch.int64 - ): - raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") - - # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting - # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only - # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 - # for a detailed analysis. - # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. - # Instead, we can also multiply by the maximum value plus something close to `1`. See - # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. - eps = 1e-3 - max_value = float(_FT._max_value(dtype)) - # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the - # discrete set `{0, 1}`. - return image.mul(max_value + 1.0 - eps).to(dtype) - else: - # int to float - if float_output: - return image.to(dtype).div_(_FT._max_value(image.dtype)) - - # int to int - num_value_bits_input = _num_value_bits(image.dtype) - num_value_bits_output = _num_value_bits(dtype) - - if num_value_bits_input > num_value_bits_output: - return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) - else: - # The bitshift kernel is not vectorized - # https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322 - # This results in the multiplication actually being faster. - # TODO: If the bitshift kernel is optimized in core, replace the computation below with - # `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)` - max_value_input = float(_FT._max_value(dtype)) - max_value_output = float(_FT._max_value(image.dtype)) - factor = int((max_value_input + 1) // (max_value_output + 1)) - return image.to(dtype).mul_(factor)