Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sequence fill support for ElasticTransform #7141

Merged
merged 4 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,3 +858,35 @@ def test_gaussian_blur(device, channels, meth_kwargs):
agg_method="max",
tol=tol,
)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fill",
[
1,
1.0,
[1],
[1.0],
(1,),
(1.0,),
[1, 2, 3],
[1.0, 2.0, 3.0],
(1, 2, 3),
(1.0, 2.0, 3.0),
],
)
@pytest.mark.parametrize("channels", [1, 3])
def test_elastic_transform(device, channels, fill):
if isinstance(fill, (list, tuple)) and len(fill) > 1 and channels == 1:
# For this the test would correctly fail, since the number of channels in the image does not match `fill`.
# Thus, this is not an issue in the transform, but rather a problem of parametrization that just gives the
# product of `fill` and `channels`.
return

_test_class_op(
T.ElasticTransform,
meth_kwargs=dict(fill=fill),
channels=channels,
device=device,
)
2 changes: 0 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,8 +1539,6 @@ def elastic_transform(
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or str or tuple value is supported for PIL Image.
Comment on lines -1542 to -1543
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by since I was looking into the fill support. This seems to be a copy-paste error. Internally we just convert PIL images to tensors

t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
t_img = pil_to_tensor(img)

and then call the tensor kernel:

output = F_t.elastic_transform(
t_img,
displacement,
interpolation=interpolation.value,
fill=fill,
)

Meaning, there is no difference between both types.

"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(elastic_transform)
Expand Down
8 changes: 6 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2104,8 +2104,12 @@ def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINE
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation

if not isinstance(fill, (int, float)):
raise TypeError(f"fill should be int or float. Got {type(fill)}")
if isinstance(fill, (int, float)):
fill = [float(fill)]
elif isinstance(fill, (list, tuple)):
fill = [float(f) for f in fill]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to convert to float?

Copy link
Collaborator Author

@pmeier pmeier Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we do due to JIT 🥲 Removing the float conversion from L2110 gives us

import torch.jit

from torchvision import transforms

torch.jit.script(transforms.ElasticTransform(fill=[1]))
[...]
Expected a value of type 'Optional[List[float]]' for argument 'fill' but instead found type 'List[int]'.
[...]

I know this is ugly AF and far from being Pythonic, but given that it is on v1 I really don't want to deal with this any more than I have to.

else:
raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
self.fill = fill

@staticmethod
Expand Down