-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Excessive GPU-GPU communication with GPT2 making multi-GPU training slow? #9371
Comments
Not an answer to your issue/question, but have you tried running in distributed training (DDP), which is the recommended way of running over multiple GPUs: https://github.com/huggingface/transformers/tree/master/examples#distributed-training-and-mixed-precision Would be curious to see the same with/without NVLink experiment there. |
Hmm, I don't have much experience using torch.distributed. I tried just running the existing script with If you have a link to some documentation that explains how to set up the training script so that it can be used with torch.distributed, I can give that a try. |
The command you posted "should" work. @sgugger might have links to better content when he's back, but the PyTorch tutorials are pretty good: https://pytorch.org/tutorials/beginner/dist_overview.html#data-parallel-training Your initial experiment is using |
OK, I got around to spending some more time with this today. I realized that the So now I'm just using that, with:
For single GPU I drop the Results:
So the general observation is that for block size 512, two GPUs without NVLink are about the same performance as a single GPU. For block size 128, two GPUs without NVLink are typically quite a bit slower than a single GPU. It doesn't seem like DistributedDataParallel helps with this issue, in other words. |
I think @sgugger has experience with multi-GPU, and works on the example scripts, pinging him! |
A friend was linking me to this issue. Thank you for your work on this benchmark! It is some interesting data. I still believe the poor performance could be a hardware issue though. As far as I know, RTX 3090 GPUs have peer-to-peer access disable, or in other words, you cannot transfer memory from GPU to GPU on these GPUs. All data is first routed through the CPU, which is often slow because the CPU buffers are not pinned, meaning that memory transfers are synchronous. So in my eyes, slow performance without NVLink is a hardware issue in this case. It would be curious, though, if these numbers would be similar for peer-to-peer enabled GPUs. Do you have access to such a GPU? |
You're thinking of something like P2P over PCIe? You're right that NVIDIA has disabled that for the 3090s. The only other hardware I have access to is our HPC cluster, which has RTX8000s and V100s (non-NVLINKed); I believe both show similar slowdowns. One thing I have been looking into is whether using something like DeepSpeed will help. I got their Megatron-LM example working and it does much better at scaling to two at least GPUs without NVLINK using the 1-bit Adam optimizer. I'm still waiting for my HPC job to get scheduled to confirm that it scales well there too. If that works then presumably something like what's being done for the t5-3b model here would help? #8771 |
If you confirm you have the same results for the RTX 8000 that would rule out any GPU issue. It could still be a hardware issue with PCIe lanes. There is a bandwidth test I believe among the NVIDIA samples that come with CUDA with which you can test the available bandwidth to/from GPUs. If this shows good numbers it should be purely an issue of software or network architecture. |
OK, I'll give this a try. Our HPC cluster is a bit busy so it may be a while before I can get a slot on the RTX 8000 nodes. |
I managed to get some time on a node with 4x V100s. For the Large model, it gets 3.83s/it with an ETA of 1248:01:43 (!). Here's the output of p2pBandwidthLatencyTest on the V100 system:
And for comparison, here's the dual 3090 w/NVLINK system:
|
Thank you - these data are very valuable! It also shows that no hardware problem exists. It seems you could confirm poor performance on the V100 which makes it very likely that you can also reproduce performance issues with the RTX 8000. With that, it seems the only option is that it is an issue with the combination of parallelism and network architecture. |
Great benchmarks! Thank you for sharing the data, @moyix Do you have the same benchmarks for V100s too - just one set is enough (1 vs 2). Also, why are you running comparison benchmarks on such huge number of items? Running enough items so that runtime is around a few minutes should be plenty to see the difference. Or is it that you were aborting these early and just recording the projected ETA and it/s from tqdm? Here are some ideas that may address your issue
|
OK, so here is my benchmark with the same tool. edit: my initial benchmark had a bug in it as pointed out by @sgugger as one has to tweak So for 1 gpu, I had to double Hardware 2x TITAN RTX 24GB each + NVlink
I get the same bus report w/ and w/o NCCL_P2P_DISABLE=1 - I don't think
but clearly the runtime is much slower w/o the NVlink as the benchmark shows, so pytorch/cuda does respect it. Analysis:
Here is the full benchmark code and outputs:
|
Yes, apologies for the confusion; the ETA numbers above are from aborting early (after a few minutes) and noting the ETA. I actually did compile PyTorch from source with CUDA 11.2 and it doesn't seem to have changed the results (although I don't know if there are further changes PyTorch will make to take full advantage of 11.2). Your benchmark code is much more self-contained than mine, so I will give your benchmarks a shot with the RTX8000 and V100 nodes on our cluster, but it will probably be a few days before I can get time there as the ICML deadline is very close :) Here's nvidia-smi -m topo for the 3090 machine:
|
Note that the timing compare 200 training steps, so the numbers you reported wrong @stas00 in the sense that 2 GPUs have covered 400 samples instead of 200. Training on the full dataset would therefore go twice as fast as with one GPU. |
This is correct - that my report was incorrect. Thank you for validating my concern in #9801, @sgugger That's why I'm asking for a less confusing way to truncate the dataset. I need to find an easy-way to do it so I don't have to be in full thinking capacity if I do it late at night which was the case last night. I will revisit my benchmark with corrections hopefully today. But it doesn't change the fact that nvlink gives 30% faster performance. |
That's what I guessed - I am glad you didn't waste all that electricity to run these to completion! It was a smart move, since you waited a few minutes.
Oh, thank you for validating that! Building pytorch from source is hard! Hat off to you! Yes, we don't know whether everything has been put in place for 11.2 support.
please note that I corrected a mistake in my benchmark as kindly pointed out by @sgugger:
Looks very similar. Do you know what exactly:
means? is NV4 better than NV2? since I get NV2. Why do you have 4? As I can see you only have 2 gpus. |
According to this table NV4 means "Connection traversing a bonded set of 4 NVLinks". There are some more details in the GA102 whitepaper:
|
Super! Thank you for that insight, @moyix! I started compiling performance/scalability notes here: #9824 I summarized the useful insights from this thread. If you get a chance to validate the GPU inter-connectivity section that would be great! And if you have other insights to contribute I'm all ears. If you don't have time/inspiration to write something complete even a stab would be great and then over time we will fill it out with details and benchmarks. The idea is to discuss in-depth the different hardware/software nuances to speed up training and fit larger models. Thank you! |
Very nice, I will take a look at it! While I am waiting for HPC time, I ran your benchmark script on the 3090 system while varying two parameters: the model size (gpt2, gpt2-medium, and gpt2-large) and the block size (128, 256, 512). The script:
And the results:
One thing that I find interesting is that the behavior I originally observed where training on a single GPU could be slower than on multiple GPUs without NVLink only seems to be true for small block sizes like 128 or (sometimes) 256. So my hypothesis is that with smaller block sizes it is effectively using smaller batches and therefore synchronizing between GPUs more often? As soon as I can get some time on our HPC I can update this with numbers for the 4xRTX8000 and the 4xV100, although the NVLink rows will no longer be applicable (since I don't have access to a machine with those cards in NVLink/NVSwitch configuration). |
Awesome! Thank you for more benchmarks, @moyix Let's apply some magic to your log:
but let's round it up to make reading easier:
Doing a quick scan it's clear that as the model grows in size and the block in its size they results start to diverge more and more, though proportions don't change much. Probably could pipe this to convert into relative sizes and then it'd very clear.
It certainly has less data to communicate to the other gpus with smaller blocks |
ok, a quick hack to add ratios relative to 1gpu, so now it's easier to see the comparison.
So I added a new column runtime edit: someone asked to explain the ratio and why the runtime is faster for DDP, but samples per second is smaller. Here is a puzzle to solve:
Will one cake eater finish the cake faster than two of them? (the answer is after the table, so you don't see it right away)
and the answer to the puzzle posted at the beginning of this comment: 2 cake eaters will eat the cake faster together despite the slowdown, because they only have half a cake to finish each! Same here, while each of the GPUs in the DDP assembly performs slower due to the gradient syncing, but because it has to consume half the samples, overall the assembly will train faster. Further, this benchmark is just for 2 GPUs So going from 1GPU to 2GPUs, you create the overhead, and so you get some loss in performance, and some gain When you go from 2GPUs to 4GPUs (on the same node), it's pure performance doubling.
so they add this overhead, but then they can parallelize the processing so the overhead becomes almost negligible as the number of GPUs grows The next problem is once you outgrow a single node. So the next issue is inter-node connects, which can be blazing fast (Infiniband) or super-slow (ethernet hub). So to scale from 8GPUs to 10 (for 8-gpu node), you first lose performance, because now the inter-node connection is the slow component that slows everything down. But as you add more nodes, again that overhead becomes less and less significant. Of course when working with multi-node one often uses other parallelization techniques than DDP, so it's PP or TP (https://huggingface.co/transformers/parallelism.html#concepts), and there one generally performs TP only inside a node, and PP and DP over nodes. It'd be amazing if someone re-did this table for 1, 2, 4 gpus, then 1, 2, 4 nodes. |
OK, now we have some extensive benchmarks for the RTX8000 machine. This machine does not have NVLink, but it apparently can do P2P GPU-GPU communication via the PCI bus. However, this seems to be quite slow – slower, in fact, than disabling P2P altogether. Here's
I used the script from before (slightly expanded) and set
|
Thank you for doing this immense work, @moyix! From a quick look it appears the model size doesn't matter, but the block-size makes a big difference to a faster outcome with the various DDP approaches - the larger the block the more benefits one gets, and for small blocks the performance is terrible. |
@JJack0812, your issue report won't get addresses here as we are talking about a totally different topic in this thread - I'd say post a separate issue - may be under pytorch or transformers, but first study existing tickets, e.g.: this one |
This issue has been automatically marked as stale and been closed because it has not had recent activity. Thank you for your contributions. If you think this still needs to be addressed please comment on this thread. |
I have met the same problem, thanks for the answers |
Summary: on a multi-GPU system, training GPT2 seems to scale poorly unless a very fast GPU-GPU interconnect like NVLink is available. In particular, without NVLink using two GPUs is slower than using just one GPU.
Environment info
transformers
version: 4.1.1Who can help
Maybe @LysandreJik or @patrickvonplaten ?
Information
Model I am using (Bert, XLNet ...): GPT2
The problem arises when using:
The script is a pretty basic example of training a medium-size GPT2 from scratch. The script is here: https://panda.moyix.net/~moyix/train_csrc.py
The dataset and tokenized vocab:
The tasks I am working on is:
Training a GPT2 language model on C source code.
To reproduce
Run with only one GPU:
CUDA_VISIBLE_DEVICES=0 python train_csrc.py
Run with two GPUs, NVLink disabled:
NCCL_P2P_DISABLE=1 python train_csrc.py
Run with two GPUs and NVLink enabled:
python train_csrc.py
Here is some benchmarking I did with my dataset on transformers 3.3.1 and 4.1.1 (note the difference in ETA is just because 3.3.1 only seems to report the ETA for the current epoch):
You can see that using two GPUs is actually slower than using a single GPU, unless NVLink is available (599 hours for 1 GPU vs 1025 hours for two GPUs). So presumably there is a large amount of GPU-GPU communication going on?
Expected behavior
Scaling should be roughly linear with the number of GPUs. Unfortunately I am not very familiar with the implementation details of GPT2 in Huggingface, but others report roughly linear scaling with Transformer models like BERT so it should work here as well: https://towardsdatascience.com/training-bert-at-a-university-eedcf940c754
Although I have a system with NVLink at home, this issue is still affecting me because I would like to be able to run this on the university HPC cluster, where most nodes do not have NVLink.
The text was updated successfully, but these errors were encountered: