-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[trainer] add tf32-mode control #14606
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As usual, let's operate on a strict no-breaking change rule. I understand PyTorch activates this feature by default when available (at least in some versions), so I would leave the default for --tf32
to None
and let PyTorch decide when tf32=None
.
Then obviously True
and False
will force activate/deactivate the feature.
src/transformers/training_args.py
Outdated
@@ -548,6 +552,12 @@ class TrainingArguments: | |||
default=False, | |||
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"}, | |||
) | |||
tf32: bool = field( | |||
default=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default should be whatever PyTorch has by default, so None here and the user can set it to True or False to force/unforce it.
I understand it's True for versions >= 1.7 and < 1.10 but False after?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good idea! let the user decide!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True for versions >= 1.7 and < 1.11 and probably False after - the nightly is still True as of today.
src/transformers/training_args.py
Outdated
@@ -802,6 +812,9 @@ def __post_init__(self): | |||
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices." | |||
) | |||
|
|||
if is_torch_available() and is_torch_tf32_available(): | |||
torch.backends.cuda.matmul.allow_tf32 = True if self.tf32 else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So here we should only change that boolean if the value set was not None. If the value is True
, there should be an error if is_torch_tf_32_available()
is False
so the user is not surprised if they don't get what they want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for this great feedback, Sylvain. Please have another look.
@@ -492,6 +493,15 @@ def test_mixed_bf16(self): | |||
|
|||
# will add more specific tests once there are some bugs to fix | |||
|
|||
@require_torch_gpu | |||
@require_torch_tf32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, do we have a setup that has the right CUDA version an GPU capabilities?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have rtx-3090 if that's what you ask.
Running benchmarks now - will post those shortly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering for our testing machines on the automatic CI :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one day we will have those newer gpus.
The benchmarks are terrible: #14608 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the update, just left some styling nits but it's great!
@@ -492,6 +493,15 @@ def test_mixed_bf16(self): | |||
|
|||
# will add more specific tests once there are some bugs to fix | |||
|
|||
@require_torch_gpu | |||
@require_torch_tf32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering for our testing machines on the automatic CI :-)
Co-authored-by: Sylvain Gugger <[email protected]>
This PR adds tr32-mode control support for HF Trainer for Ampere cards. RFC: #14450
pytorch had this mode on by default since pt-1.7, but are discussing to turn it off in the coming new release. pytorch/pytorch#67384
Here is the proposed logic:
--tf32 0
will disable it.The PR adds:
is_torch_tf32_available
andrequire_torch_tf32
helper utilsFixes: #14450
@sgugger, @LysandreJik