-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TPU support to BERT example #207
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Looking at the code more closely I think we don't need the flag yet, it is not adding anything.
examples/bert/bert_train.py
Outdated
@@ -386,6 +387,23 @@ def main(_): | |||
|
|||
model_config = MODEL_CONFIGS[FLAGS.model_size] | |||
|
|||
if FLAGS.use_tpu: | |||
if not tf.config.list_logical_devices("TPU"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like right now at least we don't need this flag. We could just do if not tf.config.list_logical_devices("TPU")
and connect is we find any right? I don't think there's any use case where we find a TPU but don't want a TPUStrategy.
As discussed, we may still need some sort of flags to support multi-worker training, but let's add then when we need them. For this PR we don't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh there might be? debugging on TPU is a bit subtle, so usually I test the code is runnable on CPU before turning on the tpu testing flag. But yea this is a minor case, for debugging I can bypass TPU with local changes. I will let you make the call!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah let's remove flag for now. GPU will auto connect and need to be disabled manually, so I think it's reasonable if TPU behaves the same.
If the main use case to cover is "I want to force running on CPU," we could consider adding a mechanism that would works on both GPU and TPU machines. (Also maybe one already exists? Is there a CUDA_VISIBLE_DEVICES=-1
equivalent for TPU?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly TPU_VISIBLE_CHIPS=-1
is an equivalent, though I can't really tell if that would work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg! changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. one nit
examples/bert/bert_train.py
Outdated
@@ -386,6 +385,18 @@ def main(_): | |||
|
|||
model_config = MODEL_CONFIGS[FLAGS.model_size] | |||
|
|||
if tf.config.list_logical_devices("TPU"): | |||
# Connect to TPU and create TPU strategy. | |||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can replace the next few lines with
resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu='local')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
No description provided.