From 3acee5ee0f76afc383b885601d432f22768453b8 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Fri, 15 Apr 2022 23:01:17 -0700 Subject: [PATCH] Fix bert example so it is runnable This broke in a recent update to the learning rate. --- examples/bert/README.md | 2 ++ examples/bert/run_pretraining.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/bert/README.md b/examples/bert/README.md index a8ffa96964..15df69ecdb 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -40,6 +40,8 @@ python3 examples/bert/run_pretraining.py \ --input_files $OUTPUT_DIR/pretraining-data/ \ --vocab_file $OUTPUT_DIR/bert_vocab_uncased.txt \ --bert_config_file examples/bert/configs/bert_tiny.json \ + --num_warmup_steps 20 \ + --num_train_steps 200 \ --saved_model_output $OUTPUT_DIR/model/ # Run finetuning. diff --git a/examples/bert/run_pretraining.py b/examples/bert/run_pretraining.py index 580694afb4..5bfaf723cf 100644 --- a/examples/bert/run_pretraining.py +++ b/examples/bert/run_pretraining.py @@ -63,14 +63,14 @@ flags.DEFINE_integer( "num_warmup_steps", - 1e4, + 10000, "The number of warmup steps during which the learning rate will increase " "till a threshold.", ) flags.DEFINE_integer( "num_train_steps", - 1e6, + 1000000, "The total fixed number of steps till which the model will train.", ) @@ -326,11 +326,14 @@ def __call__(self, step): is_warmup = step < warmup # Linear Warmup will be implemented if current step is less than - # `num_warmup_steps`. - if is_warmup: - return peak_lr * (step / warmup) - # else Linear Decay will be implemented - return max(0.0, peak_lr * (training - step) / (training - warmup)) + # `num_warmup_steps` else Linear Decay will be implemented. + return tf.cond( + is_warmup, + lambda: peak_lr * (step / warmup), + lambda: tf.math.maximum( + 0.0, peak_lr * (training - step) / (training - warmup) + ), + ) def decode_record(record):