-
Notifications
You must be signed in to change notification settings - Fork 614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
exclude_from_weight_decay for AdamW and SGDW #2624
Conversation
You are owner of some files modified in this pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Would you mind adding test cases for your changes to https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py?
def _do_use_weight_decay(self, var): | ||
"""Whether to use L2 weight decay for `var`.""" | ||
if not self._decay_var_list or var.ref() in self._decay_var_list: | ||
if self.exclude_from_weight_decay: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we maybe factor this part out into a function in optimizer/utils.py as exclude_from_weight_decay(variable, exclude_regexes: List[str]) -> bool
? This would ensure consistency between LAMB and potential future optimizers with exclude lists for weight decay and reduce nesting.
Also, I'd consider it to be more intuitive to have decay_var_list having a higher priority over exclude_from_weight decay since it's a function argument (i.e., if decay_var_list is specified, always decay those independent of a match in exclude_weight_decay). WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I added a function
is_variable_excluded_by_regexes(variable, exclude_regexes: List[str]) -> bool
inoptimizer/utils.py
, as it can also be used inlamb.py _do_layer_adaptation
. I will change this name if not comfortable for you. It's mostly called asnot is_variable_excluded_by_regexes(...)
, or maybe prefer a function likeis_variable_not_excluded_by_regexes
? WDYT? - I changed
_do_use_weight_decay
inoptimizers/weight_decay_optimizers.py
thatif self._decay_var_list and var.ref() in self._decay_var_list
will returnTrue
. Makingdecay_var_list
higher priority thanexclude_from_weight_decay
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for factoring this out and making decay_var_list higher priority!
How about renaming is_variable_excluded_by_regexes
to is_variable_matched_by_regexes(variable, regexes)
? I feel this is a bit easier to parse. I think having the not in front is fine as it keeps the function name simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya, ranamed is_variable_excluded_by_regexes
to is_variable_matched_by_regexes
, also its param exclude_regexes
to just regexes
.
return True | ||
return False | ||
|
||
def _get_variable_name(self, param_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this or wouldn't re.search find the substring anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this, as this is from lamb.py
. I kept a function get_variable_name
in optimizers/utils.py
but not using it in is_variable_excluded_by_regexes
. Maybe some consideration from the author? @junjiek
def __init__( | ||
self, | ||
weight_decay: Union[FloatTensorLike, Callable], | ||
exclude_from_weight_decay: Optional[List[str]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add this to the Args section below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Could you also add a sentence that explains that decay_var_list in minimize takes priority over exclude_from_weight_decay if specified (and also add a corresponding sentence to the documentation to minimize)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated DecoupledWeightDecayExtension
__init__
, minimize
and apply_gradients
docs, also extend_with_decoupled_weight_decay
doc. Added exclude_from_weight_decay
in AdamW
and SGDW
**kwargs
doc.
Four test cases added:
|
Thanks for making these changes, looks good from my side! :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @leondgarse for the PR and @PhilJd for the review!
Description
Brief Description of the PR:
exclude_from_weight_decay
forDecoupledWeightDecayExtension
optimizers includingAdamW
andSGDW
, likeLAMB
. There are several issues on this, like Support exclude_from_weight_decay in AdamW #1903 and Add decay_var_list as init option to DecoupledWeightDecayExtension #2018.LAMB exclude_from_weight_decay
behavior, and it has a conflict with currentDecoupledWeightDecayExtension
_decay_var_list
. I'm actually looking forward if a better solution can be introduced.Fixes # (issue)
Type of change
Checklist:
How Has This Been Tested?
If we open the commented print line 248, then:
If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes:
*