Skip to content

Commit

Permalink
exclude_from_weight_decay for AdamW and SGDW (#2624)
Browse files Browse the repository at this point in the history
* exclude_from_weight_decay for AdamW and SGDW
  • Loading branch information
leondgarse authored Jan 3, 2022
1 parent ef80dc4 commit 37a368a
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 55 deletions.
39 changes: 13 additions & 26 deletions tensorflow_addons/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
76 minutes](https://arxiv.org/abs/1904.00962).
"""

import re
import warnings

from typing import Optional, Union, Callable, List
from typeguard import typechecked

import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes


@tf.keras.utils.register_keras_serializable(package="Addons")
Expand Down Expand Up @@ -163,12 +163,11 @@ def _resource_apply_dense(self, grad, var, apply_state=None):
v_sqrt = tf.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + coefficients["epsilon"])

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
if self._do_use_weight_decay(var):
update += coefficients["weight_decay"] * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
if self._do_layer_adaptation(var):
w_norm = tf.norm(var, ord=2)
g_norm = tf.norm(update, ord=2)
ratio = tf.where(
Expand Down Expand Up @@ -206,12 +205,11 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
v_sqrt = tf.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + coefficients["epsilon"])

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
if self._do_use_weight_decay(var):
update += coefficients["weight_decay"] * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
if self._do_layer_adaptation(var):
w_norm = tf.norm(var, ord=2)
g_norm = tf.norm(update, ord=2)
ratio = tf.where(
Expand Down Expand Up @@ -241,26 +239,15 @@ def get_config(self):
)
return config

def _do_use_weight_decay(self, param_name):
def _do_use_weight_decay(self, variable):
"""Whether to use L2 weight decay for `param_name`."""
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
return not is_variable_matched_by_regexes(
variable, self.exclude_from_weight_decay
)

def _do_layer_adaptation(self, param_name):
def _do_layer_adaptation(self, variable):
"""Whether to do layer-wise learning rate adaptation for
`param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True

def _get_variable_name(self, param_name):
"""Get the variable name from the tensor name."""
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name
return not is_variable_matched_by_regexes(
variable, self.exclude_from_layer_adaptation
)
16 changes: 9 additions & 7 deletions tensorflow_addons/optimizers/tests/lamb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,20 +335,22 @@ def test_get_config():

def test_exclude_weight_decay():
opt = lamb.LAMB(0.01, weight_decay=0.01, exclude_from_weight_decay=["var1"])
assert opt._do_use_weight_decay("var0")
assert not opt._do_use_weight_decay("var1")
assert not opt._do_use_weight_decay("var1_weight")
assert opt._do_use_weight_decay(tf.Variable([], name="var0"))
assert not opt._do_use_weight_decay(tf.Variable([], name="var1"))
assert not opt._do_use_weight_decay(tf.Variable([], name="var1_weight"))


def test_exclude_layer_adaptation():
opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"])
assert opt._do_layer_adaptation("var0")
assert not opt._do_layer_adaptation("var1")
assert not opt._do_layer_adaptation("var1_weight")
assert opt._do_layer_adaptation(tf.Variable([], name="var0"))
assert not opt._do_layer_adaptation(tf.Variable([], name="var1"))
assert not opt._do_layer_adaptation(tf.Variable([], name="var1_weight"))


def test_serialization():
optimizer = lamb.LAMB(1e-4)
optimizer = lamb.LAMB(
1e-4, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]
)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def do_test(
opt = optimizer(**optimizer_kwargs)
# Create the update op.
# Run 3 steps of the optimizer
optimizer_kwargs.pop("exclude_from_weight_decay", None)
for _ in range(3):
if do_decay_var_list:
opt.apply_gradients(
Expand Down Expand Up @@ -241,6 +242,31 @@ def test_basic_decay_var_list_adamw(dtype):
)


def test_exclude_weight_decay_adamw():
optimizer = weight_decay_optimizers.AdamW(
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))


@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
def test_var_list_with_exclude_list_adamw(dtype):
do_test(
dtype,
weight_decay_optimizers.AdamW,
adamw_update_numpy,
do_decay_var_list=True,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-8,
weight_decay=WEIGHT_DECAY,
exclude_from_weight_decay=["var0_*", "var1_*"],
)


def test_keras_fit():
"""Check if calling model.fit works."""
model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)])
Expand Down Expand Up @@ -341,6 +367,30 @@ def test_basic_decay_var_list_sgdw(dtype):
)


def test_exclude_weight_decay_sgdw():
optimizer = weight_decay_optimizers.SGDW(
learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
def test_var_list_with_exclude_list_sgdw(dtype):
do_test(
dtype,
weight_decay_optimizers.SGDW,
sgdw_update_numpy,
do_decay_var_list=True,
learning_rate=0.001,
momentum=0.9,
weight_decay=WEIGHT_DECAY,
exclude_from_weight_decay=["var0_*", "var1_*"],
)


@pytest.mark.parametrize(
"optimizer",
[
Expand Down Expand Up @@ -379,7 +429,9 @@ def test_optimizer_sparse(dtype, optimizer):


def test_serialization():
optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4)
optimizer = weight_decay_optimizers.AdamW(
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()
Expand Down
22 changes: 22 additions & 0 deletions tensorflow_addons/optimizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# ==============================================================================
"""Additional Utilities used for tfa.optimizers."""

import re
import tensorflow as tf
from typing import List


def fit_bn(model, *args, **kwargs):
Expand Down Expand Up @@ -51,3 +53,23 @@ def fit_bn(model, *args, **kwargs):

model.trainable = _trainable
model._metrics = _metrics


def get_variable_name(variable) -> str:
"""Get the variable name from the variable tensor."""
param_name = variable.name
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name


def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool:
"""Whether variable is matched in regexes list by its name."""
if regexes:
# var_name = get_variable_name(variable)
var_name = variable.name
for r in regexes:
if re.search(r, var_name):
return True
return False
75 changes: 54 additions & 21 deletions tensorflow_addons/optimizers/weight_decay_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes

from typeguard import typechecked
from typing import Union, Callable, Type
from typing import Union, Callable, Type, Optional, List


class DecoupledWeightDecayExtension:
Expand Down Expand Up @@ -71,24 +72,40 @@ def __init__(self, weight_decay, *args, **kwargs):
"""

@typechecked
def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs):
def __init__(
self,
weight_decay: Union[FloatTensorLike, Callable],
exclude_from_weight_decay: Optional[List[str]] = None,
**kwargs,
):
"""Extension class that adds weight decay to an optimizer.
Args:
weight_decay: A `Tensor`, a floating point value, or a schedule
that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
to decay the variable by, in the update step.
exclude_from_weight_decay: List of regex patterns of
variables excluded from weight decay. Variables whose name
contain a substring matching the pattern will be excluded.
Note `decay_var_list` in `minimize` or `apply_gradients` takes
priority over `exclude_from_weight_decay` if specified.
**kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
wd = kwargs.pop("weight_decay", weight_decay)
super().__init__(**kwargs)
self._decay_var_list = None # is set in minimize or apply_gradients
self._set_hyper("weight_decay", wd)
self.exclude_from_weight_decay = exclude_from_weight_decay

def get_config(self):
config = super().get_config()
config.update({"weight_decay": self._serialize_hyperparameter("weight_decay")})
config.update(
{
"weight_decay": self._serialize_hyperparameter("weight_decay"),
"exclude_from_weight_decay": self.exclude_from_weight_decay,
}
)
return config

@classmethod
Expand Down Expand Up @@ -130,7 +147,8 @@ def minimize(
grad_loss: Optional. A `Tensor` holding the gradient computed for
`loss`.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
to all variables in var_list. Note `decay_var_list` takes
priority over `exclude_from_weight_decay` if specified.
name: Optional name for the returned operation.
tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
`Tensor`, the tape that computed the `loss` must be provided.
Expand All @@ -154,10 +172,11 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
to all variables in var_list. Note `decay_var_list` takes
priority over `exclude_from_weight_decay` if specified.
**kwargs: Additional arguments to pass to the base optimizer's
apply_gradient method, e.g., TF2.2 added an argument
`experimental_aggregate_gradients`.
Expand All @@ -173,7 +192,7 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
return super().apply_gradients(grads_and_vars, name=name, **kwargs)

def _decay_weights_op(self, var, apply_state=None):
if not self._decay_var_list or var.ref() in self._decay_var_list:
if self._do_use_weight_decay(var):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = (apply_state or {}).get(
(var_device, var_dtype)
Expand All @@ -183,7 +202,7 @@ def _decay_weights_op(self, var, apply_state=None):
return tf.no_op()

def _decay_weights_sparse_op(self, var, indices, apply_state=None):
if not self._decay_var_list or var.ref() in self._decay_var_list:
if self._do_use_weight_decay(var):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = (apply_state or {}).get(
(var_device, var_dtype)
Expand Down Expand Up @@ -226,6 +245,12 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
grad, var, indices, apply_state=apply_state
)

def _do_use_weight_decay(self, var):
"""Whether to use L2 weight decay for `var`."""
if self._decay_var_list and var.ref() in self._decay_var_list:
return True
return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay)


@typechecked
def extend_with_decoupled_weight_decay(
Expand All @@ -243,9 +268,13 @@ def extend_with_decoupled_weight_decay(
The API of the new optimizer class slightly differs from the API of the
base optimizer:
- The first argument to the constructor is the weight decay rate.
- Optional keyword argument `exclude_from_weight_decay` accepts list of
regex patterns of variables excluded from weight decay. Variables whose
name contain a substring matching the pattern will be excluded.
- `minimize` and `apply_gradients` accept the optional keyword argument
`decay_var_list`, which specifies the variables that should be decayed.
If `None`, all variables that are optimized are decayed.
Note this takes priority over `exclude_from_weight_decay` if specified.
If both `None`, all variables that are optimized are decayed.
Usage example:
```python
Expand Down Expand Up @@ -376,12 +405,14 @@ def __init__(
nesterov: boolean. Whether to apply Nesterov momentum.
name: Optional name prefix for the operations created when applying
gradients. Defaults to 'SGD'.
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
`lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
gradients by norm; `clipvalue` is clip gradients by value.
`decay` is included for backward compatibility to allow time
inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
`exclude_from_weight_decay` accepts list of regex patterns of
variables excluded from weight decay.
"""
super().__init__(
weight_decay,
Expand Down Expand Up @@ -466,12 +497,14 @@ def __init__(
beyond".
name: Optional name for the operations created when applying
gradients. Defaults to "AdamW".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
`lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
gradients by norm; `clipvalue` is clip gradients by value.
`decay` is included for backward compatibility to allow time
inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
`exclude_from_weight_decay` accepts list of regex patterns of
variables excluded from weight decay.
"""
super().__init__(
weight_decay,
Expand Down

0 comments on commit 37a368a

Please sign in to comment.