Skip to content

Commit

Permalink
[fbsync] [proto] Speed improvements for adjust hue op (#6805)
Browse files Browse the repository at this point in the history
Summary:
* WIP

* Updated rgb2hsv and a bit of hsv2rgb

* Fix issue with batch of images

* Few improvements

* hsv2rgb improvements

* PR review

* another update

* Fix cuda issue with empty images
torch.aminmax is failing

Reviewed By: YosuaMichael

Differential Revision: D40722899

fbshipit-source-id: 59edbba970a015fbc58c26828b36197945f46080

Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
2 people authored and facebook-github-bot committed Oct 27, 2022
1 parent 27fc3e6 commit 0cc1c75
Showing 1 changed file with 98 additions and 1 deletion.
99 changes: 98 additions & 1 deletion torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,104 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)


adjust_hue_image_tensor = _FT.adjust_hue
def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
r, g, _ = image.unbind(dim=-3)

# Implementation is based on
# https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330
minc, maxc = torch.aminmax(image, dim=-3)

# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occuring so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc

channels_range = maxc - minc
# Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = channels_range / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3)
rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3)

mask_maxc_neq_r = maxc != r
mask_maxc_eq_g = maxc == g
mask_maxc_neq_g = ~mask_maxc_eq_g

hr = (bc - gc).mul_(~mask_maxc_neq_r)
hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)

h = hr.add_(hg).add_(hb)
h = h.div_(6.0).add_(1.0).fmod_(1.0)
return torch.stack((h, s, maxc), dim=-3)


def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h * 6
i = torch.floor(h6)
f = (h6) - i
i = i.to(dtype=torch.int32)

p = (v * (1.0 - s)).clamp_(0.0, 1.0)
q = (v * (1.0 - s * f)).clamp_(0.0, 1.0)
t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0)
i.remainder_(6)

mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)

a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)

return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3)


def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")

if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")

c = get_num_channels_image_tensor(image)

if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")

if c == 1: # Match PIL behaviour
return image

if image.numel() == 0:
# exit earlier on empty images
return image

orig_dtype = image.dtype
if image.dtype == torch.uint8:
image = image / 255.0

image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3)
h.add_(hue_factor).remainder_(1.0)
image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image)

if orig_dtype == torch.uint8:
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)

return image_hue_adj


adjust_hue_image_pil = _FP.adjust_hue


Expand Down

0 comments on commit 0cc1c75

Please sign in to comment.