From b7a5d039bd9f9c6b4673bc6f32bcebd3dfe48c52 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 18 Oct 2022 07:52:34 +0200 Subject: [PATCH 1/7] add KernelInfo --- test/prototype_transforms_kernel_infos.py | 33 +++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index c455caa6b7a..7c227b79af8 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1940,3 +1940,36 @@ def sample_inputs_normalize_video(): ), ] ) + + +def sample_inputs_convert_image_dtype(): + for input_dtype, output_dtype in itertools.product( + [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2 + ): + if input_dtype.is_floating_point and output_dtype == torch.int64: + continue + + for image_loader in make_image_loaders( + sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype] + ): + yield ArgsKwargs(image_loader, dtype=output_dtype) + + yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.convert_image_dtype, + sample_inputs_fn=sample_inputs_convert_image_dtype, + test_marks=[ + TestMark( + ("TestKernels", "test_dtype_and_device_consistency"), + pytest.mark.skip(reason="`convert_image_dtype` converts the dtype"), + condition=lambda args_kwargs: args_kwargs.args[0].dtype + != args_kwargs.kwargs.get("dtype", torch.float32), + ) + ], + ), + ] +) From 18e61d20c719f46efe993af3ad7c776b2bb9755f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 18 Oct 2022 08:51:49 +0200 Subject: [PATCH 2/7] split dtype and device consistency tests --- test/prototype_transforms_kernel_infos.py | 5 +++-- test/test_prototype_transforms_functional.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 7c227b79af8..dd5931927a5 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -122,7 +122,8 @@ def xfail_all_tests(*, reason, condition): "test_batched_vs_single", "test_no_inplace", "test_cuda_vs_cpu", - "test_dtype_and_device_consistency", + "test_dtype_consistency", + "test_device_consistency", ] ] @@ -1964,7 +1965,7 @@ def sample_inputs_convert_image_dtype(): sample_inputs_fn=sample_inputs_convert_image_dtype, test_marks=[ TestMark( - ("TestKernels", "test_dtype_and_device_consistency"), + ("TestKernels", "test_dtype_consistency"), pytest.mark.skip(reason="`convert_image_dtype` converts the dtype"), condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 34291611d8d..56430fcafac 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -175,7 +175,7 @@ def test_cuda_vs_cpu(self, info, args_kwargs): @sample_inputs @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_dtype_and_device_consistency(self, info, args_kwargs, device): + def test_dtype_consistency(self, info, args_kwargs, device): (input, *other_args), kwargs = args_kwargs.load(device) output = info.kernel(input, *other_args, **kwargs) @@ -184,6 +184,17 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device): output, *_ = output assert output.dtype == input.dtype + + @sample_inputs + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_device_consistency(self, info, args_kwargs, device): + (input, *other_args), kwargs = args_kwargs.load(device) + + output = info.kernel(input, *other_args, **kwargs) + # Most kernels just return a tensor, but some also return some additional metadata + if not isinstance(output, torch.Tensor): + output, *_ = output + assert output.device == input.device @reference_inputs From 55ade917161a6c338eb6d8f25e35b449d8f8dbe0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 18 Oct 2022 09:07:33 +0200 Subject: [PATCH 3/7] add proper support for video --- test/test_prototype_transforms.py | 2 +- test/test_prototype_transforms_consistency.py | 2 +- torchvision/prototype/features/_image.py | 3 +++ torchvision/prototype/features/_video.py | 3 +++ torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_meta.py | 10 ++++---- .../transforms/functional/__init__.py | 4 +++- .../prototype/transforms/functional/_meta.py | 24 +++++++++++++++++++ .../transforms/functional/_type_conversion.py | 2 -- 9 files changed, 40 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 5928e6718c1..d41e524d2aa 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -90,7 +90,7 @@ class TestSmoke: transforms.RandomErasing(p=1.0), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), - transforms.ConvertImageDtype(), + transforms.ConvertDtype(), transforms.RandomHorizontalFlip(), transforms.Pad(5), transforms.RandomZoomOut(), diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 7d2f1d735ea..c68828cf44c 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -152,7 +152,7 @@ def __init__( ), ), ConsistencyConfig( - prototype_transforms.ConvertImageDtype, + prototype_transforms.ConvertDtype, legacy_transforms.ConvertImageDtype, [ ArgsKwargs(torch.float16), diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index e9128b94be0..117a8cd560b 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -122,6 +122,9 @@ def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) color_space=color_space, ) + def to_dtype(self, dtype: torch.dtype = torch.float, copy: bool = True) -> Image: + return Image.wrap_like(self, self._F.convert_dtype_image_tensor(self, dtype, copy=copy)) + def horizontal_flip(self) -> Image: output = self._F.horizontal_flip_image_tensor(self) return Image.wrap_like(self, output) diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index 26f97549ac5..82517b20c39 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -78,6 +78,9 @@ def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) color_space=color_space, ) + def to_dtype(self, dtype: torch.dtype = torch.float, copy: bool = True) -> Video: + return Video.wrap_like(self, self._F.convert_dtype_video(self, dtype, copy=copy)) + def horizontal_flip(self) -> Video: output = self._F.horizontal_flip_video(self) return Video.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5324db63496..b3f8bba2cbc 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -39,7 +39,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype +from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index bdfe8b47a89..63919c3f814 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -21,20 +21,18 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat return features.BoundingBox.wrap_like(inpt, output, format=params["format"]) -class ConvertImageDtype(Transform): +class ConvertDtype(Transform): _transformed_types = (features.is_simple_tensor, features.Image, features.Video) - def __init__(self, dtype: torch.dtype = torch.float32) -> None: + def __init__(self, dtype: torch.dtype = torch.float32, copy: bool = True) -> None: super().__init__() self.dtype = dtype + self.copy = copy def _transform( self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] ) -> Union[features.TensorImageType, features.TensorVideoType]: - output = F.convert_image_dtype(inpt, dtype=self.dtype) - return ( - output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined] - ) + return F.convert_dtype(inpt, self.dtype, copy=self.copy) class ConvertColorSpace(Transform): diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index fb72e7b57a3..98c114b7732 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -24,6 +24,9 @@ get_spatial_size_mask, get_spatial_size_video, get_spatial_size, + convert_dtype_image_tensor, + convert_dtype, + convert_dtype_video, ) # usort: skip from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video @@ -162,7 +165,6 @@ normalize_video, ) from ._type_conversion import ( - convert_image_dtype, decode_image_with_pil, decode_video_with_av, pil_to_tensor, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 2903d73ce95..71f4740c08c 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -282,3 +282,27 @@ def convert_color_space( return inpt.to_color_space(color_space, copy=copy) else: return convert_color_space_image_pil(inpt, color_space, copy=copy) + + +def convert_dtype_image_tensor( + image: torch.Tensor, dtype: torch.dtype = torch.float, copy: bool = True +) -> torch.Tensor: + if copy and image.dtype == dtype: + return image.clone() + + return _FT.convert_image_dtype(image, dtype) + + +def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, copy: bool = True) -> torch.Tensor: + return convert_dtype_image_tensor(video, dtype, copy=copy) + + +def convert_dtype( + inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float, copy: bool = True +) -> torch.Tensor: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return convert_dtype_image_tensor(inpt, dtype, copy=copy) + else: # isinstance(inpt, (features.Image, features.Video)): + return inpt.to_dtype(dtype, copy=copy) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index b171716ae87..86ff256cdb5 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -41,5 +41,3 @@ def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> # We changed the names to align them with the new naming scheme. Still, `to_pil_image` is # prevalent and well understood. Thus, we just alias it without deprecating the old name. to_pil_image = to_image_pil - -convert_image_dtype = _F.convert_image_dtype From a28af5a6ba734d6d3610cf57e587617ca35b3ad9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 18 Oct 2022 09:37:22 +0200 Subject: [PATCH 4/7] fix tests and add DispatcherInfo --- test/prototype_common_utils.py | 6 +-- test/prototype_transforms_dispatcher_infos.py | 11 ++++++ test/prototype_transforms_kernel_infos.py | 39 +++++++++++++------ test/test_prototype_transforms_functional.py | 6 +-- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 220a793ac9d..e56da8bbacc 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -22,7 +22,7 @@ UnsupportedInputs, ) from torchvision.prototype import features -from torchvision.prototype.transforms.functional import convert_image_dtype, to_image_tensor +from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.transforms.functional_tensor import _max_value as get_max_value __all__ = [ @@ -97,8 +97,8 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses): def _equalize_attributes(self, actual, expected): if actual.dtype != expected.dtype: dtype = torch.promote_types(actual.dtype, expected.dtype) - actual = convert_image_dtype(actual, dtype) - expected = convert_image_dtype(expected, dtype) + actual = convert_dtype_image_tensor(actual, dtype) + expected = convert_dtype_image_tensor(expected, dtype) return super()._equalize_attributes(actual, expected) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 82173907c6f..7c2f21545f3 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -31,6 +31,8 @@ def __init__( *, # Dictionary of types that map to the kernel the dispatcher dispatches to. kernels, + # Name of the corresponding method on the features. If omitted, defaults to the dispatcher name + method_name=None, # If omitted, no PIL dispatch test will be performed. pil_kernel_info=None, # See InfoBase @@ -41,6 +43,7 @@ def __init__( super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs) self.dispatcher = dispatcher self.kernels = kernels + self.method_name = method_name or self.id self.pil_kernel_info = pil_kernel_info kernel_infos = {} @@ -416,4 +419,12 @@ def xfail_all_tests(*, reason, condition): skip_dispatch_feature, ], ), + DispatcherInfo( + F.convert_dtype, + kernels={ + features.Image: F.convert_dtype_image_tensor, + features.Video: F.convert_dtype_video, + }, + method_name="to_dtype", + ), ] diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index dd5931927a5..5fc93137a3d 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1943,33 +1943,48 @@ def sample_inputs_normalize_video(): ) -def sample_inputs_convert_image_dtype(): +def sample_inputs_convert_dtype_image_tensor(): for input_dtype, output_dtype in itertools.product( [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2 ): if input_dtype.is_floating_point and output_dtype == torch.int64: continue - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype] + for image_loader, copy_kwargs in itertools.product( + make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype]), + [dict(copy=True), dict(copy=False)] if input_dtype == output_dtype else [dict()], ): - yield ArgsKwargs(image_loader, dtype=output_dtype) + yield ArgsKwargs(image_loader, dtype=output_dtype, **copy_kwargs) yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8) +def sample_inputs_convert_dtype_video(): + for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + yield ArgsKwargs(video_loader) + + +_convert_dtype_skip_dtype_consistency = TestMark( + ("TestKernels", "test_dtype_consistency"), + pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), + condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32), +) + + KERNEL_INFOS.extend( [ KernelInfo( - F.convert_image_dtype, - sample_inputs_fn=sample_inputs_convert_image_dtype, + F.convert_dtype_image_tensor, + sample_inputs_fn=sample_inputs_convert_dtype_image_tensor, + test_marks=[ + _convert_dtype_skip_dtype_consistency, + ], + ), + KernelInfo( + F.convert_dtype_video, + sample_inputs_fn=sample_inputs_convert_dtype_video, test_marks=[ - TestMark( - ("TestKernels", "test_dtype_consistency"), - pytest.mark.skip(reason="`convert_image_dtype` converts the dtype"), - condition=lambda args_kwargs: args_kwargs.args[0].dtype - != args_kwargs.kwargs.get("dtype", torch.float32), - ) + _convert_dtype_skip_dtype_consistency, ], ), ] diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 56430fcafac..02e06f1ccaf 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -243,7 +243,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): [ F.clamp_bounding_box, F.convert_color_space, - F.convert_image_dtype, F.get_dimensions, F.get_image_num_channels, F.get_image_size, @@ -296,10 +295,9 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on): def test_dispatch_feature(self, info, args_kwargs, spy_on): (feature, *other_args), kwargs = args_kwargs.load() - method_name = info.id - method = getattr(feature, method_name) + method = getattr(feature, info.method_name) feature_type = type(feature) - spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{method_name}") + spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{info.method_name}") info.dispatcher(feature, *other_args, **kwargs) From 5426d4242a3ae8c6259d5f8b3b4d8fd21e519f2d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 18 Oct 2022 09:42:11 +0200 Subject: [PATCH 5/7] add aliases --- test/test_prototype_transforms_functional.py | 1 + torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_meta.py | 5 +++++ torchvision/prototype/transforms/functional/__init__.py | 7 ++++--- torchvision/prototype/transforms/functional/_meta.py | 5 +++++ 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 02e06f1ccaf..884dffd6966 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -314,6 +314,7 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on): (F.get_image_num_channels, F.get_num_channels), (F.to_pil_image, F.to_image_pil), (F.elastic_transform, F.elastic), + (F.convert_image_dtype, F.convert_dtype_image_tensor), ] ], ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index b3f8bba2cbc..ad1484d95c6 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -39,7 +39,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype +from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 63919c3f814..ab8aa6cf8d1 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -35,6 +35,11 @@ def _transform( return F.convert_dtype(inpt, self.dtype, copy=self.copy) +# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +ConvertImageDtype = ConvertDtype + + class ConvertColorSpace(Transform): _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 98c114b7732..7e520c98691 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -8,6 +8,10 @@ convert_color_space_image_pil, convert_color_space_video, convert_color_space, + convert_dtype_image_tensor, + convert_dtype, + convert_dtype_video, + convert_image_dtype, get_dimensions_image_tensor, get_dimensions_image_pil, get_dimensions, @@ -24,9 +28,6 @@ get_spatial_size_mask, get_spatial_size_video, get_spatial_size, - convert_dtype_image_tensor, - convert_dtype, - convert_dtype_video, ) # usort: skip from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 71f4740c08c..d2a17e0ec96 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -293,6 +293,11 @@ def convert_dtype_image_tensor( return _FT.convert_image_dtype(image, dtype) +# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +convert_image_dtype = convert_dtype_image_tensor + + def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, copy: bool = True) -> torch.Tensor: return convert_dtype_image_tensor(video, dtype, copy=copy) From beedf49a9d16b043a7f081a1bcf767fa67f38302 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 21 Oct 2022 15:05:55 +0200 Subject: [PATCH 6/7] cleanup --- test/prototype_transforms_kernel_infos.py | 5 ++--- test/test_prototype_transforms_functional.py | 13 +------------ 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 39a2b565e6b..c342d5f85ee 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -127,8 +127,7 @@ def xfail_all_tests(*, reason, condition): "test_batched_vs_single", "test_no_inplace", "test_cuda_vs_cpu", - "test_dtype_consistency", - "test_device_consistency", + "test_dtype_and_dtype_consistency", ] ] @@ -2061,7 +2060,7 @@ def sample_inputs_convert_dtype_video(): _common_convert_dtype_marks = [ TestMark( - ("TestKernels", "test_dtype_consistency"), + ("TestKernels", "test_dtype_and_device_consistency"), pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32), ), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e58b04c61f7..cad5c204af8 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -178,7 +178,7 @@ def test_cuda_vs_cpu(self, info, args_kwargs): @sample_inputs @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_dtype_consistency(self, info, args_kwargs, device): + def test_dtype_and_device_consistency(self, info, args_kwargs, device): (input, *other_args), kwargs = args_kwargs.load(device) output = info.kernel(input, *other_args, **kwargs) @@ -187,17 +187,6 @@ def test_dtype_consistency(self, info, args_kwargs, device): output, *_ = output assert output.dtype == input.dtype - - @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_device_consistency(self, info, args_kwargs, device): - (input, *other_args), kwargs = args_kwargs.load(device) - - output = info.kernel(input, *other_args, **kwargs) - # Most kernels just return a tensor, but some also return some additional metadata - if not isinstance(output, torch.Tensor): - output, *_ = output - assert output.device == input.device @reference_inputs From 6c506fcb98ea794c7ba36f8ee076ae331d60ef72 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 21 Oct 2022 15:10:45 +0200 Subject: [PATCH 7/7] fix typo --- test/prototype_transforms_kernel_infos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index c342d5f85ee..77a77444ba7 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -127,7 +127,7 @@ def xfail_all_tests(*, reason, condition): "test_batched_vs_single", "test_no_inplace", "test_cuda_vs_cpu", - "test_dtype_and_dtype_consistency", + "test_dtype_and_device_consistency", ] ]