Skip to content

Commit

Permalink
[fbsync] Resize V2 relies on interpolate's native uint8 handling (#7557)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Nicolas Hug <[email protected]>

Reviewed By: vmoens

Differential Revision: D45908452

fbshipit-source-id: a9821a4e1c50b973b2488753a3117faf59ffe585
  • Loading branch information
NicolasHug authored and facebook-github-bot committed May 16, 2023
1 parent 5d9e676 commit fb9842d
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 22 deletions.
37 changes: 27 additions & 10 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,15 @@ def load(self, device):
class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format

def __post_init__(self):
self.spatial_size = self.shape[-2:]
self.num_channels = self.shape[-3]

def load(self, device):
return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format)


NUM_CHANNELS_MAP = {
"GRAY": 1,
Expand All @@ -493,18 +497,21 @@ def make_image_loader(
extra_dims=(),
dtype=torch.float32,
constant_alpha=True,
memory_format=torch.contiguous_format,
):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
def fn(shape, dtype, device, memory_format):
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
data = torch.testing.make_tensor(
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format
)
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value
return datapoints.Image(data)

return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)


make_image = from_loader(make_image_loader)
Expand All @@ -530,11 +537,13 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders)


def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
def make_image_loader_for_interpolation(
size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
def fn(shape, dtype, device, memory_format):
height, width = shape[-2:]

image_pil = (
Expand All @@ -550,19 +559,25 @@ def fn(shape, dtype, device):
)
)

image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
image_tensor = to_image_tensor(image_pil)
if memory_format == torch.contiguous_format:
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else:
image_tensor = image_tensor.to(device=device)
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)

return datapoints.Image(image_tensor)

return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)


def make_image_loaders_for_interpolation(
sizes=((233, 147),),
color_spaces=("RGB",),
dtypes=(torch.uint8,),
memory_formats=(torch.contiguous_format, torch.channels_last),
):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats):
yield make_image_loader_for_interpolation(**params)


Expand Down Expand Up @@ -744,8 +759,10 @@ def make_video_loader(
size = _parse_spatial_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames

def fn(shape, dtype, device):
video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
def fn(shape, dtype, device, memory_format):
video = make_image(
size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format
)
return datapoints.Video(video)

return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
Expand Down
13 changes: 10 additions & 3 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def __init__(
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.CenterCrop,
Expand Down Expand Up @@ -313,6 +315,8 @@ def __init__(
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.RandomErasing,
Expand Down Expand Up @@ -783,7 +787,8 @@ def test_compose(self):
]
)

check_call_consistency(prototype_transform, legacy_transform)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
Expand All @@ -807,7 +812,8 @@ def test_random_apply(self, p, sequence_type):
p=p,
)

check_call_consistency(prototype_transform, legacy_transform)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))

if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
Expand All @@ -832,7 +838,8 @@ def test_random_choice(self, probabilities):
p=probabilities,
)

check_call_consistency(prototype_transform, legacy_transform)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))


class TestToTensorTransforms:
Expand Down
30 changes: 30 additions & 0 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,3 +1365,33 @@ def test_correctness_uniform_temporal_subsample(device):

out_video = F.uniform_temporal_subsample(video, 8)
assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")

output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
assert input.ndim == 3, error_msg_fn
input_stride = input.stride()
output_stride = output.stride()
# Here we check output memory format according to the input:
# if input_stride is (..., 1) then input is most likely channels first and thus
# output strides should match channels first strides (H * W, H, 1)
# if input_stride is (1, ...) then input is most likely channels last and thus
# output strides should match channels last strides (1, W * C, C)
if input_stride[-1] == 1:
expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
assert expected_stride == output_stride, error_msg_fn("")
elif input_stride[0] == 1:
expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
assert expected_stride == output_stride, error_msg_fn("")
else:
assert False, error_msg_fn("")
18 changes: 11 additions & 7 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,31 +1569,35 @@ def reference_inputs_equalize_image_tensor():
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
if dtype.is_floating_point:
low = low_factor
high = high_factor
else:
max_value = torch.iinfo(dtype).max
low = int(low_factor * max_value)
high = int(high_factor * max_value)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
memory_format=memory_format, copy=True
)

def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
image = torch.distributions.Beta(alpha, beta).sample(shape)
if not dtype.is_floating_point:
image.mul_(torch.iinfo(dtype).max).round_()
return image.to(dtype=dtype, device=device)
return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)

spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product(
[torch.uint8],
["GRAY", "RGB"],
[
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
memory_format=memory_format, copy=True
),
lambda shape, dtype, device, memory_format: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
).to(memory_format=memory_format, copy=True),
*[
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
for low_factor, high_factor in [
Expand Down
35 changes: 33 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,47 @@ def resize_image_tensor(
antialias = False

shape = image.shape
numel = image.numel()
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)

if (new_height, new_width) == (old_height, old_width):
return image
elif image.numel() > 0:
elif numel > 0:
image = image.reshape(-1, num_channels, old_height, old_width)

dtype = image.dtype
need_cast = dtype not in (torch.float32, torch.float64)
acceptable_dtypes = [torch.float32, torch.float64]
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
if "AVX2" in torch.backends.cpu.get_cpu_capability():
acceptable_dtypes.append(torch.uint8)

# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
if dtype == torch.uint8 and not (
image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last)
):
image = image.contiguous(memory_format=torch.channels_last)

strides = image.stride()
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
# There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
# contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
# In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
# to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
# channels_last, thus preserving the memory format of the input. This is not just for format consistency:
# for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
# TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
# we should be able to remove this hack.
new_strides = list(strides)
new_strides[0] = numel
image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

need_cast = dtype not in acceptable_dtypes
if need_cast:
image = image.to(dtype=torch.float32)

Expand Down

0 comments on commit fb9842d

Please sign in to comment.