Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

num_partitions does not work for GPU #410

Closed
Namco0816 opened this issue Apr 4, 2022 · 8 comments
Closed

num_partitions does not work for GPU #410

Namco0816 opened this issue Apr 4, 2022 · 8 comments
Assignees

Comments

@Namco0816
Copy link

Namco0816 commented Apr 4, 2022

Hi thanks for the great work.
I were already carefully read the docs of the partitioning, but I am still confused about how it works and what did the partitioning rules means.
I tried to run the pertaining code on a single node with 8-A100 GPU. When I pretrain the T5 with the huggingface trainer and deepspeed Zero-2, it works well. However I tried to run the pretrain with the scripts provided in the examples with

partitioning.PjitPartitioner:
  num_partitions = 1
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

partitioning.standard_logical_axis_rules:
  activation_partitioning_dims = 2
  parameter_partitioning_dims = 2

,

I get the following errors:

56   │ Traceback (most recent call last):
  57   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 196, in _run_module_as_main
  58   │     return _run_code(code, main_globals, None,
  59   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 86, in _run_code
  60   │     exec(code, run_globals)
  61   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 659, in <module>
  62   │     gin_utils.run(main)
  63   │   File "/mnt/cache/namco/t5x/t5x/gin_utils.py", line 105, in run
  64   │     app.run(
  65   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 312, in run
  66   │     _run_main(main, args)
  67   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
  68   │     sys.exit(main(argv))
  69   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 637, in main
  70   │     _main(argv)
  71   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 657, in _main
  72   │     train_using_gin()
  73   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1605, in gin_wrapper
  74   │     utils.augment_exception_message_and_reraise(e, err_str)
  75   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
  76   │     raise proxy.with_traceback(exception.__traceback__) from None
  77   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1582, in gin_wrapper
  78   │     return fn(*new_args, **new_kwargs)
  79   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 321, in train
  80   │     train_state = train_state_initializer.from_checkpoint_or_scratch(
  81   │   File "/mnt/cache/namco/t5x/t5x/utils.py", line 523, in from_checkpoint_or_scratch
  82   │     or self.from_scratch(init_rng))
  83   │   File "/mnt/cache/namco/t5x/t5x/utils.py", line 395, in from_scratch
  84   │     return p_initialize_train_state_fn(init_rng)
  85   │   File "/mnt/cache/namco/t5x/t5x/partitioning.py", line 729, in __call__
  86   │     return self._pjitted_fn(*args)
  87   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 267, in wrapped
  88   │     args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
  89   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 246, in infer_params
  90   │     jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
  91   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
  92   │     ans = call(fun, *args)
  93   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 411, in _pjit_jaxpr
  94   │     _check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape,
  95   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 588, in _check_shapes_against_resources
  96   │     raise ValueError(f"One of {what} was given the resource assignment "
  97   │ ValueError: One of pjit outputs was given the resource assignment of PartitionSpec('model', None), which implies that the size of its dimension 0 should be divisib
       │ le by 8, but it is equal to 12
  98   │   In call to configurable 'train' (<function train at 0x7f598523c790>)

Could you please help me to fix this error?

@adarob
Copy link
Collaborator

adarob commented Apr 4, 2022

Thanks for reporting this -- it looks like a GPU-specific bug in T5X. Can you try setting

partitioning.PjitPartitioner:
  model_parallel_submesh = (1, 1)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

@adarob
Copy link
Collaborator

adarob commented Apr 4, 2022

@jekbradbury as FYI

Also, I think your activation/parameter partitioning dims options will be no-ops since you're not using model parallelism.

@adarob adarob changed the title Issues about partitioning num_partitions does not work for GPU Apr 4, 2022
@Namco0816
Copy link
Author

Thanks for your reply!
I've modified the code with:
partitioning.PjitPartitioner:

  model_parallel_submesh = (1, 1)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

However I still get the error:

Traceback (most recent call last):
2296   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 196, in _run_module_as_main
2297   │     return _run_code(code, main_globals, None,
2298   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 86, in _run_code
2299   │     exec(code, run_globals)
2300   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 659, in <module>
2301   │     gin_utils.run(main)
2302   │   File "/mnt/cache/namco/t5x/t5x/gin_utils.py", line 105, in run
2303   │     app.run(
2304   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 312, in run
2305   │     _run_main(main, args)
2306   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
2307   │     sys.exit(main(argv))
2308   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 637, in main
2309   │     _main(argv)
2310   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 657, in _main
2311   │     train_using_gin()
2312   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1605, in gin_wrapper
2313   │     utils.augment_exception_message_and_reraise(e, err_str)
2314   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
2315   │     raise proxy.with_traceback(exception.__traceback__) from None
2316   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1582, in gin_wrapper
2317   │     return fn(*new_args, **new_kwargs)
2318   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 507, in train
2319   │     trainer.compile_train(first_batch)
2320   │   File "/mnt/cache/namco/t5x/t5x/trainer.py", line 549, in compile_train
2321   │     self._compiled_train_step = self._partitioner.compile(
2322   │   File "/mnt/cache/namco/t5x/t5x/partitioning.py", line 779, in compile
2323   │     return partitioned_fn.lower(*args).compile()
2324   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/stages.py", line 174, in compile
2325   │     self._lowering.compile(), self.in_tree, self.in_avals,
2326   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2280, in compile
2327   │     self._executable = MeshExecutable.from_hlo(
2328   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2371, in from_hlo
2329   │     xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options)
2330   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 583, in compile_or_get_cached
2331   │     return backend_compile(backend, computation, compile_options)
2332   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
2333   │     return func(*args, **kwargs)
2334   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 537, in backend_compile
2335   │     return backend.compile(built_c, compile_options=options)
2336   │ RuntimeError: UNIMPLEMENTED: Requested AllReduce not implemented on GPU; replica_count: 1; partition_count: 8, group_mode: kCrossReplicaAndPartition, operand_count
       │ : 26; NCCL support: 1; first operand array element-type: BF16
2337   │   In call to configurable 'train' (<function train at 0x7f3d01a81240>)
2338   │ Rewritten gin arg: --gin_bindings=MODEL_DIR = "/mnt/lustre/namco/jax-model/t5-base"

I am also very confused about the partition rules in the partitioning.py. I've noticed that the get_gpu_mesh will return the mesh for gpu. This function will return a (1, 8) mesh for my 8 A100 GPUs machine. The first dimension 1 represents the host_num and second dimension 8 represent the num_gpus. If I understand correctly, based on the partition rules, 'data' will be assigned to the first axis and 'model' will be assigned to the second axis. However the simple ddp will shard the input_data across 8 GPUs, which means that the partition rules should be ("batch", "model"), however the code provided in the partitioning example define the data parallel as ("batch", "data"). I am really confused about this part.

Thanks for your help!

@ibulu
Copy link

ibulu commented Apr 29, 2022

I think I am getting a related error when trying to fine-tune longT5 model on CPU:

ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x17a77b940>)

@sudhakarsingh27
Copy link

sudhakarsingh27 commented May 6, 2022

TLDR;

@adarob For data+model parallelism on GPUs, is model_parallel_submesh=(1,1,1,<#GPU for model parallelism>) the way to go for a single node multi-gpu case (seems to be suggested by this line in the code as well)?
Thanks!


More context:

I'm working with this config: t5_1_1/base.gin and I faced the same error as OP when I used the default partitioning config. (My intention is to run data+model parallel).

I followed @adarob 's suggestion but couldn't get it running.

Thanks for reporting this -- it looks like a GPU-specific bug in T5X. Can you try setting

partitioning.PjitPartitioner:
  model_parallel_submesh = (1, 1)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

After playing around, I found the following partitioning rule:

partitioning.PjitPartitioner:
  model_parallel_submesh = (1, 1, 1, 2)
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

seems to get the following mesh of devices [4,2] (and mapping to ('data', 'model')) which is what I wanted.

[[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)]
 [GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)]
 [GpuDevice(id=4, process_index=0) GpuDevice(id=5, process_index=0)]
 [GpuDevice(id=6, process_index=0) GpuDevice(id=7, process_index=0)]]

@sudhakarsingh27
Copy link

Also, as @Namco0816 pointed out, with using data parallelism only, the GPU mesh returned is

[[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)                                                                                                                                        
  GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)                                                                                                                                        
  GpuDevice(id=4, process_index=0) GpuDevice(id=5, process_index=0)                                                                                                                                        
  GpuDevice(id=6, process_index=0) GpuDevice(id=7, process_index=0)]]

which is of dim [1,8] and maps to ('data','model') axes and so there's effectively no data parallelism even when data only parallelism is selected. I think this function is completely agnostic of data/model parallel axes/partitions and therefore we see the issue.
Can someone confirm this? @adarob @jekbradbury

sudhakarsingh27 added a commit to sudhakarsingh27/t5x that referenced this issue Jun 16, 2022
Currently, partitioning on GPUs is broken since neither `num_partions` nor
`model_parallel_submesh` creates an appropriate mesh for GPUs.
Out-of-the-box, the default GPU mesh created is `[1, #num_gpus]` which is
incorrect if "data-only parallelism mode is to be selected.

Correct way would be to create a mesh with `[#num_gpus, 1]` dimensions for
"data-only parallelism" and more generally,
`[jax.local_device_count // num_partions, num_partions]`
should be the mesh dimensions for data+model parallelism. Even more
generally, mesh should also account for multiple processes (which might
be running on separate nodes.)

For more info, look at `create_hybrid_device_mesh` method present in
https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py

This PR fixes google-research#410
@StephennFernandes
Copy link

StephennFernandes commented Jun 16, 2022

@adarob , some of us are actually trying to pretrain and finetune locally on GPUs, as i fear reduced batch_size could affect generalizations. is there DeepSpeed ZeRO integration in t5x ?

sudhakarsingh27 added a commit to sudhakarsingh27/t5x that referenced this issue Jun 23, 2022
To run T5x on multi-node and multi-GPUs, `jax.distributed.initialize`
needs to be called with appropriate setup as mentioned here:
jax-ml/jax#8364.
Added a command line flag - `multiprocess` to enable multiprocess T5x run
on GPUs.  Also, added command line flags for the arguments to
`jax.distributed.initialize`, namely - `coordinator_address`,
`num_processes` and `process_id`.

Example usage 1 (2 processes, running on 2 separate nodes, 8 GPUs each):
On the first node:
```
python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=i.p.ad.dr:port \
  --num_processes=2 \
  --process_id=0
```

On the second node:
```
python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=i.p.ad.dr:port \
  --num_processes=2 \
  --process_id=1
```
Notice that the `process_id` is different for the two processes. Also,
substitute the appropriate coordinator_address in `i.p.ad.dr:port`

Example usage 2 (1 node, 2 processes, 4 GPUs each):
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=127.0.0.1:12345 \
  --num_processes=2 \
  --process_id=0 & \
  && CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=127.0.0.1:12345 \
  --num_processes=2 \
  --process_id=1
```

More information about multiprocess JAX runs:
jax-ml/jax#2731

Note: T5x partitioning fix: google-research#608
complements this change.

Fixes google-research#410/google-research#89
@adarob
Copy link
Collaborator

adarob commented Jul 6, 2022

Fixed by #643

@adarob adarob closed this as completed Jul 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants