Skip to content

Commit

Permalink
Using MultiWorkerMirroredStrategy instead of MirroredStrategy since i…
Browse files Browse the repository at this point in the history
…t uses TF_CONFIG here and changes to keep it close to the source version.
  • Loading branch information
chamorajg committed Jun 21, 2019
1 parent 7116db2 commit b0c21af
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions adanet/core/estimator_distributed_test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,16 +307,17 @@ def _model_fn(features, labels, mode):
train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
return tf.EstimatorSpec(mode, loss=loss, train_op=train_op)

def _input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(100)
labels = tf.data.Dataset.from_tensors(1.).repeat(100)
return tf.data.Dataset.zip((features, labels))

distribution = tf.distribute.experimental.MultiWorkerMirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=_model_fn, config=config)
classifier.train(input_fn=_input_fn)
classifier.evaluate(input_fn=_input_fn)
config = tf.estimator.RunConfig(
tf_random_seed=42,
train_distribute = distribution,
model_dir=FLAGS.model_dir,
session_config=tf_compat.v1.ConfigProto(
log_device_placement=False,
# Ignore other workers; only talk to parameter servers.
# Otherwise, when a chief/worker terminates, the others will hang.
device_filters=["/job:ps"]))
estimator = tf.estimator.Estimator(model_fn=_model_fn, config=config)

def input_fn():
input_features = {"x": tf.constant(features, name="x")}
Expand Down

0 comments on commit b0c21af

Please sign in to comment.