Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests and proper support for videos in ConvertImageDtype #6783

Merged
merged 10 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
UnsupportedInputs,
)
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_image_dtype, to_image_tensor
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value

__all__ = [
Expand Down Expand Up @@ -97,8 +97,8 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses):
def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_image_dtype(actual, dtype)
expected = convert_image_dtype(expected, dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)

return super()._equalize_attributes(actual, expected)

Expand Down
10 changes: 10 additions & 0 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,14 @@ def xfail_all_tests(*, reason, condition):
skip_dispatch_feature,
],
),
DispatcherInfo(
F.convert_dtype,
kernels={
features.Image: F.convert_dtype_image_tensor,
features.Video: F.convert_dtype_video,
},
test_marks=[
skip_dispatch_feature,
],
),
]
53 changes: 30 additions & 23 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,7 +1979,7 @@ def sample_inputs_normalize_video():
)


def sample_inputs_convert_image_dtype():
def sample_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
):
Expand All @@ -1992,10 +1992,8 @@ def sample_inputs_convert_image_dtype():
):
yield ArgsKwargs(image_loader, dtype=output_dtype)

yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8)


def reference_convert_image_dtype(image, dtype=torch.float):
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
input_dtype = image.dtype
output_dtype = dtype

Expand Down Expand Up @@ -2026,7 +2024,7 @@ def fn(value):
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)


def reference_inputs_convert_image_dtype():
def reference_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[
torch.uint8,
Expand Down Expand Up @@ -2055,41 +2053,50 @@ def reference_inputs_convert_image_dtype():
yield ArgsKwargs(image, dtype=output_dtype)


def sample_inputs_convert_dtype_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)


_common_convert_dtype_marks = [
TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
),
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %')}:UserWarning"),
),
]

KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_image_dtype,
sample_inputs_fn=sample_inputs_convert_image_dtype,
reference_fn=reference_convert_image_dtype,
reference_inputs_fn=reference_inputs_convert_image_dtype,
F.convert_dtype_image_tensor,
sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
reference_fn=reference_convert_dtype_image_tensor,
reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
test_marks=[
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"),
),
TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype
!= args_kwargs.kwargs.get("dtype", torch.float32),
),
*_common_convert_dtype_marks,
TestMark(
("TestKernels", "test_against_reference"),
pytest.mark.xfail(reason="Conversion overflows"),
condition=lambda args_kwargs: (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and not args_kwargs.kwargs["dtype"].is_floating_point
)
or (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and args_kwargs.kwargs["dtype"] == torch.int64
)
or (
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
and args_kwargs.kwargs["dtype"] == torch.float16
),
),
],
),
KernelInfo(
F.convert_dtype_video,
sample_inputs_fn=sample_inputs_convert_dtype_video,
test_marks=_common_convert_dtype_marks,
),
]
)
2 changes: 1 addition & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class TestSmoke:
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(),
transforms.ConvertDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
),
),
ConsistencyConfig(
prototype_transforms.ConvertImageDtype,
prototype_transforms.ConvertDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
Expand Down
1 change: 1 addition & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on):
(F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil),
(F.elastic_transform, F.elastic),
(F.convert_image_dtype, F.convert_dtype_image_tensor),
]
],
)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype
from ._misc import (
GaussianBlur,
Identity,
Expand Down
14 changes: 7 additions & 7 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat
return features.BoundingBox.wrap_like(inpt, output, format=params["format"])


class ConvertImageDtype(Transform):
class ConvertDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
Expand All @@ -35,12 +35,12 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None:
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> Union[features.TensorImageType, features.TensorVideoType]:
# TODO: the `inpt.as_subclass(torch.Tensor)` call can be removed as soon as we have a proper dispatcher that
# handles this. See https://github.com/pytorch/vision/pull/6783 for details.
output = F.convert_image_dtype(inpt.as_subclass(torch.Tensor), dtype=self.dtype)
return (
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
)
return F.convert_dtype(inpt, self.dtype)


# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ConvertImageDtype = ConvertDtype


class ConvertColorSpace(Transform):
Expand Down
5 changes: 4 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
convert_dtype_image_tensor,
convert_dtype,
convert_dtype_video,
convert_image_dtype,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
Expand Down Expand Up @@ -162,7 +166,6 @@
normalize_video,
)
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
decode_video_with_av,
pil_to_tensor,
Expand Down
96 changes: 96 additions & 0 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,99 @@ def convert_color_space(
return features.Video.wrap_like(inpt, output, color_space=color_space)
else:
return convert_color_space_image_pil(inpt, color_space)


def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image

float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point

if float_input:
# float to float
if float_output:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_FT._max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).div_(_FT._max_value(image.dtype))

# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)

if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
# The bitshift kernel is not vectorized
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
# This results in the multiplication actually being faster.
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
max_value_input = float(_FT._max_value(dtype))
max_value_output = float(_FT._max_value(image.dtype))
factor = int((max_value_input + 1) // (max_value_output + 1))
return image.to(dtype).mul_(factor)


# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
convert_image_dtype = convert_dtype_image_tensor


def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
return convert_dtype_image_tensor(video, dtype)


def convert_dtype(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return convert_dtype_image_tensor(inpt, dtype)
elif isinstance(inpt, features.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return features.Image.wrap_like(inpt, output)
else: # isinstance(inpt, features.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return features.Video.wrap_like(inpt, output)
Loading