Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer(gradient_clip_algorithm='value') has no effect (from #6123) #6920

Closed
ceshine opened this issue Apr 9, 2021 · 7 comments · Fixed by #6928
Closed

Trainer(gradient_clip_algorithm='value') has no effect (from #6123) #6920

ceshine opened this issue Apr 9, 2021 · 7 comments · Fixed by #6928
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@ceshine
Copy link
Contributor

ceshine commented Apr 9, 2021

🐛 Bug

I couldn't find anywhere in the code where the gradient_clip_algorithm argument (implemented in #6123) got passed to Accelerator.clip_gradients method and suspected that the default algorithm (GradClipAlgorithmType.NORM) is always used no matter what.

After a brief investigation, I believe I've confirmed that it is the case and the original test case couldn't correctly detect it.

I'm not sure how to properly fix this bug yet but would like to issue a warning to other users (that only clipping by norm works at this moment).

To Reproduce

This commit firstly disabled the suppression of AssertionError in Trainer.run_train, and then test if the maximum gradient value is almost the same as the set 1e-5 threshold.

I ran the command pytest tests/trainer/test_trainer.py -k "test_gradient_clipping_by_value and not test_gradient_clipping_by_value_fp16" and got this:

FAILED tests/trainer/test_trainer.py::test_gradient_clipping_by_value - AssertionError: Gradient max value 3.6332883155409945e-06 != grad_clip_val 1e-05 .

If we change the default algorithm in PrecisionPlugin.clip_gradients to GradClipAlgorithmType.VALUE, we will pass this test case.

Alternatively, we can directly assert if the clip algorithm is by value in PrecisionPlugin.clip_gradients. We'll get the following error:

FAILED tests/trainer/test_trainer.py::test_gradient_clipping_by_value - AssertionError: GradClipAlgorithmType.NORM

By now we can clearly see that:

  1. Setting gradient_clip_algorithm changes nothing in the training procedure
  2. The original test case cannot distinguish between the two clipping algorithms
  3. The AssertionError in the original test case will be ignored anyway because of the design of Trainer.run_train. (I'm not entirely sure of this one because I'm not familiar with the test environment setup. It appears so in my local environment for sure.)

Environment

  • CUDA:
    - GPU:
    - GeForce RTX 2070
    - available: True
    - version: 11.0
  • Packages:
    - numpy: 1.19.2
    - pyTorch_debug: False
    - pyTorch_version: 1.7.1
    - pytorch-lightning: 1.3.0rc0
    - tqdm: 4.49.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - processor: x86_64
    - python: 3.7.9
    - version: removed reduce on non-loss outputs from dp #78-Ubuntu SMP Fri Mar 19 13:29:52 UTC 2021
@ceshine ceshine added bug Something isn't working help wanted Open to be worked on labels Apr 9, 2021
@ceshine ceshine changed the title Trainer(gradient_clip_algorithm='value' has no effect (from #6123) Trainer(gradient_clip_algorithm='value') has no effect (from #6123) Apr 9, 2021
@kaushikb11 kaushikb11 added the priority: 0 High priority task label Apr 9, 2021
@carmocca
Copy link
Contributor

carmocca commented Apr 9, 2021

cc: @dhkim0225 as the original PR author

@ceshine
Copy link
Contributor Author

ceshine commented Apr 9, 2021

I've created a potential solution to this bug. One major change I made is re-raising the caught RuntimeError or AssertError in Trainer.run_train for the test cases to work. I don't understand why the errors were suppressed there.

@carmocca
Copy link
Contributor

carmocca commented Apr 9, 2021

One major change I made is re-raising the caught RuntimeError or AssertError in Trainer.run_train for the test cases to work. I don't understand why the errors were suppressed there.

This is a bug. Already reported in #6807

@ceshine
Copy link
Contributor Author

ceshine commented Apr 9, 2021

One major change I made is re-raising the caught RuntimeError or AssertError in Trainer.run_train for the test cases to work. I don't understand why the errors were suppressed there.

This is a bug. Already reported in #6807

Thanks for the info! Good the know that I have not misunderstood that part of the code.

@dhkim0225
Copy link
Contributor

@ceshine Thank you for reporting this! I apologize for the inconvenience caused.
@carmocca Maybe this is because I didn't add a GradClipAlgorithm type parameter here. https://github.com/PyTorchLightning/pytorch-lightning/blob/55525031c65756c2e8574cfee5d386e584aecdca/pytorch_lightning/trainer/training_loop.py#L384
I will make a new PR for this.

ceshine added a commit to veritable-tech/pytorch-lightning that referenced this issue Apr 9, 2021
ceshine added a commit to veritable-tech/pytorch-lightning that referenced this issue Apr 9, 2021
@ceshine
Copy link
Contributor Author

ceshine commented Apr 9, 2021

@ceshine Thank you for reporting this! I apologize for the inconvenience caused.
@carmocca Maybe this is because I didn't add a GradClipAlgorithm type parameter here.

https://github.com/PyTorchLightning/pytorch-lightning/blob/55525031c65756c2e8574cfee5d386e584aecdca/pytorch_lightning/trainer/training_loop.py#L384

I will make a new PR for this.

@dhkim0225 Thanks for the reply. I'm afraid that the fix to the problem is more than changing that line. I'm also creating a PR with my changes (to use the CI pipeline). Hopefully, it'll get us to a proper solution faster.

@dhkim0225
Copy link
Contributor

Closing my PR since @ceshine 's PR did all things I wanted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
4 participants