Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjust_hue now supports inputs of type Tensor #2566

Merged
merged 21 commits into from
Sep 3, 2020
Merged

adjust_hue now supports inputs of type Tensor #2566

merged 21 commits into from
Sep 3, 2020

Conversation

CristianManta
Copy link
Contributor

@CristianManta CristianManta commented Aug 9, 2020

Attempts to solve the bug reported in issue #2563.

Fixes #2563

Blocked by #2586

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 9, 2020

@CristianManta thanks for the PR, the main point of unifying inputs of F.adjust_* methods is to ensure that the result is very close to PIL's one. See for example here for adjust_saturation, adjust_brightness etc. I also started working on this feature but not yet sent a PR, if you would like to make your PR landed I can guide you what should be done exactly to have it merged: code, tests, docs.

@CristianManta
Copy link
Contributor Author

CristianManta commented Aug 10, 2020

@vfdev-5 I would be happy to help on this feature to the best of my ability. I can give it a crack.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 10, 2020

@CristianManta sounds great ! Could you please confirm if you could get it done by this week ?
If so, the next step is to update the test: https://github.com/pytorch/vision/blob/master/test/test_functional_tensor.py#L129
There are two problems if we do it straight forward without any adaptations:

EDIT: Forgot to say that we also have to ensure that everything is torch.jit.script-able

@CristianManta
Copy link
Contributor Author

CristianManta commented Aug 10, 2020

@vfdev-5 I can spend around 2 hours per day on it (including reading about torch.jit.script and other needed stuff) since I have a full-time summer job. I'm also using Python since only 7 months and have a lot to learn/read, so I would say that I would try it tonight and if I'm too lost or feel that I won't be able to complete it by the end of this week, I'll let you know (tonight, Montreal time) so that you can delegate this task to someone else.

So, if I understand correctly, I have to add a few tests in test_adjustments to make sure that the difference between F.adjust_hue and F_t.adjust_hue and the script version is not too big (in a similar fashion to the code you quoted, but with a different factor)? Basically I'll add (F.adjust_hue, F_t.adjust_hue, script_adjust_hue) at line 134 and see how to make the test pass, right?

I'll read about torch.jit.script also, since I never used it before.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 10, 2020

@vfdev-5 I can spend around 2 hours per day on it (including reading about torch.jit.script and other needed stuff) since I have a full-time summer job. I'm also using Python since only 7 months and have a lot to learn/read, so I would say that I would try it tonight and if I'm too lost or feel that I won't be able to complete it by the end of this week, I'll let you know (tonight, Montreal time) so that you can delegate this task to someone else.

@CristianManta no worries, you can start it and I can finish your PR by pushing to your branch.

So, if I understand correctly, I have to add a few tests in test_adjustments to make sure that the difference between F.adjust_hue and F_t.adjust_hue and the script version is not too big (in a similar fashion to the code you quoted, but with a different factor)? Basically I'll add (F.adjust_hue, F_t.adjust_hue, script_adjust_hue) at line 134 and see how to make the test pass, right?

yes, and I know that it wont pass as I checked that locally. That's why

We have to verify PIL and Tensor outputs on a small image and adapt assertions accordingly...

(see my above message).

I'll read about torch.jit.script also, since I never used it before.

There is nothing complicated here, mostly the task is to fixed runtime errors while scripting the function and its subfunctions as probably several variables do not have explicitly declare their types.

@CristianManta
Copy link
Contributor Author

@vfdev-5 I added the appropriate assertions in the tests, but of course there is an error when defining torch.jit.script(F_t.adjust_hue) since I guess I need to modify F_t.adjust_hue to make it scriptable. However, I read the tutorials and examples about converting functions to scripts and I can reproduce them on trivial examples like

import torch                                                                                                                                                                                  
@torch.jit.script                                                               
def func(x):                                                                    
    return x**2                                                                 
                                                                                
def func2(z):                                                                   
    return 2*func(2*z)                                                          
print(func.code)                                                                
                                                                                
sc2 = torch.jit.script(func2)                                                   
print(sc2.code) 

so I'm not sure to understand why it doesn't work with F_t.adjust_hue. I tried adding type declarations (that you can see in my updated PR) since you mentioned the lack of type declarations. Am I supposed to add a type in front of each variable involved in the function to make it work (I doubt that, because the errors were at the same place regardless of if I added those types or not)?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 12, 2020

@CristianManta sorry for delay, yes, you need to do exactly as it is done for example for F_t.adjust_gamma.
Another thing to do is to add type hints to _hsv2rgb and _rgb2hsv functions. In my tests we also have to modify the implementation of _rgb2hsv:

-    s = cr / torch.where(eqc, maxc.new_ones(()), maxc)
+    ones = torch.ones_like(maxc)
+    s = cr / torch.where(eqc, ones, maxc)

-    cr_divisor = torch.where(eqc, maxc.new_ones(()), cr)
+    cr_divisor = torch.where(eqc, ones, cr)

Feel free to ask if you have other questions.

@@ -164,7 +163,7 @@ def adjust_hue(img, hue_factor):
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0

img = _rgb2hsv(img)
img: Tensor = _rgb2hsv(img)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not necessary to add : Tensor

@@ -362,7 +361,7 @@ def _rgb2hsv(img):

cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
s = cr / torch.where(eqc, maxc.new_ones(()), maxc)
s: Tensor = cr / torch.where(eqc, maxc.new_ones(()), maxc)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@CristianManta
Copy link
Contributor Author

CristianManta commented Aug 12, 2020

Ahh I see, yes I recall that the test generated an error coming from _rgb2hsv and I was running around in circles to try to solve it. I'll continue working on it this evening and I'll add the remaining type hints in the functions headers that are being called. I think I'll also add a raise TypeError when the input does not match the type hint of the function arguments.

@@ -714,7 +714,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
return F_t.adjust_saturation(img, saturation_factor)


def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically speaking, the type of hue_factor is torch.Tensor of dtype=whatever_torch.rand()'s_output_type_is, but I guess it would be misleading to put hue_factor: Tensor in the header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, since we're in functional.py, the type of img would be something like Union[Tensor, PIL], but I cannot use PIL as a type so I left img as it is. Not sure what to do. I don't think putting Any is correct either.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the type hints are mostly for torch.jit.script which does not accept type mixing with Union. Let's keep annotations as it is done for other. This is correctdef adjust_hue(img: Tensor, hue_factor: float) -> Tensor:

@CristianManta
Copy link
Contributor Author

CristianManta commented Aug 13, 2020

I think I covered all the type hint problems and I adapted the test to take into account the different range of factor for adjust_hue. The torch.jit.script also seems to work. I also added some comments/remarks to what I just pushed.

I guess the remaining part is to fix the error bound in the assertLess which is now the only source of failure in the tests. I temporarily set it to 5 just to make sure that there are no other issues.

Now, in order to fix it, I'm wondering what approach I should follow. The 5 / 255 factor seems to come from the worst case when there is a conversion from uint8 to torch.float which can be off by at most 1 (floor/ceiling rounding error) before dividing by 255, hence an error of 1 / 255 for one such conversion. But why 5 times? Also, where does the 1e-5 factor come from?

If all of this is correct, then isn't changing the assertion cheating? Shouldn't we rather change F_t.adjust_hue to make it more precise?

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CristianManta thanks for working on the PR ! I left some comments.

@@ -146,25 +148,36 @@ def test_adjustments(self):
img = torch.randint(0, 256, shape, dtype=torch.uint8)

factor = 3 * torch.rand(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As factor should be a float for all ops, let's just apply .item() on it.

img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
f_img = transforms.ToTensor()(f_img_pil)
for i, (f, ft, sft) in enumerate(fns):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be more simple to check the type of f like f == F.adjust_hue and redefine factor as factor = torch.rand(1).item() - 0.5 instead of doing if-else with almost the same code for both cases...

@@ -714,7 +714,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
return F_t.adjust_saturation(img, saturation_factor)


def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the type hints are mostly for torch.jit.script which does not accept type mixing with Union. Let's keep annotations as it is done for other. This is correctdef adjust_hue(img: Tensor, hue_factor: float) -> Tensor:

torchvision/transforms/functional.py Show resolved Hide resolved
@@ -118,7 +118,7 @@ def adjust_saturation(img, saturation_factor):


@torch.jit.unused
def adjust_hue(img, hue_factor):
def adjust_hue(img, hue_factor: float):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary. Otherwise, you can set it as def adjust_hue(img: Any, hue_factor: float):

torchvision/transforms/functional_tensor.py Show resolved Hide resolved
@@ -344,7 +344,9 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)


def _rgb2hsv(img):
def _rgb2hsv(img: Tensor) -> Tensor:
if not isinstance(img, torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_rgb2hsv is a private method, I think we can skip type checking and raising error here.
Btw, functional_tensor.py module is intended to be private too, user should not use those functions in the code. Warnings will be added soon according to #2547

@@ -380,7 +383,9 @@ def _rgb2hsv(img):
return torch.stack((h, s, maxc))


def _hsv2rgb(img):
def _hsv2rgb(img: Tensor) -> Tensor:
if not isinstance(img, torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max()
max_diff_scripted = (sft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertLess(max_diff, 5 + 1e-5) # TODO: 5 / 255, not 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, we can compare number of different pixels like it is done here:

num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 13, 2020

Now, in order to fix it, I'm wondering what approach I should follow. The 5 / 255 factor seems to come from the worst case when there is a conversion from uint8 to torch.float which can be off by at most 1 (floor/ceiling rounding error) before dividing by 255, hence an error of 1 / 255 for one such conversion. But why 5 times? Also, where does the 1e-5 factor come from?
If all of this is correct, then isn't changing the assertion cheating? Shouldn't we rather change F_t.adjust_hue to make it more precise?

I was thinking, firstly, to see how different two outputs (PIL and Tensor) are, do we have majority of exact values and few different values or they are all different ? If you could compute F.adjust_hue on a small image and tensor of size 5x5 (e.g. torch.arange(3 * 5 * 5).reshape(3, 5, 5) and post here few values, it would be nice.

Next, we can take a look at F_t.adjust_hue implementation and see if it could be slightly changed to match PIL result.

Finally, we adapt the test such that either we compute the number of different pixels and raise an error if above some threshold. Either, rework current checking of max diff and leave comment what we check.

What do you think ?

@CristianManta
Copy link
Contributor Author

That's a good idea, I'll compute the distances between pairs of images in the 2 norms you mentioned and this should give us a clue more about what's wrong. My suspicions is that many pixels will be different, slightly exceeding the established max difference each.

I'll compute F.adjust_hue on a PIL image, convert to Tensor and re-compute F.adjust_hue on a Tensor (representing the same image) and measure the distances between the 2 outputs. I'll repeat that a few times.

@CristianManta
Copy link
Contributor Author

CristianManta commented Aug 14, 2020

I fixed the mistakes that you mentioned in the previous review. I also wrote a script to get some statistics about the error differences. If you want to run it, you need to place it one level above the vision directory:

import torch
import vision.torchvision.transforms.functional as F
import vision.torchvision.transforms.transforms as transforms
import vision.torchvision.transforms.functional_tensor as F_t


# Parameters
tol = 1e-5
k = 20

def get_number_of_diff_pixels(img1, img2):
    W = img1.shape[1]
    H = img1.shape[2]
    counter = 0
    for i in range(W):
        for j in range(H):
            for ch in range(3):
                if abs(img1[ch][i][j].item() - img2[ch][i][j].item()) > tol:
                    counter += 1
                    print('The discrepancy was in channel: {} for pixel ({}, {})'.format(ch, i, j))
                    break
    return counter


tensor_img = torch.arange(3 * 5 * 5).reshape(3, 5, 5).to(torch.float) / 255
pil_img = transforms.ToPILImage()(tensor_img)

# first testing if composing the ToPILImage and its inverse produces the 
# original image indeed

test_inverse_pil_img = transforms.ToTensor()(pil_img)
if (test_inverse_pil_img != tensor_img).sum().item() / 3.0 == 0:
    print('ToPILImage composed with its inverse doesn\'t lose accuracy.\n')
else:
    print('ToPILImage composed with its inverse loses accuracy.\n')


for i in range(k + 1):
    hue_factor = -0.5 + i / k
    out_tensor = F.adjust_hue(tensor_img, hue_factor)
    out_pil = F.adjust_hue(pil_img, hue_factor)
    out_pil_to_tensor = transforms.ToTensor()(out_pil)

    # num_diff_pixels = (out_pil_to_tensor != out_tensor).sum().item() / 3.0
    num_diff_pixels = get_number_of_diff_pixels(out_pil_to_tensor, out_tensor)
    print('hue_factor = {}: number of different pixels = {}'.format(hue_factor, num_diff_pixels))
    pct = 100.0 * num_diff_pixels / (tensor_img.shape[1] * tensor_img.shape[2])
    print('percentage of pixels that are different: {}%\n'.format(pct))
    print('------------------------------------------------\n')

# testing if _rgb2hsv and _hsv2rgb are precise:
hsv = F_t._rgb2hsv(tensor_img)
rgb = F_t._hsv2rgb(hsv)

diff = get_number_of_diff_pixels(rgb, tensor_img)
print('diff = {}'.format(diff))

On some relatively mild tolerance tol, on most hue_factor values tried, all 25 pixels differed by more than tol when comparing the original tensor to the one got using the PIL transformations. For instance, if you set tol = 1e-5 (even 1e-4), it looks pretty bad. In particular, when i = 10, hue_factor = 0 and the image is supposed to be left unchanged. Yet, all 25 pixels differ. The problem doesn't seem to favour one channel in particular.

I also tested to see if the 2 conversions _rgb2hsv and its inverse have bad accuracy by composing one by the other and comparing with the original tensor. A difference can already be noticed with tol = 1e-7.

Sorry if the output of the script isn't pretty. I didn't have time to make it prettier

scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(saturation=factor.item())
f = transforms.ColorJitter(hue=abs(hue_factor))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you use abs here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the ColorJitter class only takes hue inputs that are non negative:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also had to create 2 new variables, bcs_factor and hue_factor because after the inner loop, once factor has been overwritten, I still needed the original factor and hue_factor

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 14, 2020

@CristianManta I was working on the test of another color transform: rgb_to_grayscale and as some of adjust_* uses it I think it would be better to refactor this test_adjustments into separate tests and follow the same test schema as for test_adjust_gamma. Here is my working branch: d016cab on that.
I think it would be better to rewrite the test of adjust_hue in the same way.

…h by simple copy/paste and added the test_adjust_hue and ColorJitter class interface test in the same style (class interface test was removed in vfdev-5's branch for some reason).
@CristianManta
Copy link
Contributor Author

Ok, so can I also remove test_adjustments from there since we already do those tests in test_functional_tensor.py?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 15, 2020

Yes, you can remove test_adjustments from test_transforms_tensor.py, but we need to add a test for ColorJitter with all 4 options: brightness, contrast, saturation and hue.

… ColorJitter class interface test in test_transforms_tensor.py.
@CristianManta
Copy link
Contributor Author

CristianManta commented Aug 19, 2020

@vfdev-5 There is still one more test failing, test_adjust_hue reports that the difference between pil and tensor is too big still. I tried fixing this by using colorsys instead of _rgb2hsv and _hsv2rgb to perhaps reduce error coming from them, but colorsys functions are not scriptable. What should I do?

@fmassa
Copy link
Member

fmassa commented Aug 20, 2020

@CristianManta I think there might be some changes that went in this PR that are unrelated (in the tests).

About the differences from _rgb2hsv and _hsv2rgb, do they also fail on master? I think I've seen some occasional failures.
In the end, our implementation should follow more closely the official definitions. PIL does some approximations for working with uint8 images, so I'm not so worried if it fails for some input types.

@CristianManta
Copy link
Contributor Author

@fmassa I realize that since this PR was opened 12 days ago, many changes were merged into master since then, so it's worth re-trying with the new changes. Also, about the unrelated changes that went into this PR, I think it's because I copied some stuff from d016cab on which @vfdev-5 was working since his changes were related somewhat.

Tonight I'll save my 4 changed files elsewhere, copy the new versions of them from master into my branch, and manually re-change only the parts that are relevant to adjust_hue, this should hopefully resolve conflicts as well. I'll re-run the tests afterwards

@CristianManta
Copy link
Contributor Author

CristianManta commented Aug 21, 2020

@fmassa I just updated the PR and solved the conflict. I kept the d016cab style for testing the adjustments, and incorporated the cuda case. And of course, _test_adjust_hue (the goal of this PR) is still there.

I saw that, in fact, only cuda test coverage has been merged in master in torchvision/transforms/functional_tensor since the last time and so the adjust_hue is still the same. There's still the same (only) failure from test_functional_tensor:

FAIL: test_adjustments (test_functional_tensor.Tester)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/.../vision/test/test_functional_tensor.py", line 232, in test_adjustments
    self._test_adjust_hue("cpu")
  File "/.../vision/test/test_functional_tensor.py", line 225, in _test_adjust_hue
    device=device
  File "/.../vision/test/test_functional_tensor.py", line 187, in _test_adjust_fn
    msg="{}: tensor:\n{} \ndid not equal PIL tensor:\n{}".format(msg, rbg_tensor, adjusted_pil_tensor)
AssertionError: 9.0 not less than or equal to 1.0 : None, {'hue_factor': -0.5}: tensor:
tensor([[[236.,   3., 122.,  ..., 178., 126., 146.],
         [222.,  94.,  52.,  ..., 229., 112.,  17.],
         [ 79., 143., 161.,  ...,  69., 204.,  19.],
         ...,
         [ 73., 131., 223.,  ..., 239.,  17., 140.],
         [ 79., 115.,  59.,  ..., 223.,  73., 205.],
         [169.,  15., 213.,  ..., 169., 186.,  42.]],

        [[162., 132., 172.,  ..., 149., 245., 202.],
         [ 44.,  97.,  12.,  ...,  62.,  80.,  70.],
         [ 67.,  11., 185.,  ..., 217., 202., 134.],
         ...,
         [107.,  42., 226.,  ...,  66.,  26., 123.],
         [117., 247., 239.,  ..., 172., 134.,  79.],
         [155., 142., 183.,  ...,   4., 185., 184.]],

        [[ 12., 219., 114.,  ..., 202.,  91.,  55.],
         [145., 183., 143.,  ..., 225.,  60.,  72.],
         [ 40., 245., 249.,  ...,  42.,  92.,  13.],
         ...,
         [203., 220.,  64.,  ...,  21., 207.,  27.],
         [138.,  22., 202.,  ..., 210., 159., 247.],
         [ 47., 189.,  41.,  ...,  77., 231., 232.]]]) 
did not equal PIL tensor:
tensor([[[236.,   3., 122.,  ..., 179., 127., 145.],
         [222.,  94.,  56.,  ..., 229., 112.,  17.],
         [ 79., 149., 161.,  ...,  68., 204.,  20.],
         ...,
         [ 74., 134., 224.,  ..., 239.,  19., 140.],
         [ 79., 117.,  59.,  ..., 223.,  74., 210.],
         [169.,  16., 213.,  ..., 169., 189.,  44.]],

        [[165., 125., 172.,  ..., 150., 245., 202.],
         [ 44.,  94.,  13.,  ...,  63.,  80.,  69.],
         [ 67.,  12., 184.,  ..., 217., 203., 134.],
         ...,
         [104.,  42., 226.,  ...,  68.,  23., 123.],
         [115., 247., 239.,  ..., 172., 134.,  80.],
         [156., 138., 183.,  ...,   5., 186., 181.]],

        [[ 12., 219., 115.,  ..., 202.,  91.,  56.],
         [141., 183., 143.,  ..., 219.,  60.,  72.],
         [ 41., 245., 249.,  ...,  43.,  92.,  14.],
         ...,
         [203., 220.,  65.,  ...,  22., 207.,  27.],
         [138.,  22., 203.,  ..., 209., 159., 247.],
         [ 47., 189.,  43.,  ...,  74., 231., 232.]]])

----------------------------------------------------------------------

@fmassa
Copy link
Member

fmassa commented Aug 31, 2020

@vfdev-5 could you help take this PR into completion?

@fmassa fmassa mentioned this pull request Sep 1, 2020
16 tasks
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 2, 2020

Travis failed: https://travis-ci.org/github/pytorch/vision/jobs/723374281

/functional_tensor.py", line 197, in adjust_hue
E               img = _rgb2hsv(img)
E               h, s, v = img.unbind(0)
E               h += hue_factor
E               ~~~~~~~~~~~~~~~ <--- HERE
E               h = h % 1.0
E               img = torch.stack((h, s, v))
E           RuntimeError: Output 0 of UnbindBackward is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

EDIT: fixed in c58d151

@vfdev-5 vfdev-5 requested a review from fmassa September 2, 2020 12:30
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more comment, otherwise looks good to me

torchvision/transforms/functional_tensor.py Outdated Show resolved Hide resolved
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @CristianManta and @vfdev-5 !

@fmassa fmassa merged commit bb88c45 into pytorch:master Sep 3, 2020
@yaox12
Copy link
Contributor

yaox12 commented Sep 3, 2020

Hi, I think the current F_t.adjust_hue cannot process images in batchs, which is inconsistent with F_t.adjust_saturation, F_t.adjust_brightness etc.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 3, 2020

@yaox12 yes, current implementation should be improved for that case. Batch support will be covered by #2583

bryant1410 pushed a commit to bryant1410/vision-1 that referenced this pull request Nov 22, 2020
* adjust_hue now supports inputs of type Tensor

* Added comparison between original adjust_hue and its Tensor and torch.jit.script versions.

* Added a few type checkings related to adjust_hue in functional_tensor.py in hopes to make F_t.adjust_hue scriptable...but to no avail.

* Changed implementation of _rgb2hsv and removed useless type declaration according to PR's review.

* Handled the range of hue_factor in the assertions and temporarily increased the assertLess bound to make sure that no other test fails.

* Fixed some lint issues with CircleCI and added type hints in functional_pil.py as well.

* Corrected type hint mistakes.

* Followed PR review recommendations and added test for class interface with hue.

* Refactored test_functional_tensor.py to match vfdev-5's d016cab branch by simple copy/paste and added the test_adjust_hue and ColorJitter class interface test in the same style (class interface test was removed in vfdev-5's branch for some reason).

* Removed test_adjustments from test_transforms_tensor.py and moved the ColorJitter class interface test in test_transforms_tensor.py.

* Added cuda test cases for test_adjustments and tried to fix conflict.

* Updated tests
- adjust hue
- color jitter

* Fixes incompatible devices

* Increased tol for cuda tests

* Fixes potential issue with inplace op
- fixes irreproducible failing test on Travis CI

* Reverted fmod -> %

Co-authored-by: vfdev-5 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: ColorJitter in torchvision.transforms
4 participants