-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multinode training on GPU #2731
Comments
This is actually something that does work right now but it's still experimental. There's also no real public-facing API for it yet; you have to type in some obscure and fairly magical things to set it all up correctly. We should polish it off and document it! |
Can you say a bit more about your model, though? Would gradient all-reductions across multiple nodes suffice? |
@hawkinsp Technically, I'm training a reformer model using Trax library. |
And I assume you're just looking for data parallelism, i.e., partitioning a minibatch across GPUs, not partitioning in any other way (e.g., model parallelism)? |
@hawkinsp yeah my concern is data parallelism |
Data parallelism would of value to other projects that use XLA as well (eg https://www.tensorflow.org/swift). Exposing this functionality in a standardized way would help drive progress in the broader ecosystem! |
Hello py4, I am meeting the same problem, have you found some solutions? |
Hello hawkinsp, Could you please provide more details about how to run data parallel with multi node GPUs? |
@hawkinsp We are also interested in running JAX code on multiple nodes. Anything (hacky or not) that you can share would be appreciated. Thanks! |
I really enjoyed Jax during my DM internship and wanted to use it on my university SLURM cluster, but the lack of a clear (official) data parallel (multi-node) solution is a huge blocker to increasing Jax adoption outside of Google where you cant just grab a TPU pod and |
I would love this feature! I enjoy Jax, but I've been largely using DeepSpeed due to its ability to distribute across clusters. |
Any progress on this issue ? Using JAX to train a model on multi-node, multi-GPU is becoming a very important features for us. |
@sudhakarsingh27 I constantly monitoring the jax releases, and there is something WIP that you might be interested in #8364 |
See also: #9582 |
Yes indeed. We haven't advertised it that much yet, but (a) you need to initialize the cluster using that API, and (b) you need to follow the same rules of multi-host programming that also apply on TPU, documented here: https://jax.readthedocs.io/en/latest/multi_process.html I suspect we can consider this issue closed when we've documented (a) in the document (b). |
@hawkinsp @zhangqiaorjc
Ran the attached minimal code on single node with 8 V100 GPUs as follows (2 processes with 4 GPUs each):
I could check that multi process(host/node) first fails with
I get the following error when I run the multi-process jax commands above:
For reference, here's the ouput from
|
To run T5x on multi-node and multi-GPUs, `jax.distributed.initialize` needs to be called with appropriate setup as mentioned here: jax-ml/jax#8364. Added a command line flag - `multiprocess` to enable multiprocess T5x run on GPUs. Also, added command line flags for the arguments to `jax.distributed.initialize`, namely - `coordinator_address`, `num_processes` and `process_id`. Example usage 1 (2 processes, running on 2 separate nodes, 8 GPUs each): On the first node: ``` python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=i.p.ad.dr:port \ --num_processes=2 \ --process_id=0 ``` On the second node: ``` python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=i.p.ad.dr:port \ --num_processes=2 \ --process_id=1 ``` Notice that the `process_id` is different for the two processes. Also, substitute the appropriate coordinator_address in `i.p.ad.dr:port` Example usage 2 (1 node, 2 processes, 4 GPUs each): ``` CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=127.0.0.1:12345 \ --num_processes=2 \ --process_id=0 & \ && CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=127.0.0.1:12345 \ --num_processes=2 \ --process_id=1 ``` More information about multiprocess JAX runs: jax-ml/jax#2731 Note: T5x partitioning fix: google-research#608 complements this change. Fixes google-research#410/google-research#89
I don't have a node with 8 gpus. I have two nodes each with 4 gpus. So is it possible to train a model on multiple nodes?
The text was updated successfully, but these errors were encountered: