Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize V2 relies on interpolate's native uint8 handling #7557

Merged
merged 8 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,15 @@ def load(self, device):
class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format

def __post_init__(self):
self.spatial_size = self.shape[-2:]
self.num_channels = self.shape[-3]

def load(self, device):
return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format)


NUM_CHANNELS_MAP = {
"GRAY": 1,
Expand All @@ -493,18 +497,21 @@ def make_image_loader(
extra_dims=(),
dtype=torch.float32,
constant_alpha=True,
memory_format=torch.contiguous_format,
):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
def fn(shape, dtype, device, memory_format):
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
data = torch.testing.make_tensor(
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format
)
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value
return datapoints.Image(data)

return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)


make_image = from_loader(make_image_loader)
Expand All @@ -530,11 +537,13 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders)


def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
def make_image_loader_for_interpolation(
size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
def fn(shape, dtype, device, memory_format):
height, width = shape[-2:]

image_pil = (
Expand All @@ -550,19 +559,26 @@ def fn(shape, dtype, device):
)
)

image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
image_tensor = to_image_tensor(image_pil)
if memory_format == torch.contiguous_format:
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else:
image_tensor = image_tensor.to(device=device)
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)
assert image_tensor[None].is_contiguous(memory_format=memory_format)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

return datapoints.Image(image_tensor)

return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)


def make_image_loaders_for_interpolation(
sizes=((233, 147),),
color_spaces=("RGB",),
dtypes=(torch.uint8,),
memory_formats=(torch.contiguous_format, torch.channels_last),
):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats):
yield make_image_loader_for_interpolation(**params)


Expand Down Expand Up @@ -744,8 +760,10 @@ def make_video_loader(
size = _parse_spatial_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames

def fn(shape, dtype, device):
video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
def fn(shape, dtype, device, memory_format):
video = make_image(
size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format
)
return datapoints.Video(video)

return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
Expand Down
13 changes: 10 additions & 3 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def __init__(
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
closeness_kwargs=dict(rtol=1, atol=1),
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
),
ConsistencyConfig(
v2_transforms.CenterCrop,
Expand Down Expand Up @@ -313,6 +315,8 @@ def __init__(
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
closeness_kwargs=dict(rtol=1, atol=1),
),
ConsistencyConfig(
v2_transforms.RandomErasing,
Expand Down Expand Up @@ -783,7 +787,8 @@ def test_compose(self):
]
)

check_call_consistency(prototype_transform, legacy_transform)
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1))

@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
Expand All @@ -807,7 +812,8 @@ def test_random_apply(self, p, sequence_type):
p=p,
)

check_call_consistency(prototype_transform, legacy_transform)
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1))

if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
Expand All @@ -832,7 +838,8 @@ def test_random_choice(self, probabilities):
p=probabilities,
)

check_call_consistency(prototype_transform, legacy_transform)
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1))


class TestToTensorTransforms:
Expand Down
25 changes: 25 additions & 0 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,3 +1365,28 @@ def test_correctness_uniform_temporal_subsample(device):

out_video = F.uniform_temporal_subsample(video, 8)
assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")

output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
assert input.ndim == 3, error_msg_fn
input_stride = input.stride()
output_stride = output.stride()
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
if input_stride[-1] == 1:
expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
assert expected_stride == output_stride, error_msg_fn("")
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
elif input_stride[0] == 1:
expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
assert expected_stride == output_stride, error_msg_fn("")
else:
assert False, error_msg_fn("")
18 changes: 11 additions & 7 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,31 +1569,35 @@ def reference_inputs_equalize_image_tensor():
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
if dtype.is_floating_point:
low = low_factor
high = high_factor
else:
max_value = torch.iinfo(dtype).max
low = int(low_factor * max_value)
high = int(high_factor * max_value)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
memory_format=memory_format, copy=True
)

def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
image = torch.distributions.Beta(alpha, beta).sample(shape)
if not dtype.is_floating_point:
image.mul_(torch.iinfo(dtype).max).round_()
return image.to(dtype=dtype, device=device)
return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)

spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product(
[torch.uint8],
["GRAY", "RGB"],
[
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
memory_format=memory_format, copy=True
),
lambda shape, dtype, device, memory_format: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
).to(memory_format=memory_format, copy=True),
*[
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
for low_factor, high_factor in [
Expand Down
30 changes: 29 additions & 1 deletion torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,35 @@ def resize_image_tensor(
image = image.reshape(-1, num_channels, old_height, old_width)

dtype = image.dtype
need_cast = dtype not in (torch.float32, torch.float64)
acceptable_dtypes = [torch.float32, torch.float64]
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
if "AVX2" in torch.backends.cpu.get_cpu_capability():
acceptable_dtypes.append(torch.uint8)

# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
if dtype == torch.uint8 and not (
image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last)
):
image = image.contiguous(memory_format=torch.channels_last)

if image.is_contiguous(memory_format=torch.channels_last):
strides = image.stride()
numel = image.numel()
if image.shape[0] == 1 and numel != strides[0]:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
# This is the case when channels last tensor was squeezed and unsqueezed such that
# stride[0] set as image.shape[1] * image.stride()[1] instead of being image.numel()
# Let's restride image such that it will be correctly treated as channels last.
# Related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
new_strides = list(strides)
new_strides[0] = numel
image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

need_cast = dtype not in acceptable_dtypes
if need_cast:
image = image.to(dtype=torch.float32)

Expand Down