Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize V2 relies on interpolate's native uint8 handling #7557

Merged
merged 8 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,15 @@ def load(self, device):
class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format

def __post_init__(self):
self.spatial_size = self.shape[-2:]
self.num_channels = self.shape[-3]

def load(self, device):
return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format)


NUM_CHANNELS_MAP = {
"GRAY": 1,
Expand All @@ -493,18 +497,21 @@ def make_image_loader(
extra_dims=(),
dtype=torch.float32,
constant_alpha=True,
memory_format=torch.contiguous_format,
):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
def fn(shape, dtype, device, memory_format):
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
data = torch.testing.make_tensor(
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format
)
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value
return datapoints.Image(data)

return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)


make_image = from_loader(make_image_loader)
Expand All @@ -530,11 +537,13 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders)


def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
def make_image_loader_for_interpolation(
size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
def fn(shape, dtype, device, memory_format):
height, width = shape[-2:]

image_pil = (
Expand All @@ -550,19 +559,25 @@ def fn(shape, dtype, device):
)
)

image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
image_tensor = to_image_tensor(image_pil)
if memory_format == torch.contiguous_format:
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else:
image_tensor = image_tensor.to(device=device)
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)

return datapoints.Image(image_tensor)

return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)


def make_image_loaders_for_interpolation(
sizes=((233, 147),),
color_spaces=("RGB",),
dtypes=(torch.uint8,),
memory_formats=(torch.contiguous_format, torch.channels_last),
):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats):
yield make_image_loader_for_interpolation(**params)


Expand Down Expand Up @@ -744,8 +759,10 @@ def make_video_loader(
size = _parse_spatial_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames

def fn(shape, dtype, device):
video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
def fn(shape, dtype, device, memory_format):
video = make_image(
size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format
)
return datapoints.Video(video)

return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
Expand Down
13 changes: 10 additions & 3 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def __init__(
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.CenterCrop,
Expand Down Expand Up @@ -313,6 +315,8 @@ def __init__(
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.RandomErasing,
Expand Down Expand Up @@ -783,7 +787,8 @@ def test_compose(self):
]
)

check_call_consistency(prototype_transform, legacy_transform)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
Expand All @@ -807,7 +812,8 @@ def test_random_apply(self, p, sequence_type):
p=p,
)

check_call_consistency(prototype_transform, legacy_transform)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
Expand All @@ -832,7 +838,8 @@ def test_random_choice(self, probabilities):
p=probabilities,
)

check_call_consistency(prototype_transform, legacy_transform)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))


class TestToTensorTransforms:
Expand Down
30 changes: 30 additions & 0 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,3 +1365,33 @@ def test_correctness_uniform_temporal_subsample(device):

out_video = F.uniform_temporal_subsample(video, 8)
assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")

output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
assert input.ndim == 3, error_msg_fn
input_stride = input.stride()
output_stride = output.stride()
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
# Here we check output memory format according to the input:
# if input_stride is (..., 1) then input is most likely channels first and thus
# output strides should match channels first strides (H * W, H, 1)
# if input_stride is (1, ...) then input is most likely channels last and thus
# output strides should match channels last strides (1, W * C, C)
if input_stride[-1] == 1:
expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
assert expected_stride == output_stride, error_msg_fn("")
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
elif input_stride[0] == 1:
expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
assert expected_stride == output_stride, error_msg_fn("")
else:
assert False, error_msg_fn("")
18 changes: 11 additions & 7 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,31 +1569,35 @@ def reference_inputs_equalize_image_tensor():
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
if dtype.is_floating_point:
low = low_factor
high = high_factor
else:
max_value = torch.iinfo(dtype).max
low = int(low_factor * max_value)
high = int(high_factor * max_value)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
memory_format=memory_format, copy=True
)

def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
image = torch.distributions.Beta(alpha, beta).sample(shape)
if not dtype.is_floating_point:
image.mul_(torch.iinfo(dtype).max).round_()
return image.to(dtype=dtype, device=device)
return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)

spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product(
[torch.uint8],
["GRAY", "RGB"],
[
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
memory_format=memory_format, copy=True
),
lambda shape, dtype, device, memory_format: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
).to(memory_format=memory_format, copy=True),
*[
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
for low_factor, high_factor in [
Expand Down
35 changes: 33 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,47 @@ def resize_image_tensor(
antialias = False

shape = image.shape
numel = image.numel()
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)

if (new_height, new_width) == (old_height, old_width):
return image
elif image.numel() > 0:
elif numel > 0:
image = image.reshape(-1, num_channels, old_height, old_width)

dtype = image.dtype
need_cast = dtype not in (torch.float32, torch.float64)
acceptable_dtypes = [torch.float32, torch.float64]
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
if "AVX2" in torch.backends.cpu.get_cpu_capability():
acceptable_dtypes.append(torch.uint8)

# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
if dtype == torch.uint8 and not (
image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last)
):
image = image.contiguous(memory_format=torch.channels_last)

strides = image.stride()
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
# There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
# contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
# In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
# to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
# channels_last, thus preserving the memory format of the input. This is not just for format consistency:
# for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
# TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
# we should be able to remove this hack.
new_strides = list(strides)
new_strides[0] = numel
image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

need_cast = dtype not in acceptable_dtypes
if need_cast:
image = image.to(dtype=torch.float32)

Expand Down