Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add init_system API for multi-host GPU. #8364

Merged
merged 1 commit into from
Oct 28, 2021

Conversation

zhangqiaorjc
Copy link
Collaborator

@zhangqiaorjc zhangqiaorjc commented Oct 25, 2021

Currently JAX doesn't expose our multi-host GPU backend to our open source users.

This PR exposes an experimental API jax.distributed.initialize to initialize the multi-host GPU backend.

I tested it on 2 GPU VMs on GCP

== VM0

$ TF_CPP_MIN_LOG_LEVEL=0 python -c "import jax; jax.distributed.initialize('10.128.0.47:1456', 2, 0)"
2021-10-26 01:19:17.133471: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/service.cc:369] Jax service listening on 10.128.0.47:1456
2021-10-26 01:19:28.339404: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/client.cc:129] Connected to distributed JAX controller
2021-10-26 01:19:28.375758: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/client.cc:163] Waiting for all distributed JAX tasks to shut down.
2021-10-26 01:19:28.376179: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/client.cc:180] Distributed task shutdown result: OK
2021-10-26 01:19:28.397722: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/service.cc:381] Jax service shutting down

== VM1

$ TF_CPP_MIN_LOG_LEVEL=0 python -c "import jax; jax.distributed.initialize('10.128.0.47:1456', 2, 1)"
2021-10-26 01:19:28.339474: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/client.cc:129] Connected to distributed JAX controller
2021-10-26 01:19:28.371461: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/client.cc:163] Waiting for all distributed JAX tasks to shut down.
2021-10-26 01:19:28.376473: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distribute
d/client.cc:180] Distributed task shutdown result: OK

An example of this API in action, see code that does model parallelism on 2 GPU VMs

https://gist.github.com/zhangqiaorjc/0ae6e7114fb0b3e9243e6420e4d6f3e4

See screenshot for results

https://photos.google.com/share/AF1QipMfIpFOpmckl86lU4WS4nb2IzMDkrOqLyafa4C3Vx7zMqoyy6NOM8PiS8gH7zaLIw?key=bjVRWHZoRmFUTkhhLVBOdzFlYWg4bG5nZ3NJYVpB

jax/_src/dist_system.py Outdated Show resolved Hide resolved
jax/_src/dist_system.py Outdated Show resolved Hide resolved
jax/dist_system.py Outdated Show resolved Hide resolved
client.connect()

factory = functools.partial(xla_client.make_gpu_client, client, node_id)
xla_bridge.register_backend_factory('gpu', factory, priority=300)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should verify the GPU backend has not already been initialized.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xla_bridge.backends() seems to init backends only once already

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but you still need to ensure that you issue an error if the backend has already been initialized.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check and will raise RuntimeError if gpu backend is already initialized

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also updated xla_bridge to error out if a backend is already registered when we try to register a new one

@zhangqiaorjc zhangqiaorjc force-pushed the dsys branch 3 times, most recently from c30ec02 to a774a19 Compare October 26, 2021 20:05
jax/_src/distributed.py Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
def init_system(coordinator_address: str, num_processes: int, process_id: int):
"""Initialize distributed system for topology discovery.

init_system is required to setup the runtime for multi-host GPU usage.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give an example of how to use it and when/why.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipping doctest

@zhangqiaorjc zhangqiaorjc force-pushed the dsys branch 2 times, most recently from 2cb595b to a0a1559 Compare October 26, 2021 20:57
@zhangqiaorjc zhangqiaorjc self-assigned this Oct 26, 2021
@zhangqiaorjc zhangqiaorjc force-pushed the dsys branch 3 times, most recently from f2dacf1 to c320d65 Compare October 26, 2021 21:19
@copybara-service copybara-service bot merged commit 934bfc0 into jax-ml:main Oct 28, 2021
sudhakarsingh27 added a commit to sudhakarsingh27/t5x that referenced this pull request Jun 23, 2022
To run T5x on multi-node and multi-GPUs, `jax.distributed.initialize`
needs to be called with appropriate setup as mentioned here:
jax-ml/jax#8364.
Added a command line flag - `multiprocess` to enable multiprocess T5x run
on GPUs.  Also, added command line flags for the arguments to
`jax.distributed.initialize`, namely - `coordinator_address`,
`num_processes` and `process_id`.

Example usage 1 (2 processes, running on 2 separate nodes, 8 GPUs each):
On the first node:
```
python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=i.p.ad.dr:port \
  --num_processes=2 \
  --process_id=0
```

On the second node:
```
python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=i.p.ad.dr:port \
  --num_processes=2 \
  --process_id=1
```
Notice that the `process_id` is different for the two processes. Also,
substitute the appropriate coordinator_address in `i.p.ad.dr:port`

Example usage 2 (1 node, 2 processes, 4 GPUs each):
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=127.0.0.1:12345 \
  --num_processes=2 \
  --process_id=0 & \
  && CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=127.0.0.1:12345 \
  --num_processes=2 \
  --process_id=1
```

More information about multiprocess JAX runs:
jax-ml/jax#2731

Note: T5x partitioning fix: google-research#608
complements this change.

Fixes google-research#410/google-research#89
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants