diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 3a2691880a20c..19153919f8da7 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -85,38 +85,42 @@ Here are the only required methods. .. code-block:: python import lightning.pytorch as pl - import torch.nn as nn - import torch.nn.functional as F + import torch + from lightning.pytorch.demos import Transformer - class LitModel(pl.LightningModule): - def __init__(self): + + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): super().__init__() - self.l1 = nn.Linear(28 * 28, 10) + self.model = Transformer(vocab_size=vocab_size) - def forward(self, x): - return torch.relu(self.l1(x.view(x.size(0), -1))) + def forward(self, inputs, target): + return self.model(inputs, target) def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) + inputs, target = batch + output = self(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) + return torch.optim.SGD(self.model.parameters(), lr=0.1) Which you can train by doing: .. code-block:: python - train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) - trainer = pl.Trainer(max_epochs=1) - model = LitModel() + from torch.utils.data import DataLoader + + dataset = pl.demos.WikiText2() + dataloader = DataLoader(dataset) + model = LightningTransformer(vocab_size=dataset.vocab_size) - trainer.fit(model, train_dataloaders=train_loader) + trainer = pl.Trainer(fast_dev_run=100) + trainer.fit(model=model, train_dataloaders=dataloader) -The LightningModule has many convenience methods, but the core ones you need to know about are: +The LightningModule has many convenient methods, but the core ones you need to know about are: .. list-table:: :widths: 50 50 @@ -152,15 +156,15 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul .. code-block:: python - class LitClassifier(pl.LightningModule): - def __init__(self, model): + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): super().__init__() - self.model = model + self.model = Transformer(vocab_size=vocab_size) def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) + inputs, target = batch + output = self.model(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss Under the hood, Lightning does the following (pseudocode): @@ -191,15 +195,15 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning .. code-block:: python - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) + def training_step(self, batch, batch_idx): + inputs, target = batch + output = self.model(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) - # logs metrics for each training_step, - # and the average across the epoch, to the progress bar and logger - self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss + # logs metrics for each training_step, + # and the average across the epoch, to the progress bar and logger + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss The :meth:`~lightning.pytorch.core.module.LightningModule.log` method automatically reduces the requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood: @@ -230,25 +234,25 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho .. code-block:: python - def __init__(self): - super().__init__() - self.training_step_outputs = [] - - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - preds = ... - self.training_step_outputs.append(preds) - return loss + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.model = Transformer(vocab_size=vocab_size) + self.training_step_outputs = [] + def training_step(self, batch, batch_idx): + inputs, target = batch + output = self.model(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + preds = ... + self.training_step_outputs.append(preds) + return loss - def on_train_epoch_end(self): - all_preds = torch.stack(self.training_step_outputs) - # do something with all preds - ... - self.training_step_outputs.clear() # free memory + def on_train_epoch_end(self): + all_preds = torch.stack(self.training_step_outputs) + # do something with all preds + ... + self.training_step_outputs.clear() # free memory ------------------ @@ -264,10 +268,10 @@ To activate the validation loop while training, override the :meth:`~lightning.p .. code-block:: python - class LitModel(pl.LightningModule): + class LightningTransformer(pl.LightningModule): def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) + inputs, target = batch + output = self.model(inputs, target) loss = F.cross_entropy(y_hat, y) self.log("val_loss", loss) @@ -300,8 +304,8 @@ and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`. .. code-block:: python - model = Model() - trainer = Trainer() + model = LightningTransformer(vocab_size=dataset.vocab_size) + trainer = pl.Trainer() trainer.validate(model) .. note:: @@ -322,25 +326,26 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule .. code-block:: python - def __init__(self): - super().__init__() - self.validation_step_outputs = [] - - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - pred = ... - self.validation_step_outputs.append(pred) - return pred + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.model = Transformer(vocab_size=vocab_size) + self.validation_step_outputs = [] + def validation_step(self, batch, batch_idx): + x, y = batch + inputs, target = batch + output = self.model(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + pred = ... + self.validation_step_outputs.append(pred) + return pred - def on_validation_epoch_end(self): - all_preds = torch.stack(self.validation_step_outputs) - # do something with all preds - ... - self.validation_step_outputs.clear() # free memory + def on_validation_epoch_end(self): + all_preds = torch.stack(self.validation_step_outputs) + # do something with all preds + ... + self.validation_step_outputs.clear() # free memory ---------------- @@ -358,9 +363,10 @@ The only difference is that the test loop is only called when :meth:`~lightning. .. code-block:: python - model = Model() - trainer = Trainer() - trainer.fit(model) + model = LightningTransformer(vocab_size=dataset.vocab_size) + dataloader = DataLoader(dataset) + trainer = pl.Trainer() + trainer.fit(model=model, train_dataloaders=dataloader) # automatically loads the best weights for you trainer.test(model) @@ -370,17 +376,23 @@ There are two ways to call ``test()``: .. code-block:: python # call after training - trainer = Trainer() - trainer.fit(model) + trainer = pl.Trainer() + trainer.fit(model=model, train_dataloaders=dataloader) # automatically auto-loads the best weights from the previous run - trainer.test(dataloaders=test_dataloader) + trainer.test(dataloaders=test_dataloaders) # or call with pretrained model - model = MyLightningModule.load_from_checkpoint(PATH) - trainer = Trainer() + model = LightningTransformer.load_from_checkpoint(PATH) + dataset = WikiText2() + test_dataloader = DataLoader(dataset) + trainer = pl.Trainer() trainer.test(model, dataloaders=test_dataloader) +.. note:: + `WikiText2` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only. + A proper split can be created in :meth:`lightning.pytorch.core.LightningModule.setup` or :meth:`lightning.pytorch.core.LightningDataModule.setup`. + .. note:: It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once. @@ -403,24 +415,18 @@ By default, the :meth:`~lightning.pytorch.core.module.LightningModule.predict_st :meth:`~lightning.pytorch.core.module.LightningModule.forward` method. In order to customize this behaviour, simply override the :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` method. -For the example let's override ``predict_step`` and try out `Monte Carlo Dropout `_: +For the example let's override ``predict_step``: .. code-block:: python - class LitMCdropoutModel(pl.LightningModule): - def __init__(self, model, mc_iteration): + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): super().__init__() - self.model = model - self.dropout = nn.Dropout() - self.mc_iteration = mc_iteration - - def predict_step(self, batch, batch_idx): - # enable Monte Carlo Dropout - self.dropout.train() + self.model = Transformer(vocab_size=vocab_size) - # take average of `self.mc_iteration` iterations - pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0) - return pred + def predict_step(self, batch): + inputs, target = batch + return self.model(inputs, target) Under the hood, Lightning does the following (pseudocode): @@ -440,15 +446,17 @@ There are two ways to call ``predict()``: .. code-block:: python # call after training - trainer = Trainer() - trainer.fit(model) + trainer = pl.Trainer() + trainer.fit(model=model, train_dataloaders=dataloader) # automatically auto-loads the best weights from the previous run predictions = trainer.predict(dataloaders=predict_dataloader) # or call with pretrained model - model = MyLightningModule.load_from_checkpoint(PATH) - trainer = Trainer() + model = LightningTransformer.load_from_checkpoint(PATH) + dataset = pl.demos.WikiText2() + test_dataloader = DataLoader(dataset) + trainer = pl.Trainer() predictions = trainer.predict(model, dataloaders=test_dataloader) Inference in Research @@ -460,15 +468,31 @@ If you want to perform inference with the system, you can add a ``forward`` meth .. code-block:: python - class Autoencoder(pl.LightningModule): - def forward(self, x): - return self.decoder(x) + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.model = Transformer(vocab_size=vocab_size) + def forward(self, batch): + inputs, target = batch + return self.model(inputs, target) + + def training_step(self, batch, batch_idx): + inputs, target = batch + output = self.model(inputs, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.model.parameters(), lr=0.1) + + + model = LightningTransformer(vocab_size=dataset.vocab_size) - model = Autoencoder() model.eval() with torch.no_grad(): - reconstruction = model(embedding) + batch = dataloader.dataset[0] + pred = model(batch) The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation: @@ -618,7 +642,7 @@ checkpoint, which simplifies model re-instantiation after training. .. code-block:: python - class LitMNIST(LightningModule): + class LitMNIST(pl.LightningModule): def __init__(self, layer_1_dim=128, learning_rate=1e-2): super().__init__() # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint @@ -642,7 +666,7 @@ parameters should be provided back when reloading the LightningModule. In this c .. code-block:: python - class LitMNIST(LightningModule): + class LitMNIST(pl.LightningModule): def __init__(self, loss_fx, generator_network, layer_1_dim=128): super().__init__() self.layer_1_dim = layer_1_dim