Skip to content

Commit

Permalink
add tests for the output types of prototype functional dispatchers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Jan 20, 2023
1 parent 01d138d commit d2d448c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
13 changes: 11 additions & 2 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def xfail_jit_list_of_ints(name, *, reason=None):
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
)

multi_crop_skips = [
TestMark(
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)


def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
Expand Down Expand Up @@ -404,7 +413,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_datapoint,
*multi_crop_skips,
],
),
DispatcherInfo(
Expand All @@ -415,7 +424,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_datapoint,
*multi_crop_skips,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
),
Expand Down
37 changes: 37 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):

spy.assert_called_once()

@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor

@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
Expand All @@ -381,6 +391,22 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on):

spy.assert_called_once()

@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_pil_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()

if image_datapoint.ndim > 3:
pytest.skip("Input is batched")

image_pil = F.to_image_pil(image_datapoint)

output = info.dispatcher(image_pil, *other_args, **kwargs)

assert isinstance(output, PIL.Image.Image)

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
Expand All @@ -397,6 +423,17 @@ def test_dispatch_datapoint(self, info, args_kwargs, spy_on):

spy.assert_called_once()

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_datapoint_output_type(self, info, args_kwargs):
(datapoint, *other_args), kwargs = args_kwargs.load()

output = info.dispatcher(datapoint, *other_args, **kwargs)

assert isinstance(output, type(datapoint))

@pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"),
[
Expand Down

0 comments on commit d2d448c

Please sign in to comment.