From d2d448c71b4cb054d160000a0f63eecad7867bdb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jan 2023 13:12:22 +0100 Subject: [PATCH] add tests for the output types of prototype functional dispatchers (#7118) --- test/prototype_transforms_dispatcher_infos.py | 13 ++++++- test/test_prototype_transforms_functional.py | 37 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 4a54cb40d9a..f6b8786570c 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -112,6 +112,15 @@ def xfail_jit_list_of_ints(name, *, reason=None): pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."), ) +multi_crop_skips = [ + TestMark( + ("TestDispatchers", test_name), + pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."), + ) + for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"] +] +multi_crop_skips.append(skip_dispatch_datapoint) + def fill_sequence_needs_broadcast(args_kwargs): (image_loader, *_), kwargs = args_kwargs @@ -404,7 +413,7 @@ def fill_sequence_needs_broadcast(args_kwargs): pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), test_marks=[ xfail_jit_python_scalar_arg("size"), - skip_dispatch_datapoint, + *multi_crop_skips, ], ), DispatcherInfo( @@ -415,7 +424,7 @@ def fill_sequence_needs_broadcast(args_kwargs): }, test_marks=[ xfail_jit_python_scalar_arg("size"), - skip_dispatch_datapoint, + *multi_crop_skips, ], pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), ), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index bc299fd1f50..102f78e6e11 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -362,6 +362,16 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): spy.assert_called_once() + @image_sample_inputs + def test_simple_tensor_output_type(self, info, args_kwargs): + (image_datapoint, *other_args), kwargs = args_kwargs.load() + image_simple_tensor = image_datapoint.as_subclass(torch.Tensor) + + output = info.dispatcher(image_simple_tensor, *other_args, **kwargs) + + # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well + assert type(output) is torch.Tensor + @make_info_args_kwargs_parametrization( [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), @@ -381,6 +391,22 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on): spy.assert_called_once() + @make_info_args_kwargs_parametrization( + [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], + args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), + ) + def test_pil_output_type(self, info, args_kwargs): + (image_datapoint, *other_args), kwargs = args_kwargs.load() + + if image_datapoint.ndim > 3: + pytest.skip("Input is batched") + + image_pil = F.to_image_pil(image_datapoint) + + output = info.dispatcher(image_pil, *other_args, **kwargs) + + assert isinstance(output, PIL.Image.Image) + @make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(), @@ -397,6 +423,17 @@ def test_dispatch_datapoint(self, info, args_kwargs, spy_on): spy.assert_called_once() + @make_info_args_kwargs_parametrization( + DISPATCHER_INFOS, + args_kwargs_fn=lambda info: info.sample_inputs(), + ) + def test_datapoint_output_type(self, info, args_kwargs): + (datapoint, *other_args), kwargs = args_kwargs.load() + + output = info.dispatcher(datapoint, *other_args, **kwargs) + + assert isinstance(output, type(datapoint)) + @pytest.mark.parametrize( ("dispatcher_info", "datapoint_type", "kernel_info"), [