-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add init_system API for multi-host GPU. #8364
Conversation
dd39818
to
8205350
Compare
jax/_src/dist_system.py
Outdated
client.connect() | ||
|
||
factory = functools.partial(xla_client.make_gpu_client, client, node_id) | ||
xla_bridge.register_backend_factory('gpu', factory, priority=300) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should verify the GPU backend has not already been initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xla_bridge.backends() seems to init backends only once already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but you still need to ensure that you issue an error if the backend has already been initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a check and will raise RuntimeError if gpu backend is already initialized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also updated xla_bridge to error out if a backend is already registered when we try to register a new one
c30ec02
to
a774a19
Compare
jax/_src/distributed.py
Outdated
def init_system(coordinator_address: str, num_processes: int, process_id: int): | ||
"""Initialize distributed system for topology discovery. | ||
|
||
init_system is required to setup the runtime for multi-host GPU usage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give an example of how to use it and when/why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipping doctest
2cb595b
to
a0a1559
Compare
f2dacf1
to
c320d65
Compare
To run T5x on multi-node and multi-GPUs, `jax.distributed.initialize` needs to be called with appropriate setup as mentioned here: jax-ml/jax#8364. Added a command line flag - `multiprocess` to enable multiprocess T5x run on GPUs. Also, added command line flags for the arguments to `jax.distributed.initialize`, namely - `coordinator_address`, `num_processes` and `process_id`. Example usage 1 (2 processes, running on 2 separate nodes, 8 GPUs each): On the first node: ``` python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=i.p.ad.dr:port \ --num_processes=2 \ --process_id=0 ``` On the second node: ``` python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=i.p.ad.dr:port \ --num_processes=2 \ --process_id=1 ``` Notice that the `process_id` is different for the two processes. Also, substitute the appropriate coordinator_address in `i.p.ad.dr:port` Example usage 2 (1 node, 2 processes, 4 GPUs each): ``` CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=127.0.0.1:12345 \ --num_processes=2 \ --process_id=0 & \ && CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=127.0.0.1:12345 \ --num_processes=2 \ --process_id=1 ``` More information about multiprocess JAX runs: jax-ml/jax#2731 Note: T5x partitioning fix: google-research#608 complements this change. Fixes google-research#410/google-research#89
Currently JAX doesn't expose our multi-host GPU backend to our open source users.
This PR exposes an experimental API
jax.distributed.initialize
to initialize the multi-host GPU backend.I tested it on 2 GPU VMs on GCP
== VM0
== VM1
An example of this API in action, see code that does model parallelism on 2 GPU VMs
https://gist.github.com/zhangqiaorjc/0ae6e7114fb0b3e9243e6420e4d6f3e4
See screenshot for results
https://photos.google.com/share/AF1QipMfIpFOpmckl86lU4WS4nb2IzMDkrOqLyafa4C3Vx7zMqoyy6NOM8PiS8gH7zaLIw?key=bjVRWHZoRmFUTkhhLVBOdzFlYWg4bG5nZ3NJYVpB