From ce16e6292ea25ad92e72ef54098f43a2c454fdb0 Mon Sep 17 00:00:00 2001
From: Jared T Nielsen <jaredtnielsen@gmail.com>
Date: Fri, 3 Apr 2020 02:44:31 -0700
Subject: [PATCH] Fix LAMB optimizer regex parsing (#1532)

* Fix type for LAMB optimizer exclude_from_weight_decay

* Add import

* Add optional wrapper

* Add test

* Layer adaption test

* Typo
---
 tensorflow_addons/optimizers/lamb.py      | 10 +++++-----
 tensorflow_addons/optimizers/lamb_test.py | 14 ++++++++++++++
 2 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py
index d5a5807048..1e121e172d 100644
--- a/tensorflow_addons/optimizers/lamb.py
+++ b/tensorflow_addons/optimizers/lamb.py
@@ -19,7 +19,7 @@
 """
 
 import re
-from typing import Optional, Union, Callable
+from typing import Optional, Union, Callable, List
 from typeguard import typechecked
 
 import tensorflow as tf
@@ -42,8 +42,8 @@ def __init__(
         beta_2: FloatTensorLike = 0.999,
         epsilon: FloatTensorLike = 1e-6,
         weight_decay_rate: FloatTensorLike = 0.0,
-        exclude_from_weight_decay: Optional[str] = None,
-        exclude_from_layer_adaptation: Optional[str] = None,
+        exclude_from_weight_decay: Optional[List[str]] = None,
+        exclude_from_layer_adaptation: Optional[List[str]] = None,
         name: str = "LAMB",
         **kwargs
     ):
@@ -59,10 +59,10 @@ def __init__(
               The exponential decay rate for the 2nd moment estimates.
             epsilon: A small constant for numerical stability.
             weight_decay_rate: weight decay rate.
-            exclude_from_weight_decay: comma separated name patterns of
+            exclude_from_weight_decay: List of regex patterns of
               variables excluded from weight decay. Variables whose name
               contain a substring matching the pattern will be excluded.
-            exclude_from_layer_adaptation: comma separated name patterns of
+            exclude_from_layer_adaptation: List of regex patterns of
               variables excluded from layer adaptation. Variables whose name
               contain a substring matching the pattern will be excluded.
             name: Optional name for the operations created when applying
diff --git a/tensorflow_addons/optimizers/lamb_test.py b/tensorflow_addons/optimizers/lamb_test.py
index e738c111f7..ede68e5f24 100644
--- a/tensorflow_addons/optimizers/lamb_test.py
+++ b/tensorflow_addons/optimizers/lamb_test.py
@@ -401,3 +401,17 @@ def test_get_config(self):
         opt = lamb.LAMB(1e-4)
         config = opt.get_config()
         self.assertEqual(config["learning_rate"], 1e-4)
+
+    def test_exclude_weight_decay(self):
+        opt = lamb.LAMB(
+            0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]
+        )
+        assert opt._do_use_weight_decay("var0")
+        assert not opt._do_use_weight_decay("var1")
+        assert not opt._do_use_weight_decay("var1_weight")
+
+    def test_exclude_layer_adaptation(self):
+        opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"])
+        assert opt._do_layer_adaptation("var0")
+        assert not opt._do_layer_adaptation("var1")
+        assert not opt._do_layer_adaptation("var1_weight")