-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Small changes to reduce peak memory. #389
Conversation
This is actually really nice @robieta ! Peak memory is one of the big issues in practice when finetuning LLaMA 65B / Falcon 40B as well. Once this is merged, we should also apply this to lit-parrot! |
@carmocca definitely let's add the same fixes to all scripts using optimization and FSDP |
Ok, added same tweaks to full finetuning and redpajama pretraining, which are the two cases were the optimizer state is chunkier. For adapter and LoRA I'm not expecting this to be a problem, @robieta please keep me honest here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can confirm it about foreach
from experiments in lit-gpt, haven't tested limit_all_gathers
yet.
Whoa I just ran those and I'm seeing a huge speedup (95% !) by setting |
The finetune scripts in this repo are using DeepSpeed, unlike in lit-gpt where I switched it for FSDP: Lightning-AI/litgpt#118. So we can merge this. |
LLaMA 7B is very close to the limit for sharding across 4x40GB cards, and recently several library changes have put it over the limit and it now OOMs when run. (Notably, it turns out we were cheating and using a bit less memory than the correct implementation.) This PR introduces two small changes which prevent Shakespeare pretraining from OOMing and greatly improve its performance:
foreach=False
Some PyTorch optimizers have the ability to group updates, which generally improves performance by reducing the number of kernel launches and consequently the amount of host latency. However this grouping has an increased peak memory footprint. Instead of allocating and freeing a sizeX
Tensork
times, you allocate and free ak * X
sized Tensor once.limit_all_gathers=True
Breaking up the optimizer step will bring the logical memory down below the OOM threshold, but if we look at memory statistics we still see reserved memory maxing out. The reason is FSDP is overzealous in launching all gathers in an attempt to overlap as much communication and compute as possible. The trouble is all of those concurrent in-flight requests take up memory; more specifically the CUDACachingAllocator caches per stream, so it's not straightforward for it to reclaim memory from the communication stream and use it in the compute stream. By restricting the number of in-flight gathers we get about a 3x performance improvement. (6-8 seconds / step -> 1.5-3 seconds / step)