Skip to content

Commit

Permalink
Failing to restore AdaBelief optimizer from checkpoint (#2705)
Browse files Browse the repository at this point in the history
* Update adabelief.py

* addressed pr comments
  • Loading branch information
denadai2 authored May 18, 2022
1 parent e5bfef1 commit 953d848
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensorflow_addons/optimizers/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self._set_hyper("decay", self._initial_decay)
self._set_hyper("weight_decay", weight_decay)
self._set_hyper("sma_threshold", sma_threshold)
self._set_hyper("total_steps", int(total_steps))
self._set_hyper("total_steps", float(total_steps))
self._set_hyper("warmup_proportion", warmup_proportion)
self._set_hyper("min_lr", min_lr)
self.epsilon = epsilon or tf.keras.backend.epsilon()
Expand Down Expand Up @@ -325,7 +325,7 @@ def get_config(self):
"epsilon": self.epsilon,
"amsgrad": self.amsgrad,
"rectify": self.rectify,
"total_steps": self._serialize_hyperparameter("total_steps"),
"total_steps": int(self._serialize_hyperparameter("total_steps")),
"warmup_proportion": self._serialize_hyperparameter(
"warmup_proportion"
),
Expand Down
23 changes: 23 additions & 0 deletions tensorflow_addons/optimizers/tests/adabelief_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,26 @@ def test_scheduler_serialization():
"class_name": "InverseTimeDecay",
"config": wd_scheduler.get_config(),
}


def test_checkpoint_serialization(tmpdir):
optimizer = AdaBelief()
optimizer2 = AdaBelief()

var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)
var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32)

grad_0 = tf.constant([0.1, 0.2], dtype=tf.dtypes.float32)
grad_1 = tf.constant([0.03, 0.04], dtype=tf.dtypes.float32)

grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1]))

optimizer.apply_gradients(grads_and_vars)

checkpoint = tf.train.Checkpoint(optimizer=optimizer)
checkpoint2 = tf.train.Checkpoint(optimizer=optimizer2)
model_path = str(tmpdir / "adabelief_chkpt")
checkpoint.write(model_path)
checkpoint2.read(model_path)

optimizer2.apply_gradients(grads_and_vars)

0 comments on commit 953d848

Please sign in to comment.