Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LAMB optimizer regex parsing #1532

Merged
merged 7 commits into from
Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tensorflow_addons/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import re
from typing import Optional, Union, Callable
from typing import Optional, Union, Callable, List
from typeguard import typechecked

import tensorflow as tf
Expand All @@ -42,8 +42,8 @@ def __init__(
beta_2: FloatTensorLike = 0.999,
epsilon: FloatTensorLike = 1e-6,
weight_decay_rate: FloatTensorLike = 0.0,
exclude_from_weight_decay: Optional[str] = None,
exclude_from_layer_adaptation: Optional[str] = None,
exclude_from_weight_decay: Optional[List[str]] = None,
exclude_from_layer_adaptation: Optional[List[str]] = None,
name: str = "LAMB",
**kwargs
):
Expand All @@ -59,10 +59,10 @@ def __init__(
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability.
weight_decay_rate: weight decay rate.
exclude_from_weight_decay: comma separated name patterns of
exclude_from_weight_decay: List of regex patterns of
variables excluded from weight decay. Variables whose name
contain a substring matching the pattern will be excluded.
exclude_from_layer_adaptation: comma separated name patterns of
exclude_from_layer_adaptation: List of regex patterns of
variables excluded from layer adaptation. Variables whose name
contain a substring matching the pattern will be excluded.
name: Optional name for the operations created when applying
Expand Down
14 changes: 14 additions & 0 deletions tensorflow_addons/optimizers/lamb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,17 @@ def test_get_config(self):
opt = lamb.LAMB(1e-4)
config = opt.get_config()
self.assertEqual(config["learning_rate"], 1e-4)

def test_exclude_weight_decay(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test for exclude_from_layer_adaption as well please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

opt = lamb.LAMB(
0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]
)
assert opt._do_use_weight_decay("var0")
assert not opt._do_use_weight_decay("var1")
assert not opt._do_use_weight_decay("var1_weight")

def test_exclude_layer_adaptation(self):
opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"])
assert opt._do_layer_adaptation("var0")
assert not opt._do_layer_adaptation("var1")
assert not opt._do_layer_adaptation("var1_weight")