-
Notifications
You must be signed in to change notification settings - Fork 307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
num_partitions
does not work for GPU
#410
Comments
Thanks for reporting this -- it looks like a GPU-specific bug in T5X. Can you try setting
|
@jekbradbury as FYI Also, I think your activation/parameter partitioning dims options will be no-ops since you're not using model parallelism. |
Thanks for your reply!
However I still get the error:
I am also very confused about the partition rules in the partitioning.py. I've noticed that the Thanks for your help! |
I think I am getting a related error when trying to fine-tune longT5 model on CPU: ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel |
TLDR;@adarob For data+model parallelism on GPUs, is More context:I'm working with this config: t5_1_1/base.gin and I faced the same error as OP when I used the default partitioning config. (My intention is to run data+model parallel). I followed @adarob 's suggestion but couldn't get it running.
After playing around, I found the following partitioning rule:
seems to get the following mesh of devices
|
Also, as @Namco0816 pointed out, with using
which is of dim |
Currently, partitioning on GPUs is broken since neither `num_partions` nor `model_parallel_submesh` creates an appropriate mesh for GPUs. Out-of-the-box, the default GPU mesh created is `[1, #num_gpus]` which is incorrect if "data-only parallelism mode is to be selected. Correct way would be to create a mesh with `[#num_gpus, 1]` dimensions for "data-only parallelism" and more generally, `[jax.local_device_count // num_partions, num_partions]` should be the mesh dimensions for data+model parallelism. Even more generally, mesh should also account for multiple processes (which might be running on separate nodes.) For more info, look at `create_hybrid_device_mesh` method present in https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py This PR fixes google-research#410
@adarob , some of us are actually trying to pretrain and finetune locally on GPUs, as i fear reduced batch_size could affect generalizations. is there DeepSpeed ZeRO integration in t5x ? |
To run T5x on multi-node and multi-GPUs, `jax.distributed.initialize` needs to be called with appropriate setup as mentioned here: jax-ml/jax#8364. Added a command line flag - `multiprocess` to enable multiprocess T5x run on GPUs. Also, added command line flags for the arguments to `jax.distributed.initialize`, namely - `coordinator_address`, `num_processes` and `process_id`. Example usage 1 (2 processes, running on 2 separate nodes, 8 GPUs each): On the first node: ``` python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=i.p.ad.dr:port \ --num_processes=2 \ --process_id=0 ``` On the second node: ``` python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=i.p.ad.dr:port \ --num_processes=2 \ --process_id=1 ``` Notice that the `process_id` is different for the two processes. Also, substitute the appropriate coordinator_address in `i.p.ad.dr:port` Example usage 2 (1 node, 2 processes, 4 GPUs each): ``` CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=127.0.0.1:12345 \ --num_processes=2 \ --process_id=0 & \ && CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ${T5X_DIR}/t5x/train.py \ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --tfds_data_dir=${TFDS_DATA_DIR} \ --multiprocess \ --coordinator_address=127.0.0.1:12345 \ --num_processes=2 \ --process_id=1 ``` More information about multiprocess JAX runs: jax-ml/jax#2731 Note: T5x partitioning fix: google-research#608 complements this change. Fixes google-research#410/google-research#89
Fixed by #643 |
Hi thanks for the great work.
I were already carefully read the docs of the partitioning, but I am still confused about how it works and what did the partitioning rules means.
I tried to run the pertaining code on a single node with 8-A100 GPU. When I pretrain the T5 with the huggingface trainer and deepspeed Zero-2, it works well. However I tried to run the pretrain with the scripts provided in the examples with
,
I get the following errors:
Could you please help me to fix this error?
The text was updated successfully, but these errors were encountered: