Skip to content

Commit 45ee88b

Browse files
committed
BERT runs with KungFu in the run_squad task.
1 parent cc7051d commit 45ee88b

File tree

4 files changed

+58
-3
lines changed

4 files changed

+58
-3
lines changed

README.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1-
# BERT
1+
# Distributed BERT with KungFu
2+
3+
## Scaling out BERT with KungFu
4+
5+
Install [KungFu](https://github.com/lsds/KungFu) first.
6+
Configure the relevant paths in the `run_kungfu.sh`, and simply launch:
7+
8+
```bash
9+
./run_kungfu.sh
10+
```
11+
12+
## BERT Releases
213

314
**\*\*\*\*\* New May 31st, 2019: Whole Word Masking Models \*\*\*\*\***
415

optimization.py

+3
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
6767
if use_tpu:
6868
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
6969

70+
from kungfu.tensorflow.v1.optimizers import SynchronousSGDOptimizer
71+
optimizer = SynchronousSGDOptimizer(optimizer)
72+
7073
tvars = tf.trainable_variables()
7174
grads = tf.gradients(loss, tvars)
7275

run_kungfu.sh

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Path to the pre-trained model
2+
BERT_BASE_DIR=/data/uncased_L-12_H-768_A-12
3+
4+
# Path to the squad dataset
5+
SQUAD_DIR=/data/squad1
6+
7+
# Path to the checkpoint folder
8+
OUTPUT_DIR=./tmp/squad_base_kungfu
9+
10+
# Path to the kungfu-run executable
11+
KUNGFU_RUN=$HOME/KungFu/bin/kungfu-run
12+
13+
$KUNGFU_RUN -np 4 python3 run_squad.py \
14+
--vocab_file=$BERT_BASE_DIR/vocab.txt \
15+
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
16+
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
17+
--do_train=True \
18+
--train_file=$SQUAD_DIR/train-v1.1.json \
19+
--do_predict=True \
20+
--predict_file=$SQUAD_DIR/dev-v1.1.json \
21+
--train_batch_size=8 \
22+
--learning_rate=3e-5 \
23+
--num_train_epochs=2.0 \
24+
--max_seq_length=384 \
25+
--doc_stride=128 \
26+
--output_dir=$OUTPUT_DIR

run_squad.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import six
3030
import tensorflow as tf
3131

32+
from kungfu import current_rank, current_cluster_size
33+
3234
flags = tf.flags
3335

3436
FLAGS = flags.FLAGS
@@ -1141,11 +1143,15 @@ def main(_):
11411143
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
11421144

11431145
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
1146+
1147+
# KungFu: Let one estimator to do checkpoint.
1148+
save_checkpoints_steps = None if current_rank() != 0 else FLAGS.save_checkpoints_steps
1149+
11441150
run_config = tf.contrib.tpu.RunConfig(
11451151
cluster=tpu_cluster_resolver,
11461152
master=FLAGS.master,
11471153
model_dir=FLAGS.output_dir,
1148-
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
1154+
save_checkpoints_steps=save_checkpoints_steps,
11491155
tpu_config=tf.contrib.tpu.TPUConfig(
11501156
iterations_per_loop=FLAGS.iterations_per_loop,
11511157
num_shards=FLAGS.num_tpu_cores,
@@ -1166,6 +1172,10 @@ def main(_):
11661172
rng = random.Random(12345)
11671173
rng.shuffle(train_examples)
11681174

1175+
# KungFu: Adjust training steps based on parallelism
1176+
num_train_steps = num_train_steps // current_cluster_size()
1177+
num_warmup_steps = num_warmup_steps // current_cluster_size()
1178+
11691179
model_fn = model_fn_builder(
11701180
bert_config=bert_config,
11711181
init_checkpoint=FLAGS.init_checkpoint,
@@ -1212,7 +1222,12 @@ def main(_):
12121222
seq_length=FLAGS.max_seq_length,
12131223
is_training=True,
12141224
drop_remainder=True)
1215-
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
1225+
1226+
# KungFu: let the first estimator to broadcast global variables.
1227+
from kungfu.tensorflow.v1.initializer import BroadcastGlobalVariablesHook
1228+
hooks = [BroadcastGlobalVariablesHook()]
1229+
1230+
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=hooks)
12161231

12171232
if FLAGS.do_predict:
12181233
eval_examples = read_squad_examples(

0 commit comments

Comments
 (0)