-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adjust_hue now supports inputs of type Tensor #2566
Changes from 7 commits
653d2fd
ff93f78
b565af5
535d1de
5f9ba27
3f6c5a5
270e02e
3b072db
bdae1cc
70cff26
bb7ec8c
aead63d
e28e558
d4dd848
8551011
71185bd
3f23938
38f33e7
e8b5f28
c58d151
a143835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -130,10 +130,12 @@ def test_adjustments(self): | |||
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness) | ||||
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast) | ||||
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation) | ||||
script_adjust_hue = torch.jit.script(F_t.adjust_hue) | ||||
|
||||
fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness), | ||||
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast), | ||||
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation)) | ||||
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation), | ||||
(F.adjust_hue, F_t.adjust_hue, script_adjust_hue)) | ||||
|
||||
for _ in range(20): | ||||
channels = 3 | ||||
|
@@ -146,25 +148,36 @@ def test_adjustments(self): | |||
img = torch.randint(0, 256, shape, dtype=torch.uint8) | ||||
|
||||
factor = 3 * torch.rand(1) | ||||
hue_factor = torch.add(torch.rand(1), -0.5) | ||||
img_clone = img.clone() | ||||
for f, ft, sft in fns: | ||||
|
||||
ft_img = ft(img, factor) | ||||
sft_img = sft(img, factor) | ||||
if not img.dtype.is_floating_point: | ||||
ft_img = ft_img.to(torch.float) / 255 | ||||
sft_img = sft_img.to(torch.float) / 255 | ||||
|
||||
img_pil = transforms.ToPILImage()(img) | ||||
f_img_pil = f(img_pil, factor) | ||||
f_img = transforms.ToTensor()(f_img_pil) | ||||
for i, (f, ft, sft) in enumerate(fns): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it would be more simple to check the type of |
||||
if i == 3: | ||||
ft_img = ft(img, hue_factor) | ||||
sft_img = sft(img, hue_factor) | ||||
if not img.dtype.is_floating_point: | ||||
ft_img = ft_img.to(torch.float) / 255 | ||||
sft_img = sft_img.to(torch.float) / 255 | ||||
|
||||
img_pil = transforms.ToPILImage()(img) | ||||
f_img_pil = f(img_pil, hue_factor) | ||||
f_img = transforms.ToTensor()(f_img_pil) | ||||
else: | ||||
ft_img = ft(img, factor) | ||||
sft_img = sft(img, factor) | ||||
if not img.dtype.is_floating_point: | ||||
ft_img = ft_img.to(torch.float) / 255 | ||||
sft_img = sft_img.to(torch.float) / 255 | ||||
|
||||
img_pil = transforms.ToPILImage()(img) | ||||
f_img_pil = f(img_pil, factor) | ||||
f_img = transforms.ToTensor()(f_img_pil) | ||||
|
||||
# F uses uint8 and F_t uses float, so there is a small | ||||
# difference in values caused by (at most 5) truncations. | ||||
max_diff = (ft_img - f_img).abs().max() | ||||
max_diff_scripted = (sft_img - f_img).abs().max() | ||||
self.assertLess(max_diff, 5 / 255 + 1e-5) | ||||
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) | ||||
self.assertLess(max_diff, 5 + 1e-5) # TODO: 5 / 255, not 5 | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, we can compare number of different pixels like it is done here: vision/test/test_functional_tensor.py Line 587 in 270e02e
|
||||
self.assertLess(max_diff_scripted, 5 + 1e-5) # TODO: idem | ||||
self.assertTrue(torch.equal(img, img_clone)) | ||||
|
||||
# test for class interface | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -714,7 +714,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: | |
return F_t.adjust_saturation(img, saturation_factor) | ||
|
||
|
||
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically speaking, the type of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, since we're in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the type hints are mostly for |
||
def adjust_hue(img, hue_factor: float): | ||
"""Adjust hue of an image. | ||
|
||
The image hue is adjusted by converting the image to HSV and | ||
|
@@ -729,20 +729,20 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: | |
.. _Hue: https://en.wikipedia.org/wiki/Hue | ||
|
||
Args: | ||
img (PIL Image): PIL Image to be adjusted. | ||
img (PIL Image or Tensor): Image to be adjusted. | ||
hue_factor (float): How much to shift the hue channel. Should be in | ||
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in | ||
HSV space in positive and negative direction respectively. | ||
0 means no shift. Therefore, both -0.5 and 0.5 will give an image | ||
with complementary colors while 0 gives the original image. | ||
|
||
Returns: | ||
PIL Image: Hue adjusted image. | ||
PIL Image or Tensor: Hue adjusted image. | ||
""" | ||
if not isinstance(img, torch.Tensor): | ||
return F_pil.adjust_hue(img, hue_factor) | ||
|
||
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | ||
return F_t.adjust_hue(img, hue_factor) | ||
|
||
|
||
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: | ||
CristianManta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,7 +118,7 @@ def adjust_saturation(img, saturation_factor): | |
|
||
|
||
@torch.jit.unused | ||
def adjust_hue(img, hue_factor): | ||
def adjust_hue(img, hue_factor: float): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not necessary. Otherwise, you can set it as |
||
"""Adjust hue of an image. | ||
|
||
The image hue is adjusted by converting the image to HSV and | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,7 +129,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: | |
return _blend(img, mean, contrast_factor) | ||
|
||
|
||
def adjust_hue(img, hue_factor): | ||
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: | ||
CristianManta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Adjust hue of an image. | ||
|
||
The image hue is adjusted by converting the image to HSV and | ||
|
@@ -344,7 +344,9 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: | |
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) | ||
|
||
|
||
def _rgb2hsv(img): | ||
def _rgb2hsv(img: Tensor) -> Tensor: | ||
if not isinstance(img, torch.Tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
raise TypeError("img should be of type torch.Tensor. Got {}".format(type(img))) | ||
r, g, b = img.unbind(0) | ||
|
||
maxc = torch.max(img, dim=0).values | ||
|
@@ -362,12 +364,13 @@ def _rgb2hsv(img): | |
|
||
cr = maxc - minc | ||
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. | ||
s = cr / torch.where(eqc, maxc.new_ones(()), maxc) | ||
ones = torch.ones_like(maxc) | ||
s = cr / torch.where(eqc, ones, maxc) | ||
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation | ||
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it | ||
# would not matter what values `rc`, `gc`, and `bc` have here, and thus | ||
# replacing denominator with 1 when `eqc` is fine. | ||
cr_divisor = torch.where(eqc, maxc.new_ones(()), cr) | ||
cr_divisor = torch.where(eqc, ones, cr) | ||
rc = (maxc - r) / cr_divisor | ||
gc = (maxc - g) / cr_divisor | ||
bc = (maxc - b) / cr_divisor | ||
|
@@ -380,7 +383,9 @@ def _rgb2hsv(img): | |
return torch.stack((h, s, maxc)) | ||
|
||
|
||
def _hsv2rgb(img): | ||
def _hsv2rgb(img: Tensor) -> Tensor: | ||
if not isinstance(img, torch.Tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
raise TypeError("img should be of type torch.Tensor. Got {}".format(type(img))) | ||
h, s, v = img.unbind(0) | ||
i = torch.floor(h * 6.0) | ||
f = (h * 6.0) - i | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As factor should be a float for all ops, let's just apply
.item()
on it.