From c3c43b6a911f3d6be5b5b3a16e844202647060e5 Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Wed, 21 Jun 2023 11:14:54 -0700 Subject: [PATCH] Small changes to reduce peak memory. (#389) Co-authored-by: Taylor Robie Co-authored-by: Luca Antiga --- finetune/full.py | 4 ++-- pretrain/redpajama.py | 3 ++- pretrain/shakespeare.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/finetune/full.py b/finetune/full.py index bf94da49..9248e8de 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -55,7 +55,7 @@ def main( ): auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block) + strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True) fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy) fabric.launch() @@ -79,7 +79,7 @@ def main( model = fabric.setup_module(model) - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False) optimizer = fabric.setup_optimizers(optimizer) train(fabric, model, optimizer, train_data, val_data, out_dir) diff --git a/pretrain/redpajama.py b/pretrain/redpajama.py index 16942873..97ebde28 100644 --- a/pretrain/redpajama.py +++ b/pretrain/redpajama.py @@ -69,7 +69,7 @@ def main( transformer_auto_wrap_policy, transformer_layer_cls={Block} ) strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block + auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True ) fabric = L.Fabric( @@ -110,6 +110,7 @@ def main( lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), + foreach=False, ) model, optimizer = fabric.setup(model, optimizer) diff --git a/pretrain/shakespeare.py b/pretrain/shakespeare.py index ae27d266..9daa064c 100644 --- a/pretrain/shakespeare.py +++ b/pretrain/shakespeare.py @@ -47,7 +47,7 @@ def main() -> None: auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block) + strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True) fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy) fabric.launch() @@ -70,7 +70,7 @@ def main() -> None: model = fabric.setup_module(model) - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False) optimizer = fabric.setup_optimizers(optimizer) train(fabric, model, optimizer, train_data, val_data)