29
29
import six
30
30
import tensorflow as tf
31
31
32
+ from kungfu import current_rank , current_cluster_size
33
+
32
34
flags = tf .flags
33
35
34
36
FLAGS = flags .FLAGS
@@ -1141,11 +1143,15 @@ def main(_):
1141
1143
FLAGS .tpu_name , zone = FLAGS .tpu_zone , project = FLAGS .gcp_project )
1142
1144
1143
1145
is_per_host = tf .contrib .tpu .InputPipelineConfig .PER_HOST_V2
1146
+
1147
+ # KungFu: Let one estimator to do checkpoint.
1148
+ save_checkpoints_steps = None if current_rank () != 0 else FLAGS .save_checkpoints_steps
1149
+
1144
1150
run_config = tf .contrib .tpu .RunConfig (
1145
1151
cluster = tpu_cluster_resolver ,
1146
1152
master = FLAGS .master ,
1147
1153
model_dir = FLAGS .output_dir ,
1148
- save_checkpoints_steps = FLAGS . save_checkpoints_steps ,
1154
+ save_checkpoints_steps = save_checkpoints_steps ,
1149
1155
tpu_config = tf .contrib .tpu .TPUConfig (
1150
1156
iterations_per_loop = FLAGS .iterations_per_loop ,
1151
1157
num_shards = FLAGS .num_tpu_cores ,
@@ -1166,6 +1172,10 @@ def main(_):
1166
1172
rng = random .Random (12345 )
1167
1173
rng .shuffle (train_examples )
1168
1174
1175
+ # KungFu: Adjust training steps based on parallelism
1176
+ num_train_steps = num_train_steps // current_cluster_size ()
1177
+ num_warmup_steps = num_warmup_steps // current_cluster_size ()
1178
+
1169
1179
model_fn = model_fn_builder (
1170
1180
bert_config = bert_config ,
1171
1181
init_checkpoint = FLAGS .init_checkpoint ,
@@ -1212,7 +1222,12 @@ def main(_):
1212
1222
seq_length = FLAGS .max_seq_length ,
1213
1223
is_training = True ,
1214
1224
drop_remainder = True )
1215
- estimator .train (input_fn = train_input_fn , max_steps = num_train_steps )
1225
+
1226
+ # KungFu: let the first estimator to broadcast global variables.
1227
+ from kungfu .tensorflow .v1 .initializer import BroadcastGlobalVariablesHook
1228
+ hooks = [BroadcastGlobalVariablesHook ()]
1229
+
1230
+ estimator .train (input_fn = train_input_fn , max_steps = num_train_steps , hooks = hooks )
1216
1231
1217
1232
if FLAGS .do_predict :
1218
1233
eval_examples = read_squad_examples (
0 commit comments