diff --git a/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py index 525b2bfda7..edf3b97e0e 100644 --- a/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py +++ b/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py @@ -246,9 +246,14 @@ def test_exclude_weight_decay_adamw(): optimizer = weight_decay_optimizers.AdamW( learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"] ) - assert optimizer._do_use_weight_decay(tf.Variable([], name="var0")) - assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1")) - assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight")) + var0 = tf.Variable([], name="var0") + var1 = tf.Variable([], name="var1") + var1_weight = tf.Variable([], name="var1_weight") + + optimizer._set_decay_var_list([var0, var1, var1_weight]) + assert optimizer._do_use_weight_decay(var0) + assert not optimizer._do_use_weight_decay(var1) + assert not optimizer._do_use_weight_decay(var1_weight) @pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) @@ -371,9 +376,14 @@ def test_exclude_weight_decay_sgdw(): optimizer = weight_decay_optimizers.SGDW( learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"] ) - assert optimizer._do_use_weight_decay(tf.Variable([], name="var0")) - assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1")) - assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight")) + var0 = tf.Variable([], name="var0") + var1 = tf.Variable([], name="var1") + var1_weight = tf.Variable([], name="var1_weight") + + optimizer._set_decay_var_list([var0, var1, var1_weight]) + assert optimizer._do_use_weight_decay(var0) + assert not optimizer._do_use_weight_decay(var1) + assert not optimizer._do_use_weight_decay(var1_weight) @pytest.mark.usefixtures("maybe_run_functions_eagerly") diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index 3d882b0169..c4fbd60e5a 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -157,9 +157,7 @@ def minimize( Raises: ValueError: If some of the variables are not `Variable` objects. """ - self._decay_var_list = ( - set([v.ref() for v in decay_var_list]) if decay_var_list else False - ) + self._set_decay_var_list(var_list, decay_var_list) return super().minimize( loss, var_list=var_list, grad_loss=grad_loss, name=name, tape=tape ) @@ -186,9 +184,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. """ - self._decay_var_list = ( - set([v.ref() for v in decay_var_list]) if decay_var_list else False - ) + grads_and_vars = list(grads_and_vars) + self._set_decay_var_list((v for _, v in grads_and_vars), decay_var_list) return super().apply_gradients(grads_and_vars, name=name, **kwargs) def _decay_weights_op(self, var, apply_state=None): @@ -245,11 +242,23 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None): grad, var, indices, apply_state=apply_state ) + def _set_decay_var_list(self, var_list, decay_var_list=None): + if decay_var_list: + self._decay_var_list = set(v.ref() for v in decay_var_list) + elif self.exclude_from_weight_decay: + self._decay_var_list = set( + v.ref() + for v in var_list + if not is_variable_matched_by_regexes(v, self.exclude_from_weight_decay) + ) + else: + self._decay_var_list = None + def _do_use_weight_decay(self, var): """Whether to use L2 weight decay for `var`.""" - if self._decay_var_list and var.ref() in self._decay_var_list: + if self._decay_var_list is None: return True - return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay) + return var.ref() in self._decay_var_list @typechecked