Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Excessive GPU-GPU communication with GPT2 making multi-GPU training slow? #9371

Closed
2 of 4 tasks
moyix opened this issue Dec 31, 2020 · 27 comments
Closed
2 of 4 tasks
Labels
Benchmarks Issues related to Memory regressions in tests and scripts Performance wontfix

Comments

@moyix
Copy link

moyix commented Dec 31, 2020

Summary: on a multi-GPU system, training GPT2 seems to scale poorly unless a very fast GPU-GPU interconnect like NVLink is available. In particular, without NVLink using two GPUs is slower than using just one GPU.

Environment info

  • transformers version: 4.1.1
  • Platform: Linux-5.8.0-rc7-custom-x86_64-with-glibc2.29
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.8.0.dev20201214+cu110 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No?
  • Hardware: 2 x NVIDIA RTX 3090 w/NVLink

Who can help

Maybe @LysandreJik or @patrickvonplaten ?

Information

Model I am using (Bert, XLNet ...): GPT2

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The script is a pretty basic example of training a medium-size GPT2 from scratch. The script is here: https://panda.moyix.net/~moyix/train_csrc.py

The dataset and tokenized vocab:

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

Training a GPT2 language model on C source code.

To reproduce

Run with only one GPU: CUDA_VISIBLE_DEVICES=0 python train_csrc.py

Run with two GPUs, NVLink disabled: NCCL_P2P_DISABLE=1 python train_csrc.py

Run with two GPUs and NVLink enabled: python train_csrc.py

Here is some benchmarking I did with my dataset on transformers 3.3.1 and 4.1.1 (note the difference in ETA is just because 3.3.1 only seems to report the ETA for the current epoch):

Version NVLINK GPUs ETA Perf
4.1.1 Yes 2GPU 419:52:28 1.94it/s
4.1.1 No 2GPU 1025:06:27 1.26s/it
4.1.1 N/A 1GPU 599:14:57 2.72it/s
3.3.1 Yes 2GPU 83:46:51 1.94it/s
3.3.1 No 2GPU 204:54:22 1.26s/it
3.3.1 N/A 1GPU 119:02:34 2.73it/s

You can see that using two GPUs is actually slower than using a single GPU, unless NVLink is available (599 hours for 1 GPU vs 1025 hours for two GPUs). So presumably there is a large amount of GPU-GPU communication going on?

Expected behavior

Scaling should be roughly linear with the number of GPUs. Unfortunately I am not very familiar with the implementation details of GPT2 in Huggingface, but others report roughly linear scaling with Transformer models like BERT so it should work here as well: https://towardsdatascience.com/training-bert-at-a-university-eedcf940c754

Although I have a system with NVLink at home, this issue is still affecting me because I would like to be able to run this on the university HPC cluster, where most nodes do not have NVLink.

@julien-c
Copy link
Member

julien-c commented Jan 1, 2021

Not an answer to your issue/question, but have you tried running in distributed training (DDP), which is the recommended way of running over multiple GPUs: https://github.com/huggingface/transformers/tree/master/examples#distributed-training-and-mixed-precision

Would be curious to see the same with/without NVLink experiment there.

@moyix
Copy link
Author

moyix commented Jan 1, 2021

Hmm, I don't have much experience using torch.distributed. I tried just running the existing script with python -m torch.distributed.launch --nproc_per_node 2 train.py, but it runs out of GPU memory almost immediately, so I assume I'm doing something wrong.

If you have a link to some documentation that explains how to set up the training script so that it can be used with torch.distributed, I can give that a try.

@julien-c
Copy link
Member

julien-c commented Jan 2, 2021

The command you posted "should" work.

@sgugger might have links to better content when he's back, but the PyTorch tutorials are pretty good: https://pytorch.org/tutorials/beginner/dist_overview.html#data-parallel-training

Your initial experiment is using DataParallel (not DistributedDataParallel) under the hood.

@moyix
Copy link
Author

moyix commented Jan 12, 2021

OK, I got around to spending some more time with this today. I realized that the run_language_modeling.py script can do everything my script was doing, and it uses DDP by default (Note: looking at the most recent version on git, I see that run_language_modeling.py has been replaced by run_clm.py. However, after trying to upgrade transformers to that version, it seems to no longer use the GPU for reasons I don't have time to debug.).

So now I'm just using that, with:

python -m torch.distributed.launch --nproc_per_node 2 \
    ~/git/transformers/examples/language-modeling/run_language_modeling.py \
    --model_type gpt2 \
    --config_name ./csrc_config \
    --tokenizer_name ./csrc_tokenizer \
    --fp16 --fp16_opt_level O3 \
    --do_train --output_dir csrc_output \
    --per_device_train_batch_size 4 \
    --train_data_file plainsrc_all.txt --block_size 128

For single GPU I drop the torch.distributed.launch and use CUDA_VISIBLE_DEVICES=1, to disable NVLINK I use NCCL_P2P_DISABLE=1 as before. The --block_size 128 argument is to match the default from my training script (without it I run out of GPU RAM).

Results:

Model Block Size GPUs NVLINK ETA Perf
Small 512 2GPU No 17:08:12 4.75it/s
Small 512 2GPU Yes 10:24:20 7.79it/s
Small 512 1GPU N/A 18:37:17 8.74it/s
Medium 512 2GPU No 43:07:49 1.89it/s
Medium 512 2GPU Yes 26:19:09 3.09it/s
Medium 512 1GPU N/A 45:36:37 3.57it/s
Small 128 2GPU No 48:12:05 6.75it/s
Small 128 2GPU Yes 21:26:31 15.17it/s
Small 128 1GPU N/A 30:54:41 21.06it/s
Medium 128 2GPU No 118:43:09 2.74it/s
Medium 128 2GPU Yes 51:55:58 6.27it/s
Medium 128 1GPU N/A 74:02:16 8.79it/s
Large 128 2GPU No 239:19:44 1.36it/s
Large 128 2GPU Yes 102:17:18 3.18it/s
Large 128 1GPU N/A 143:34:42 4.54it/s

So the general observation is that for block size 512, two GPUs without NVLink are about the same performance as a single GPU. For block size 128, two GPUs without NVLink are typically quite a bit slower than a single GPU.

It doesn't seem like DistributedDataParallel helps with this issue, in other words.

@LysandreJik
Copy link
Member

I think @sgugger has experience with multi-GPU, and works on the example scripts, pinging him!

@TimDettmers
Copy link
Contributor

A friend was linking me to this issue. Thank you for your work on this benchmark! It is some interesting data. I still believe the poor performance could be a hardware issue though.

As far as I know, RTX 3090 GPUs have peer-to-peer access disable, or in other words, you cannot transfer memory from GPU to GPU on these GPUs. All data is first routed through the CPU, which is often slow because the CPU buffers are not pinned, meaning that memory transfers are synchronous. So in my eyes, slow performance without NVLink is a hardware issue in this case. It would be curious, though, if these numbers would be similar for peer-to-peer enabled GPUs. Do you have access to such a GPU?

@moyix
Copy link
Author

moyix commented Jan 22, 2021

You're thinking of something like P2P over PCIe? You're right that NVIDIA has disabled that for the 3090s. The only other hardware I have access to is our HPC cluster, which has RTX8000s and V100s (non-NVLINKed); I believe both show similar slowdowns.

One thing I have been looking into is whether using something like DeepSpeed will help. I got their Megatron-LM example working and it does much better at scaling to two at least GPUs without NVLINK using the 1-bit Adam optimizer. I'm still waiting for my HPC job to get scheduled to confirm that it scales well there too. If that works then presumably something like what's being done for the t5-3b model here would help? #8771

@TimDettmers
Copy link
Contributor

If you confirm you have the same results for the RTX 8000 that would rule out any GPU issue. It could still be a hardware issue with PCIe lanes. There is a bandwidth test I believe among the NVIDIA samples that come with CUDA with which you can test the available bandwidth to/from GPUs. If this shows good numbers it should be purely an issue of software or network architecture.

@moyix
Copy link
Author

moyix commented Jan 23, 2021

OK, I'll give this a try. Our HPC cluster is a bit busy so it may be a while before I can get a slot on the RTX 8000 nodes.

@moyix
Copy link
Author

moyix commented Jan 23, 2021

I managed to get some time on a node with 4x V100s. For the Large model, it gets 3.83s/it with an ETA of 1248:01:43 (!).

Here's the output of p2pBandwidthLatencyTest on the V100 system:

[bd52@gv02 p2pBandwidthLatencyTest]$ ./p2pBandwidthLatencyTest 
[P2P (Peer-to-Peer) GPU Bandwidth Latency Test]
Device: 0, Tesla V100-PCIE-32GB, pciBusID: 6, pciDeviceID: 0, pciDomainID:0
Device: 1, Tesla V100-PCIE-32GB, pciBusID: 2f, pciDeviceID: 0, pciDomainID:0
Device: 2, Tesla V100-PCIE-32GB, pciBusID: 86, pciDeviceID: 0, pciDomainID:0
Device: 3, Tesla V100-PCIE-32GB, pciBusID: d8, pciDeviceID: 0, pciDomainID:0
Device=0 CAN Access Peer Device=1
Device=0 CAN Access Peer Device=2
Device=0 CAN Access Peer Device=3
Device=1 CAN Access Peer Device=0
Device=1 CAN Access Peer Device=2
Device=1 CAN Access Peer Device=3
Device=2 CAN Access Peer Device=0
Device=2 CAN Access Peer Device=1
Device=2 CAN Access Peer Device=3
Device=3 CAN Access Peer Device=0
Device=3 CAN Access Peer Device=1
Device=3 CAN Access Peer Device=2

***NOTE: In case a device doesn't have P2P access to other one, it falls back to normal memcopy procedure.
So you can see lesser Bandwidth (GB/s) and unstable Latency (us) in those cases.

P2P Connectivity Matrix
     D\D     0     1     2     3
     0       1     1     1     1
     1       1     1     1     1
     2       1     1     1     1
     3       1     1     1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3 
     0 768.57  11.42  11.52  11.53 
     1  11.39 770.46  11.50  11.53 
     2  11.42  11.43 771.22  11.45 
     3  11.42  11.43  11.44 769.70 
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2      3 
     0 767.06   9.93   9.68   9.49 
     1   9.93 769.33   9.33   9.50 
     2   9.87   9.35 769.70  10.05 
     3   9.66   9.68   9.92 770.08 
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3 
     0 771.22  15.98  16.04  16.16 
     1  16.00 773.51  16.11  16.07 
     2  15.90  15.99 772.75  15.83 
     3  16.05  16.01  15.85 772.55 
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3 
     0 770.84  18.72  18.41  18.07 
     1  18.52 772.94  18.82  18.30 
     2  18.41  18.16 771.80  19.13 
     3  18.40  17.99  18.94 771.22 
P2P=Disabled Latency Matrix (us)
   GPU     0      1      2      3 
     0   1.89  14.77  14.42  14.59 
     1  14.52   1.91  15.50  15.50 
     2  15.53  15.42   1.87  14.44 
     3  14.76  14.71  14.51   1.82 

   CPU     0      1      2      3 
     0   2.52   8.33   8.61   8.55 
     1   8.20   2.49   8.50   8.49 
     2   8.30   8.29   2.61   8.69 
     3   8.41   8.36   8.74   2.56 
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1      2      3 
     0   1.86   1.60   1.65   1.64 
     1   1.59   1.91   1.64   1.65 
     2   1.65   1.63   1.88   1.58 
     3   1.65   1.64   1.59   1.82 

   CPU     0      1      2      3 
     0   2.51   2.05   2.02   2.02 
     1   2.14   2.54   2.04   2.02 
     2   2.28   2.18   2.61   2.18 
     3   2.32   2.19   2.24   2.73 

NOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.

And for comparison, here's the dual 3090 w/NVLINK system:

[P2P (Peer-to-Peer) GPU Bandwidth Latency Test]
Device: 0, GeForce RTX 3090, pciBusID: 1, pciDeviceID: 0, pciDomainID:0
Device: 1, GeForce RTX 3090, pciBusID: 21, pciDeviceID: 0, pciDomainID:0
Device=0 CAN Access Peer Device=1
Device=1 CAN Access Peer Device=0

***NOTE: In case a device doesn't have P2P access to other one, it falls back to normal memcopy procedure.
So you can see lesser Bandwidth (GB/s) and unstable Latency (us) in those cases.

P2P Connectivity Matrix
     D\D     0     1
     0       1     1
     1       1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1 
     0 831.56  11.25 
     1  11.33 831.12 
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1 
     0 810.85  52.77 
     1  52.85 832.89 
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1 
     0 812.31  16.55 
     1  16.75 838.03 
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1 
     0 821.29 101.41 
     1 101.80 835.34 
P2P=Disabled Latency Matrix (us)
   GPU     0      1 
     0   1.59  33.13 
     1  20.55   1.48 

   CPU     0      1 
     0   2.89   8.85 
     1   8.81   2.85 
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1 
     0   1.59   1.43 
     1   1.40   1.47 

   CPU     0      1 
     0   2.93   2.45 
     1   2.39   2.90 

@TimDettmers
Copy link
Contributor

Thank you - these data are very valuable! It also shows that no hardware problem exists. It seems you could confirm poor performance on the V100 which makes it very likely that you can also reproduce performance issues with the RTX 8000. With that, it seems the only option is that it is an issue with the combination of parallelism and network architecture.

@stas00
Copy link
Contributor

stas00 commented Jan 26, 2021

Great benchmarks! Thank you for sharing the data, @moyix

Do you have the same benchmarks for V100s too - just one set is enough (1 vs 2).

Also, why are you running comparison benchmarks on such huge number of items? Running enough items so that runtime is around a few minutes should be plenty to see the difference. Or is it that you were aborting these early and just recording the projected ETA and it/s from tqdm? e.g. --max_steps 1000

Here are some ideas that may address your issue

  1. If I understand things right 3090 won't work at full capacity until we get pytorch w/ cuda-11.2
    Support CUDA 11.2 pytorch/pytorch#50232
    I don't know the nuances yet, but could it be that the communication channel is limited with cuda-11.0?

    That's why I wanted to see the results from VT100

  2. In one place it was suggested to check how your GPUs are inter-connected with help of:

    nvidia-smi topo -m
    

    that's do this check with NVLink disconnected.

  3. Also are sure your both GPUs running on the same speed PCIx (e.g. 8x if it's a consumer MB)? It must be, but just checking. I suppose doing a single GPU test on the other GPU would show if it's somehow on a slow PCIx slot. But I'd just test to rule that out. Should you get a slower outcome doing the same test on the 2nd gpu would explain the situation.

@stas00
Copy link
Contributor

stas00 commented Jan 26, 2021

OK, so here is my benchmark with the same tool.

edit: my initial benchmark had a bug in it as pointed out by @sgugger as one has to tweak --max_steps if changed to more gpus - I'm proposing to change that and have a way to have a fixed dataset truncation regardless of the number of gpus used. #9801

So for 1 gpu, I had to double --max_steps to get the same number of items. The rest of this comment has been updated to reflect the corrected state:

Hardware 2x TITAN RTX 24GB each + NVlink

type time secs
1: 204
2:DP w/ NVlink 110
2:DDP w/ NVlink 101
2:DDP w/o NVlink 131

I get the same bus report w/ and w/o NCCL_P2P_DISABLE=1 - I don't think nvidia-smi respects this env var:

NCCL_P2P_DISABLE=1 nvidia-smi topo -m

        GPU0    GPU1    CPU Affinity    NUMA Affinity
GPU0     X      NV2     0-23            N/A
GPU1    NV2      X      0-23            N/A

but clearly the runtime is much slower w/o the NVlink as the benchmark shows, so pytorch/cuda does respect it.

Analysis:

  1. DP is ~10% slower than DDP w/ NVlink, but ~15% faster than DDP w/o NVlink
  2. DDP w/ NVLink doubles the speed of single gpu, so the communication overheard is almost nill in this particular experiment

Here is the full benchmark code and outputs:

# 1 gpu

rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0 python run_clm.py --model_name_or_path gpt2 \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir \
/tmp/test-clm --per_device_train_batch_size 4 --max_steps 400

{'train_runtime': 204.8202, 'train_samples_per_second': 1.953, 'epoch': 0.69}

# DP

rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 python run_clm.py --model_name_or_path gpt2 \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir \
/tmp/test-clm --per_device_train_batch_size 4 --max_steps 200

{'train_runtime': 110.5948, 'train_samples_per_second': 1.808, 'epoch': 0.69}

# DDP

rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node 2 \
run_clm.py --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200

{'train_runtime': 101.9003, 'train_samples_per_second': 1.963, 'epoch': 0.69}

# DDP w/o NVlink

rm -r /tmp/test-clm; NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
--nproc_per_node 2 run_clm.py --model_name_or_path gpt2 --dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 --do_train --output_dir /tmp/test-clm \
--per_device_train_batch_size 4 --max_steps 200

{'train_runtime': 131.4367, 'train_samples_per_second': 1.522, 'epoch': 0.69}

@moyix
Copy link
Author

moyix commented Jan 26, 2021

Yes, apologies for the confusion; the ETA numbers above are from aborting early (after a few minutes) and noting the ETA. I actually did compile PyTorch from source with CUDA 11.2 and it doesn't seem to have changed the results (although I don't know if there are further changes PyTorch will make to take full advantage of 11.2).

Your benchmark code is much more self-contained than mine, so I will give your benchmarks a shot with the RTX8000 and V100 nodes on our cluster, but it will probably be a few days before I can get time there as the ICML deadline is very close :)

Here's nvidia-smi -m topo for the 3090 machine:

nvidia-smi topo -m
        GPU0    GPU1    CPU Affinity    NUMA Affinity
GPU0     X      NV4     0-31            N/A
GPU1    NV4      X      0-31            N/A

@sgugger
Copy link
Collaborator

sgugger commented Jan 26, 2021

Note that the timing compare 200 training steps, so the numbers you reported wrong @stas00 in the sense that 2 GPUs have covered 400 samples instead of 200. Training on the full dataset would therefore go twice as fast as with one GPU.

@stas00
Copy link
Contributor

stas00 commented Jan 26, 2021

This is correct - that my report was incorrect. Thank you for validating my concern in #9801, @sgugger

That's why I'm asking for a less confusing way to truncate the dataset.

I need to find an easy-way to do it so I don't have to be in full thinking capacity if I do it late at night which was the case last night.

I will revisit my benchmark with corrections hopefully today.

But it doesn't change the fact that nvlink gives 30% faster performance.

@stas00
Copy link
Contributor

stas00 commented Jan 26, 2021

Yes, apologies for the confusion; the ETA numbers above are from aborting early (after a few minutes) and noting the ETA.

That's what I guessed - I am glad you didn't waste all that electricity to run these to completion! It was a smart move, since you waited a few minutes.

I actually did compile PyTorch from source with CUDA 11.2 and it doesn't seem to have changed the results (although I don't know if there are further changes PyTorch will make to take full advantage of 11.2).

Oh, thank you for validating that!

Building pytorch from source is hard! Hat off to you!

Yes, we don't know whether everything has been put in place for 11.2 support.

Your benchmark code is much more self-contained than mine, so I will give your benchmarks a shot with the RTX8000 and V100 nodes on our cluster, but it will probably be a few days before I can get time there as the ICML deadline is very close :)

please note that I corrected a mistake in my benchmark as kindly pointed out by @sgugger:
#9371 (comment)

Here's nvidia-smi -m topo for the 3090 machine:

nvidia-smi topo -m
        GPU0    GPU1    CPU Affinity    NUMA Affinity
GPU0     X      NV4     0-31            N/A
GPU1    NV4      X      0-31            N/A

Looks very similar. Do you know what exactly:

  NV#  = Connection traversing a bonded set of # NVLinks

means? is NV4 better than NV2? since I get NV2. Why do you have 4? As I can see you only have 2 gpus.

@moyix
Copy link
Author

moyix commented Jan 27, 2021

According to this table NV4 means "Connection traversing a bonded set of 4 NVLinks".

There are some more details in the GA102 whitepaper:

GA102 GPUs utilize NVIDIA’s third-generation NVLink interface, which includes four x4 links, with each link providing 14.0625 GB/sec bandwidth in each direction between two GPUs. Four links provide 56.25 GB/sec bandwidth in each direction, and 112.5 GB/sec total bandwidth between two GPUs. Two RTX 3090 GPUs can be connected together for SLI using NVLink.

@stas00
Copy link
Contributor

stas00 commented Jan 27, 2021

Super! Thank you for that insight, @moyix!

I started compiling performance/scalability notes here: #9824

I summarized the useful insights from this thread. If you get a chance to validate the GPU inter-connectivity section that would be great!

And if you have other insights to contribute I'm all ears. If you don't have time/inspiration to write something complete even a stab would be great and then over time we will fill it out with details and benchmarks.

The idea is to discuss in-depth the different hardware/software nuances to speed up training and fit larger models.

Thank you!

@moyix
Copy link
Author

moyix commented Jan 27, 2021

Very nice, I will take a look at it!

While I am waiting for HPC time, I ran your benchmark script on the 3090 system while varying two parameters: the model size (gpt2, gpt2-medium, and gpt2-large) and the block size (128, 256, 512).

The script:

for MODEL in gpt2 gpt2-medium gpt2-large; do
    for BLOCK_SIZE in 128 256 512 ; do 
        # Skip gpt2-large at block size 512 due to memory constraints
        if [ $MODEL = "gpt2-large" ] && [ $BLOCK_SIZE -eq 512 ] ; then continue ; fi
        # 1 gpu

        rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0 python run_clm.py --model_name_or_path $MODEL \
            --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir \
            /tmp/test-clm --per_device_train_batch_size 4 --max_steps 400 --block_size $BLOCK_SIZE 2>&1 > /tmp/clm_bench.log
        result=$(grep train_runtime /tmp/clm_bench.log)
        echo $MODEL $BLOCK_SIZE "1GPU" $result >> clm_bench_results.log

        # DP

        rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 python run_clm.py --model_name_or_path $MODEL \
            --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir \
            /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200 --block_size $BLOCK_SIZE 2>&1 > /tmp/clm_bench.log

        result=$(grep train_runtime /tmp/clm_bench.log)
        echo $MODEL $BLOCK_SIZE "DP" $result >> clm_bench_results.log

        # DDP

        rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node 2 \
            run_clm.py --model_name_or_path $MODEL --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
            --do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200 --block_size $BLOCK_SIZE 2>&1 > /tmp/clm_bench.log

        result=$(grep train_runtime /tmp/clm_bench.log)
        echo $MODEL $BLOCK_SIZE "DDP" $result >> clm_bench_results.log

        # DDP w/o NVlink

        rm -r /tmp/test-clm; NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
            --nproc_per_node 2 run_clm.py --model_name_or_path $MODEL --dataset_name wikitext \
            --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir /tmp/test-clm \
            --per_device_train_batch_size 4 --max_steps 200 --block_size $BLOCK_SIZE 2>&1 > /tmp/clm_bench.log

        result=$(grep train_runtime /tmp/clm_bench.log)
        echo $MODEL $BLOCK_SIZE "DDP_no_NV" $result >> clm_bench_results.log
    done
done

And the results:

gpt2 128 1GPU {'train_runtime': 19.5621, 'train_samples_per_second': 20.448, 'epoch': 0.09}
gpt2 128 DP {'train_runtime': 16.6426, 'train_samples_per_second': 12.017, 'epoch': 0.09}
gpt2 128 DDP {'train_runtime': 13.5368, 'train_samples_per_second': 14.775, 'epoch': 0.09}
gpt2 128 DDP_no_NV {'train_runtime': 30.0181, 'train_samples_per_second': 6.663, 'epoch': 0.09}
gpt2 256 1GPU {'train_runtime': 30.423, 'train_samples_per_second': 13.148, 'epoch': 0.17}
gpt2 256 DP {'train_runtime': 22.6101, 'train_samples_per_second': 8.846, 'epoch': 0.17}
gpt2 256 DDP {'train_runtime': 18.6943, 'train_samples_per_second': 10.698, 'epoch': 0.17}
gpt2 256 DDP_no_NV {'train_runtime': 35.4208, 'train_samples_per_second': 5.646, 'epoch': 0.17}
gpt2 512 1GPU {'train_runtime': 58.0856, 'train_samples_per_second': 6.886, 'epoch': 0.34}
gpt2 512 DP {'train_runtime': 37.6376, 'train_samples_per_second': 5.314, 'epoch': 0.34}
gpt2 512 DDP {'train_runtime': 32.3616, 'train_samples_per_second': 6.18, 'epoch': 0.34}
gpt2 512 DDP_no_NV {'train_runtime': 49.1999, 'train_samples_per_second': 4.065, 'epoch': 0.34}
gpt2-medium 128 1GPU {'train_runtime': 49.3823, 'train_samples_per_second': 8.1, 'epoch': 0.09}
gpt2-medium 128 DP {'train_runtime': 40.5947, 'train_samples_per_second': 4.927, 'epoch': 0.09}
gpt2-medium 128 DDP {'train_runtime': 33.4365, 'train_samples_per_second': 5.981, 'epoch': 0.09}
gpt2-medium 128 DDP_no_NV {'train_runtime': 74.9924, 'train_samples_per_second': 2.667, 'epoch': 0.09}
gpt2-medium 256 1GPU {'train_runtime': 79.6724, 'train_samples_per_second': 5.021, 'epoch': 0.17}
gpt2-medium 256 DP {'train_runtime': 56.0446, 'train_samples_per_second': 3.569, 'epoch': 0.17}
gpt2-medium 256 DDP {'train_runtime': 47.7543, 'train_samples_per_second': 4.188, 'epoch': 0.17}
gpt2-medium 256 DDP_no_NV {'train_runtime': 89.3616, 'train_samples_per_second': 2.238, 'epoch': 0.17}
gpt2-medium 512 1GPU {'train_runtime': 152.6255, 'train_samples_per_second': 2.621, 'epoch': 0.34}
gpt2-medium 512 DP {'train_runtime': 92.4563, 'train_samples_per_second': 2.163, 'epoch': 0.34}
gpt2-medium 512 DDP {'train_runtime': 82.1935, 'train_samples_per_second': 2.433, 'epoch': 0.34}
gpt2-medium 512 DDP_no_NV {'train_runtime': 124.1163, 'train_samples_per_second': 1.611, 'epoch': 0.34}
gpt2-large 128 1GPU {'train_runtime': 98.5939, 'train_samples_per_second': 4.057, 'epoch': 0.09}
gpt2-large 128 DP {'train_runtime': 79.2193, 'train_samples_per_second': 2.525, 'epoch': 0.09}
gpt2-large 128 DDP {'train_runtime': 65.7918, 'train_samples_per_second': 3.04, 'epoch': 0.09}
gpt2-large 128 DDP_no_NV {'train_runtime': 152.2178, 'train_samples_per_second': 1.314, 'epoch': 0.09}
gpt2-large 256 1GPU {'train_runtime': 154.5437, 'train_samples_per_second': 2.588, 'epoch': 0.17}
gpt2-large 256 DP {'train_runtime': 106.7075, 'train_samples_per_second': 1.874, 'epoch': 0.17}
gpt2-large 256 DDP [out of memory]
gpt2-large 256 DDP_no_NV [out of memory]
gpt2-large 512 1GPU [out of memory]
gpt2-large 512 DP [out of memory]
gpt2-large 512 DDP [out of memory]
gpt2-large 152 DDP_no_NV [out of memory]

One thing that I find interesting is that the behavior I originally observed where training on a single GPU could be slower than on multiple GPUs without NVLink only seems to be true for small block sizes like 128 or (sometimes) 256. So my hypothesis is that with smaller block sizes it is effectively using smaller batches and therefore synchronizing between GPUs more often?

As soon as I can get some time on our HPC I can update this with numbers for the 4xRTX8000 and the 4xV100, although the NVLink rows will no longer be applicable (since I don't have access to a machine with those cards in NVLink/NVSwitch configuration).

@stas00
Copy link
Contributor

stas00 commented Jan 27, 2021

Awesome! Thank you for more benchmarks, @moyix

Let's apply some magic to your log:

perl -lne 'BEGIN{ print qq[|model|block|type|runtime|sample/sec|]; print "|-" x 5, "|"} $d=qr/([\d\.]+)/; m|^(\S+) $d (\S+) ..train_runtime.. $d, .train_samples_per_second.. $d| && print qq[|$1|$2|$3|$4|$5|]' log.txt

but let's round it up to make reading easier:

perl -lne 'BEGIN{ print qq[|model|block|type|runtime|sample/sec|]; print "|-" x 5, "|"} $d=qr/([\d\.]+)/; m|^(\S+) $d (\S+) ..train_runtime.. $d, .train_samples_per_second.. $d| && print qq[|$1|$2|$3|] . int($4). "|". sprintf("%0.1f", $5)."|"' log.txt
model block type runtime sample/sec
gpt2 128 1GPU 19 20.4
gpt2 128 DP 16 12.0
gpt2 128 DDP 13 14.8
gpt2 128 DDP_no_NV 30 6.7
gpt2 256 1GPU 30 13.1
gpt2 256 DP 22 8.8
gpt2 256 DDP 18 10.7
gpt2 256 DDP_no_NV 35 5.6
gpt2 512 1GPU 58 6.9
gpt2 512 DP 37 5.3
gpt2 512 DDP 32 6.2
gpt2 512 DDP_no_NV 49 4.1
gpt2-medium 128 1GPU 49 8.1
gpt2-medium 128 DP 40 4.9
gpt2-medium 128 DDP 33 6.0
gpt2-medium 128 DDP_no_NV 74 2.7
gpt2-medium 256 1GPU 79 5.0
gpt2-medium 256 DP 56 3.6
gpt2-medium 256 DDP 47 4.2
gpt2-medium 256 DDP_no_NV 89 2.2
gpt2-medium 512 1GPU 152 2.6
gpt2-medium 512 DP 92 2.2
gpt2-medium 512 DDP 82 2.4
gpt2-medium 512 DDP_no_NV 124 1.6
gpt2-large 128 1GPU 98 4.1
gpt2-large 128 DP 79 2.5
gpt2-large 128 DDP 65 3.0
gpt2-large 128 DDP_no_NV 152 1.3
gpt2-large 256 1GPU 154 2.6
gpt2-large 256 DP 106 1.9

Doing a quick scan it's clear that as the model grows in size and the block in its size they results start to diverge more and more, though proportions don't change much. Probably could pipe this to convert into relative sizes and then it'd very clear.

my hypothesis is that with smaller block sizes it is effectively using smaller batches and therefore synchronizing between GPUs more often?

It certainly has less data to communicate to the other gpus with smaller blocks

@stas00
Copy link
Contributor

stas00 commented Jan 27, 2021

ok, a quick hack to add ratios relative to 1gpu, so now it's easier to see the comparison.

perl -lne 'BEGIN{ print qq[|model|block|type|runtime|sample/sec|ratios]; print "|-" x 6, "|"} $d=qr/([\d\.]+)/; if (m|^(\S+) $d (\S+) ..train_runtime.. $d, .train_samples_per_second.. $d|) {if($3=="1GPU") {$s=$4; print "| " x 6, "|"}; print qq[|$1|$2|$3|] . int($4). "|". sprintf("%0.1f", $5)."|".sprintf("%0.1f", $4/$s)."|"}'  log.txt

So I added a new column runtime ratios and each 4 rows get recalculated wrt to their first runtime entry with 1gpu.

edit: someone asked to explain the ratio and why the runtime is faster for DDP, but samples per second is smaller.

Here is a puzzle to solve:

  1. one cake eater eats the cake at 60 sec/cake
  2. now a second cake eater joins and who eats at the same speed as the first one, but now after every bite they have to shout "ML rocks", which slows down both of them, so they are now eating 20% slower than when alone

Will one cake eater finish the cake faster than two of them?

(the answer is after the table, so you don't see it right away)

model block type runtime sample/sec ratios
gpt2 128 1GPU 19 20.4 1.0
gpt2 128 DP 16 12.0 0.9
gpt2 128 DDP 13 14.8 0.7
gpt2 128 DDP_no_NV 30 6.7 1.5
gpt2 256 1GPU 30 13.1 1.0
gpt2 256 DP 22 8.8 0.7
gpt2 256 DDP 18 10.7 0.6
gpt2 256 DDP_no_NV 35 5.6 1.2
gpt2 512 1GPU 58 6.9 1.0
gpt2 512 DP 37 5.3 0.6
gpt2 512 DDP 32 6.2 0.6
gpt2 512 DDP_no_NV 49 4.1 0.8
gpt2-medium 128 1GPU 49 8.1 1.0
gpt2-medium 128 DP 40 4.9 0.8
gpt2-medium 128 DDP 33 6.0 0.7
gpt2-medium 128 DDP_no_NV 74 2.7 1.5
gpt2-medium 256 1GPU 79 5.0 1.0
gpt2-medium 256 DP 56 3.6 0.7
gpt2-medium 256 DDP 47 4.2 0.6
gpt2-medium 256 DDP_no_NV 89 2.2 1.1
gpt2-medium 512 1GPU 152 2.6 1.0
gpt2-medium 512 DP 92 2.2 0.6
gpt2-medium 512 DDP 82 2.4 0.5
gpt2-medium 512 DDP_no_NV 124 1.6 0.8
gpt2-large 128 1GPU 98 4.1 1.0
gpt2-large 128 DP 79 2.5 0.8
gpt2-large 128 DDP 65 3.0 0.7
gpt2-large 128 DDP_no_NV 152 1.3 1.5
gpt2-large 256 1GPU 154 2.6 1.0
gpt2-large 256 DP 106 1.9 0.7

and the answer to the puzzle posted at the beginning of this comment: 2 cake eaters will eat the cake faster together despite the slowdown, because they only have half a cake to finish each!

Same here, while each of the GPUs in the DDP assembly performs slower due to the gradient syncing, but because it has to consume half the samples, overall the assembly will train faster.

Further, this benchmark is just for 2 GPUs

So going from 1GPU to 2GPUs, you create the overhead, and so you get some loss in performance, and some gain

When you go from 2GPUs to 4GPUs (on the same node), it's pure performance doubling.
i.e. 4GPUs will perform disproportionally faster than 2GPUs over 1 GPU.

  • 1 GPU has no inter-gpu communication to do
  • 2+ gpus have to average gradients

so they add this overhead, but then they can parallelize the processing so the overhead becomes almost negligible as the number of GPUs grows

The next problem is once you outgrow a single node. So the next issue is inter-node connects, which can be blazing fast (Infiniband) or super-slow (ethernet hub). So to scale from 8GPUs to 10 (for 8-gpu node), you first lose performance, because now the inter-node connection is the slow component that slows everything down. But as you add more nodes, again that overhead becomes less and less significant.

Of course when working with multi-node one often uses other parallelization techniques than DDP, so it's PP or TP (https://huggingface.co/transformers/parallelism.html#concepts), and there one generally performs TP only inside a node, and PP and DP over nodes.

It'd be amazing if someone re-did this table for 1, 2, 4 gpus, then 1, 2, 4 nodes.

@stas00 stas00 added the Benchmarks Issues related to Memory regressions in tests and scripts label Feb 1, 2021
@moyix
Copy link
Author

moyix commented Feb 2, 2021

OK, now we have some extensive benchmarks for the RTX8000 machine. This machine does not have NVLink, but it apparently can do P2P GPU-GPU communication via the PCI bus. However, this seems to be quite slow – slower, in fact, than disabling P2P altogether.

Here's nvidia-smi topo -m:

        GPU0    GPU1    GPU2    GPU3    mlx5_0  CPU Affinity    NUMA Affinity
GPU0     X      SYS     SYS     SYS     SYS     0-7     0-1
GPU1    SYS      X      SYS     SYS     SYS     0-7     0-1
GPU2    SYS     SYS      X      SYS     SYS     0-7     0-1
GPU3    SYS     SYS     SYS      X      SYS     0-7     0-1
mlx5_0  SYS     SYS     SYS     SYS      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

I used the script from before (slightly expanded) and set max-steps to 800 for the single GPU case, 400 for two GPUs, and 200 for 4 GPUs. Here are the benchmarks (long!):

model block type runtime sample/sec ratios
gpt2 128 1GPU 67 11.9 1.0
gpt2 128 DP_2GPU 530 0.8 7.9
gpt2 128 DDP_2GPU 350 1.1 5.2
gpt2 128 DDP_no_P2P_2GPU 119 3.3 1.8
gpt2 128 DP_4GPU 243 0.8 3.6
gpt2 128 DDP_4GPU 159 1.3 2.4
gpt2 128 DDP_no_P2P_4GPU 88 2.3 1.3
gpt2 256 1GPU 113 7.0 1.0
gpt2 256 DP_2GPU 582 0.7 5.1
gpt2 256 DDP_2GPU 376 1.1 3.3
gpt2 256 DDP_no_P2P_2GPU 142 2.8 1.3
gpt2 256 DP_4GPU 313 0.6 2.8
gpt2 256 DDP_4GPU 174 1.1 1.5
gpt2 256 DDP_no_P2P_4GPU 102 1.9 0.9
gpt2 512 1GPU 215 3.7 1.0
gpt2 512 DP_2GPU 694 0.6 3.2
gpt2 512 DDP_2GPU 426 0.9 2.0
gpt2 512 DDP_no_P2P_2GPU 192 2.1 0.9
gpt2 512 DP_4GPU 454 0.4 2.1
gpt2 512 DDP_4GPU 201 1.0 0.9
gpt2 512 DDP_no_P2P_4GPU 124 1.6 0.6
gpt2-medium 128 1GPU 183 4.4 1.0
gpt2-medium 128 DP_2GPU 1476 0.3 8.0
gpt2-medium 128 DDP_2GPU 863 0.5 4.7
gpt2-medium 128 DDP_no_P2P_2GPU 280 1.4 1.5
gpt2-medium 128 DP_4GPU 653 0.3 3.6
gpt2-medium 128 DDP_4GPU 375 0.5 2.0
gpt2-medium 128 DDP_no_P2P_4GPU 193 1.0 1.1
gpt2-medium 256 1GPU 306 2.6 1.0
gpt2-medium 256 DP_2GPU 1600 0.2 5.2
gpt2-medium 256 DDP_2GPU 919 0.4 3.0
gpt2-medium 256 DDP_no_P2P_2GPU 339 1.2 1.1
gpt2-medium 256 DP_4GPU 814 0.2 2.7
gpt2-medium 256 DDP_4GPU 401 0.5 1.3
gpt2-medium 256 DDP_no_P2P_4GPU 218 0.9 0.7
gpt2-medium 512 1GPU 573 1.4 1.0
gpt2-medium 512 DP_2GPU 1884 0.2 3.3
gpt2-medium 512 DDP_2GPU 1053 0.4 1.8
gpt2-medium 512 DDP_no_P2P_2GPU 472 0.8 0.8
gpt2-medium 512 DP_4GPU 1177 0.2 2.1
gpt2-medium 512 DDP_4GPU 462 0.4 0.8
gpt2-medium 512 DDP_no_P2P_4GPU 278 0.7 0.5
gpt2-large 128 1GPU 402 2.0 1.0
gpt2-large 128 DP_2GPU 3181 0.1 7.9
gpt2-large 128 DDP_2GPU 1760 0.2 4.4
gpt2-large 128 DDP_no_P2P_2GPU 565 0.7 1.4
gpt2-large 128 DP_4GPU 1361 0.1 3.4
gpt2-large 128 DDP_4GPU 717 0.3 1.8
gpt2-large 128 DDP_no_P2P_4GPU 349 0.6 0.9
gpt2-large 256 1GPU 642 1.2 1.0
gpt2-large 256 DP_2GPU 3440 0.1 5.4
gpt2-large 256 DDP_2GPU 1882 0.2 2.9
gpt2-large 256 DDP_no_P2P_2GPU 686 0.6 1.1
gpt2-large 256 DP_4GPU 1673 0.1 2.6
gpt2-large 256 DDP_4GPU 770 0.3 1.2
gpt2-large 256 DDP_no_P2P_4GPU 403 0.5 0.6
gpt2-large 512 1GPU 1168 0.7 1.0
gpt2-large 512 DP_2GPU 3947 0.1 3.4
gpt2-large 512 DDP_2GPU 2145 0.2 1.8
gpt2-large 512 DDP_no_P2P_2GPU 952 0.4 0.8
gpt2-large 512 DP_4GPU 2303 0.1 2.0
gpt2-large 512 DDP_4GPU 902 0.2 0.8
gpt2-large 512 DDP_no_P2P_4GPU 531 0.4 0.5
gpt2-xl 128 1GPU 770 1.0 1.0
gpt2-xl 128 DP_2GPU 6391 0.1 8.3
gpt2-xl 128 DDP_2GPU 3396 0.1 4.4
gpt2-xl 128 DDP_no_P2P_2GPU 751 0.5 1.0
gpt2-xl 128 DP_4GPU 2588 0.1 3.4
gpt2-xl 128 DDP_4GPU 1356 0.1 1.8
gpt2-xl 128 DDP_no_P2P_4GPU 635 0.3 0.8
gpt2-xl 256 1GPU 1210 0.7 1.0
gpt2-xl 256 DP_2GPU 6826 0.1 5.6
gpt2-xl 256 DP_4GPU 3130 0.1 2.6

@stas00
Copy link
Contributor

stas00 commented Feb 2, 2021

Thank you for doing this immense work, @moyix!

From a quick look it appears the model size doesn't matter, but the block-size makes a big difference to a faster outcome with the various DDP approaches - the larger the block the more benefits one gets, and for small blocks the performance is terrible.

@stas00
Copy link
Contributor

stas00 commented Feb 3, 2021

@JJack0812, your issue report won't get addresses here as we are talking about a totally different topic in this thread - I'd say post a separate issue - may be under pytorch or transformers, but first study existing tickets, e.g.: this one

@github-actions
Copy link

github-actions bot commented Mar 6, 2021

This issue has been automatically marked as stale and been closed because it has not had recent activity. Thank you for your contributions.

If you think this still needs to be addressed please comment on this thread.

@lizhengbuaa
Copy link

I have met the same problem, thanks for the answers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Benchmarks Issues related to Memory regressions in tests and scripts Performance wontfix
Projects
None yet
Development

No branches or pull requests

7 participants