Skip to content

Commit

Permalink
[fbsync] Add tests and proper support for videos in `ConvertImageDtyp…
Browse files Browse the repository at this point in the history
…e` (#6783)

Summary:
* add KernelInfo

* split dtype and device consistency tests

* add proper support for video

* fix tests and add DispatcherInfo

* add aliases

* cleanup

* fix typo

Reviewed By: YosuaMichael

Differential Revision: D40722908

fbshipit-source-id: 36adda72819a12167ed12d27f6715a46c8ee9b56
  • Loading branch information
datumbox authored and facebook-github-bot committed Oct 27, 2022
1 parent e1a66c2 commit 08758ca
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 113 deletions.
6 changes: 3 additions & 3 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
UnsupportedInputs,
)
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_image_dtype, to_image_tensor
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value

__all__ = [
Expand Down Expand Up @@ -97,8 +97,8 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses):
def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_image_dtype(actual, dtype)
expected = convert_image_dtype(expected, dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)

return super()._equalize_attributes(actual, expected)

Expand Down
10 changes: 10 additions & 0 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,14 @@ def xfail_all_tests(*, reason, condition):
skip_dispatch_feature,
],
),
DispatcherInfo(
F.convert_dtype,
kernels={
features.Image: F.convert_dtype_image_tensor,
features.Video: F.convert_dtype_video,
},
test_marks=[
skip_dispatch_feature,
],
),
]
53 changes: 30 additions & 23 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,7 +1979,7 @@ def sample_inputs_normalize_video():
)


def sample_inputs_convert_image_dtype():
def sample_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
):
Expand All @@ -1992,10 +1992,8 @@ def sample_inputs_convert_image_dtype():
):
yield ArgsKwargs(image_loader, dtype=output_dtype)

yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8)


def reference_convert_image_dtype(image, dtype=torch.float):
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
input_dtype = image.dtype
output_dtype = dtype

Expand Down Expand Up @@ -2026,7 +2024,7 @@ def fn(value):
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)


def reference_inputs_convert_image_dtype():
def reference_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[
torch.uint8,
Expand Down Expand Up @@ -2055,41 +2053,50 @@ def reference_inputs_convert_image_dtype():
yield ArgsKwargs(image, dtype=output_dtype)


def sample_inputs_convert_dtype_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)


_common_convert_dtype_marks = [
TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
),
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %')}:UserWarning"),
),
]

KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_image_dtype,
sample_inputs_fn=sample_inputs_convert_image_dtype,
reference_fn=reference_convert_image_dtype,
reference_inputs_fn=reference_inputs_convert_image_dtype,
F.convert_dtype_image_tensor,
sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
reference_fn=reference_convert_dtype_image_tensor,
reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
test_marks=[
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"),
),
TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype
!= args_kwargs.kwargs.get("dtype", torch.float32),
),
*_common_convert_dtype_marks,
TestMark(
("TestKernels", "test_against_reference"),
pytest.mark.xfail(reason="Conversion overflows"),
condition=lambda args_kwargs: (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and not args_kwargs.kwargs["dtype"].is_floating_point
)
or (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and args_kwargs.kwargs["dtype"] == torch.int64
)
or (
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
and args_kwargs.kwargs["dtype"] == torch.float16
),
),
],
),
KernelInfo(
F.convert_dtype_video,
sample_inputs_fn=sample_inputs_convert_dtype_video,
test_marks=_common_convert_dtype_marks,
),
]
)
2 changes: 1 addition & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class TestSmoke:
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(),
transforms.ConvertDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
),
),
ConsistencyConfig(
prototype_transforms.ConvertImageDtype,
prototype_transforms.ConvertDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
Expand Down
1 change: 1 addition & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on):
(F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil),
(F.elastic_transform, F.elastic),
(F.convert_image_dtype, F.convert_dtype_image_tensor),
]
],
)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype
from ._misc import (
GaussianBlur,
Identity,
Expand Down
14 changes: 7 additions & 7 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat
return features.BoundingBox.wrap_like(inpt, output, format=params["format"])


class ConvertImageDtype(Transform):
class ConvertDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
Expand All @@ -35,12 +35,12 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None:
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> Union[features.TensorImageType, features.TensorVideoType]:
# TODO: the `inpt.as_subclass(torch.Tensor)` call can be removed as soon as we have a proper dispatcher that
# handles this. See https://github.com/pytorch/vision/pull/6783 for details.
output = F.convert_image_dtype(inpt.as_subclass(torch.Tensor), dtype=self.dtype)
return (
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
)
return F.convert_dtype(inpt, self.dtype)


# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ConvertImageDtype = ConvertDtype


class ConvertColorSpace(Transform):
Expand Down
5 changes: 4 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
convert_dtype_image_tensor,
convert_dtype,
convert_dtype_video,
convert_image_dtype,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
Expand Down Expand Up @@ -162,7 +166,6 @@
normalize_video,
)
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
decode_video_with_av,
pil_to_tensor,
Expand Down
96 changes: 96 additions & 0 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,99 @@ def convert_color_space(
return features.Video.wrap_like(inpt, output, color_space=color_space)
else:
return convert_color_space_image_pil(inpt, color_space)


def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image

float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point

if float_input:
# float to float
if float_output:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_FT._max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).div_(_FT._max_value(image.dtype))

# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)

if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
# The bitshift kernel is not vectorized
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
# This results in the multiplication actually being faster.
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
max_value_input = float(_FT._max_value(dtype))
max_value_output = float(_FT._max_value(image.dtype))
factor = int((max_value_input + 1) // (max_value_output + 1))
return image.to(dtype).mul_(factor)


# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
convert_image_dtype = convert_dtype_image_tensor


def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
return convert_dtype_image_tensor(video, dtype)


def convert_dtype(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return convert_dtype_image_tensor(inpt, dtype)
elif isinstance(inpt, features.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return features.Image.wrap_like(inpt, output)
else: # isinstance(inpt, features.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return features.Video.wrap_like(inpt, output)
Loading

0 comments on commit 08758ca

Please sign in to comment.