How to save and load checkpointing using DeepSpeed plugin stage 3? #9321
Unanswered
yidong72
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment 1 reply
-
Here is a small example. Run it twice, once without modifications and a second time by increasing the max epochs and uncommenting the line for resume_from_checkpoint: import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch:02d}",
)
trainer = Trainer(
# resume_from_checkpoint="checkpoints/epoch=9.ckpt",
max_epochs=1, # increase when resuming
gpus=2,
accelerator="ddp",
plugins=[DeepSpeedPlugin(stage=3)],
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
precision=16,
weights_summary=None,
callbacks=[checkpoint_callback],
)
trainer.fit(model, train_dataloader=train_data)
if __name__ == "__main__":
run()
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have been struggling figuring out how to save/load my model with DeepSpeed plugin. I cannot find any examples of doing it.
Here is how I setup the plugin
I use the
ModelCheckpoint
callback to save the checkpoints. It generates either a single checkpoint file or a directory ofpt
files depending on thesave_full_weights
state true of false.However I don't know how to load the checkpoint files. I tried either
Model.load_from_checkpoint
orTrainer(resume_from_checkpoint=)
methods, none of them works for me. I gotAttributeError: 'NoneType' object has no attribute 'trainer'
,Default process group has not been initialized, please make sure to call init_process_group.
errors.Could you show me a working example? Thanks.
Beta Was this translation helpful? Give feedback.
All reactions