Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small changes to reduce peak memory. #389

Merged
merged 2 commits into from
Jun 21, 2023
Merged

Small changes to reduce peak memory. #389

merged 2 commits into from
Jun 21, 2023

Conversation

robieta
Copy link
Contributor

@robieta robieta commented Jun 13, 2023

LLaMA 7B is very close to the limit for sharding across 4x40GB cards, and recently several library changes have put it over the limit and it now OOMs when run. (Notably, it turns out we were cheating and using a bit less memory than the correct implementation.) This PR introduces two small changes which prevent Shakespeare pretraining from OOMing and greatly improve its performance:

  1. foreach=False Some PyTorch optimizers have the ability to group updates, which generally improves performance by reducing the number of kernel launches and consequently the amount of host latency. However this grouping has an increased peak memory footprint. Instead of allocating and freeing a size X Tensor k times, you allocate and free a k * X sized Tensor once.

  2. limit_all_gathers=True Breaking up the optimizer step will bring the logical memory down below the OOM threshold, but if we look at memory statistics we still see reserved memory maxing out. The reason is FSDP is overzealous in launching all gathers in an attempt to overlap as much communication and compute as possible. The trouble is all of those concurrent in-flight requests take up memory; more specifically the CUDACachingAllocator caches per stream, so it's not straightforward for it to reclaim memory from the communication stream and use it in the compute stream. By restricting the number of in-flight gathers we get about a 3x performance improvement. (6-8 seconds / step -> 1.5-3 seconds / step)

@rasbt
Copy link
Contributor

rasbt commented Jun 13, 2023

This is actually really nice @robieta ! Peak memory is one of the big issues in practice when finetuning LLaMA 65B / Falcon 40B as well.

Once this is merged, we should also apply this to lit-parrot!

pretrain/shakespeare.py Show resolved Hide resolved
pretrain/shakespeare.py Show resolved Hide resolved
@lantiga
Copy link
Collaborator

lantiga commented Jun 21, 2023

@carmocca definitely let's add the same fixes to all scripts using optimization and FSDP

@lantiga
Copy link
Collaborator

lantiga commented Jun 21, 2023

Ok, added same tweaks to full finetuning and redpajama pretraining, which are the two cases were the optimizer state is chunkier. For adapter and LoRA I'm not expecting this to be a problem, @robieta please keep me honest here.

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm it about foreach from experiments in lit-gpt, haven't tested limit_all_gathers yet.

@carmocca
Copy link
Contributor

Whoa I just ran those and I'm seeing a huge speedup (95% !) by setting limit_all_gathers for adapter too. So we should totally do it.

@carmocca
Copy link
Contributor

The finetune scripts in this repo are using DeepSpeed, unlike in lit-gpt where I switched it for FSDP: Lightning-AI/litgpt#118. So we can merge this.

@carmocca carmocca merged commit c3c43b6 into main Jun 21, 2023
@carmocca carmocca deleted the robieta/llama_oom branch June 21, 2023 18:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants