Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify exclude_from_weight_decay implementation #2676

Merged
merged 1 commit into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,14 @@ def test_exclude_weight_decay_adamw():
optimizer = weight_decay_optimizers.AdamW(
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
var0 = tf.Variable([], name="var0")
var1 = tf.Variable([], name="var1")
var1_weight = tf.Variable([], name="var1_weight")

optimizer._set_decay_var_list([var0, var1, var1_weight])
assert optimizer._do_use_weight_decay(var0)
assert not optimizer._do_use_weight_decay(var1)
assert not optimizer._do_use_weight_decay(var1_weight)


@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
Expand Down Expand Up @@ -371,9 +376,14 @@ def test_exclude_weight_decay_sgdw():
optimizer = weight_decay_optimizers.SGDW(
learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
var0 = tf.Variable([], name="var0")
var1 = tf.Variable([], name="var1")
var1_weight = tf.Variable([], name="var1_weight")

optimizer._set_decay_var_list([var0, var1, var1_weight])
assert optimizer._do_use_weight_decay(var0)
assert not optimizer._do_use_weight_decay(var1)
assert not optimizer._do_use_weight_decay(var1_weight)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
Expand Down
25 changes: 17 additions & 8 deletions tensorflow_addons/optimizers/weight_decay_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ def minimize(
Raises:
ValueError: If some of the variables are not `Variable` objects.
"""
self._decay_var_list = (
set([v.ref() for v in decay_var_list]) if decay_var_list else False
)
self._set_decay_var_list(var_list, decay_var_list)
return super().minimize(
loss, var_list=var_list, grad_loss=grad_loss, name=name, tape=tape
)
Expand All @@ -186,9 +184,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
self._decay_var_list = (
set([v.ref() for v in decay_var_list]) if decay_var_list else False
)
grads_and_vars = list(grads_and_vars)
self._set_decay_var_list((v for _, v in grads_and_vars), decay_var_list)
return super().apply_gradients(grads_and_vars, name=name, **kwargs)

def _decay_weights_op(self, var, apply_state=None):
Expand Down Expand Up @@ -245,11 +242,23 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
grad, var, indices, apply_state=apply_state
)

def _set_decay_var_list(self, var_list, decay_var_list=None):
if decay_var_list:
self._decay_var_list = set(v.ref() for v in decay_var_list)
elif self.exclude_from_weight_decay:
self._decay_var_list = set(
v.ref()
for v in var_list
if not is_variable_matched_by_regexes(v, self.exclude_from_weight_decay)
)
else:
self._decay_var_list = None

def _do_use_weight_decay(self, var):
"""Whether to use L2 weight decay for `var`."""
if self._decay_var_list and var.ref() in self._decay_var_list:
if self._decay_var_list is None:
return True
return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay)
return var.ref() in self._decay_var_list


@typechecked
Expand Down