Skip to content

Commit

Permalink
Small changes to reduce peak memory. (#389)
Browse files Browse the repository at this point in the history
Co-authored-by: Taylor Robie <[email protected]>
Co-authored-by: Luca Antiga <[email protected]>
  • Loading branch information
3 people authored Jun 21, 2023
1 parent 6d2c5ca commit c3c43b6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main(
):

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block)
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)

fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy)
fabric.launch()
Expand All @@ -79,7 +79,7 @@ def main(

model = fabric.setup_module(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False)
optimizer = fabric.setup_optimizers(optimizer)

train(fabric, model, optimizer, train_data, val_data, out_dir)
Expand Down
3 changes: 2 additions & 1 deletion pretrain/redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main(
transformer_auto_wrap_policy, transformer_layer_cls={Block}
)
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block
auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True
)

fabric = L.Fabric(
Expand Down Expand Up @@ -110,6 +110,7 @@ def main(
lr=learning_rate,
weight_decay=weight_decay,
betas=(beta1, beta2),
foreach=False,
)

model, optimizer = fabric.setup(model, optimizer)
Expand Down
4 changes: 2 additions & 2 deletions pretrain/shakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

def main() -> None:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block)
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)

fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy)
fabric.launch()
Expand All @@ -70,7 +70,7 @@ def main() -> None:

model = fabric.setup_module(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)
optimizer = fabric.setup_optimizers(optimizer)

train(fabric, model, optimizer, train_data, val_data)
Expand Down

0 comments on commit c3c43b6

Please sign in to comment.