Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjust_hue now supports inputs of type Tensor #2566

Merged
merged 21 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
653d2fd
adjust_hue now supports inputs of type Tensor
CristianManta Aug 9, 2020
ff93f78
Added comparison between original adjust_hue and its Tensor and torch…
CristianManta Aug 11, 2020
b565af5
Added a few type checkings related to adjust_hue in functional_tensor…
CristianManta Aug 11, 2020
535d1de
Changed implementation of _rgb2hsv and removed useless type declarati…
CristianManta Aug 12, 2020
5f9ba27
Handled the range of hue_factor in the assertions and temporarily inc…
CristianManta Aug 13, 2020
3f6c5a5
Fixed some lint issues with CircleCI and added type hints in function…
CristianManta Aug 13, 2020
270e02e
Corrected type hint mistakes.
CristianManta Aug 13, 2020
3b072db
Followed PR review recommendations and added test for class interface…
CristianManta Aug 14, 2020
bdae1cc
Refactored test_functional_tensor.py to match vfdev-5's d016cab branc…
CristianManta Aug 15, 2020
70cff26
Removed test_adjustments from test_transforms_tensor.py and moved the…
CristianManta Aug 15, 2020
bb7ec8c
Added cuda test cases for test_adjustments and tried to fix conflict.
CristianManta Aug 21, 2020
aead63d
Merge branch 'master' into adjust_hue_tensor
CristianManta Aug 21, 2020
e28e558
[WIP] Merge branch 'master' of https://github.com/pytorch/vision into…
vfdev-5 Sep 1, 2020
d4dd848
Updated tests
vfdev-5 Sep 1, 2020
8551011
Fixes incompatible devices
vfdev-5 Sep 1, 2020
71185bd
Increased tol for cuda tests
vfdev-5 Sep 2, 2020
3f23938
Merge branch 'master' of https://github.com/pytorch/vision into adjus…
vfdev-5 Sep 2, 2020
38f33e7
Merge branch 'master' of https://github.com/pytorch/vision into adjus…
vfdev-5 Sep 2, 2020
e8b5f28
Merge branch 'master' of github.com:pytorch/vision into cm/adjust_hue…
vfdev-5 Sep 2, 2020
c58d151
Fixes potential issue with inplace op
vfdev-5 Sep 2, 2020
a143835
Reverted fmod -> %
vfdev-5 Sep 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
script_adjust_hue = torch.jit.script(F_t.adjust_hue)

fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation),
(F.adjust_hue, F_t.adjust_hue, script_adjust_hue))

for _ in range(20):
channels = 3
Expand All @@ -146,25 +148,36 @@ def test_adjustments(self):
img = torch.randint(0, 256, shape, dtype=torch.uint8)

factor = 3 * torch.rand(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As factor should be a float for all ops, let's just apply .item() on it.

hue_factor = torch.add(torch.rand(1), -0.5)
img_clone = img.clone()
for f, ft, sft in fns:

ft_img = ft(img, factor)
sft_img = sft(img, factor)
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255

img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
f_img = transforms.ToTensor()(f_img_pil)
for i, (f, ft, sft) in enumerate(fns):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be more simple to check the type of f like f == F.adjust_hue and redefine factor as factor = torch.rand(1).item() - 0.5 instead of doing if-else with almost the same code for both cases...

if i == 3:
ft_img = ft(img, hue_factor)
sft_img = sft(img, hue_factor)
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255

img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, hue_factor)
f_img = transforms.ToTensor()(f_img_pil)
else:
ft_img = ft(img, factor)
sft_img = sft(img, factor)
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255

img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
f_img = transforms.ToTensor()(f_img_pil)

# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max()
max_diff_scripted = (sft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertLess(max_diff, 5 + 1e-5) # TODO: 5 / 255, not 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, we can compare number of different pixels like it is done here:

num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0

self.assertLess(max_diff_scripted, 5 + 1e-5) # TODO: idem
self.assertTrue(torch.equal(img, img_clone))

# test for class interface
Expand Down
8 changes: 4 additions & 4 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
return F_t.adjust_saturation(img, saturation_factor)


def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically speaking, the type of hue_factor is torch.Tensor of dtype=whatever_torch.rand()'s_output_type_is, but I guess it would be misleading to put hue_factor: Tensor in the header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, since we're in functional.py, the type of img would be something like Union[Tensor, PIL], but I cannot use PIL as a type so I left img as it is. Not sure what to do. I don't think putting Any is correct either.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the type hints are mostly for torch.jit.script which does not accept type mixing with Union. Let's keep annotations as it is done for other. This is correctdef adjust_hue(img: Tensor, hue_factor: float) -> Tensor:

def adjust_hue(img, hue_factor: float):
"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
Expand All @@ -729,20 +729,20 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
.. _Hue: https://en.wikipedia.org/wiki/Hue

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Tensor): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.

Returns:
PIL Image: Hue adjusted image.
PIL Image or Tensor: Hue adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)

raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return F_t.adjust_hue(img, hue_factor)


def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
CristianManta marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def adjust_saturation(img, saturation_factor):


@torch.jit.unused
def adjust_hue(img, hue_factor):
def adjust_hue(img, hue_factor: float):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary. Otherwise, you can set it as def adjust_hue(img: Any, hue_factor: float):

"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
Expand Down
15 changes: 10 additions & 5 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
return _blend(img, mean, contrast_factor)


def adjust_hue(img, hue_factor):
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
CristianManta marked this conversation as resolved.
Show resolved Hide resolved
"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
Expand Down Expand Up @@ -344,7 +344,9 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)


def _rgb2hsv(img):
def _rgb2hsv(img: Tensor) -> Tensor:
if not isinstance(img, torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_rgb2hsv is a private method, I think we can skip type checking and raising error here.
Btw, functional_tensor.py module is intended to be private too, user should not use those functions in the code. Warnings will be added soon according to #2547

raise TypeError("img should be of type torch.Tensor. Got {}".format(type(img)))
r, g, b = img.unbind(0)

maxc = torch.max(img, dim=0).values
Expand All @@ -362,12 +364,13 @@ def _rgb2hsv(img):

cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
s = cr / torch.where(eqc, maxc.new_ones(()), maxc)
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, maxc.new_ones(()), cr)
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
Expand All @@ -380,7 +383,9 @@ def _rgb2hsv(img):
return torch.stack((h, s, maxc))


def _hsv2rgb(img):
def _hsv2rgb(img: Tensor) -> Tensor:
if not isinstance(img, torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

raise TypeError("img should be of type torch.Tensor. Got {}".format(type(img)))
h, s, v = img.unbind(0)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
Expand Down