Skip to content

Commit

Permalink
Update docs/source-pytorch/common/lightning_module.rst (#18451)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
jxtngx and Borda authored Sep 8, 2023
1 parent 4512265 commit a013386
Showing 1 changed file with 125 additions and 101 deletions.
226 changes: 125 additions & 101 deletions docs/source-pytorch/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,38 +85,42 @@ Here are the only required methods.
.. code-block:: python
import lightning.pytorch as pl
import torch.nn as nn
import torch.nn.functional as F
import torch
from lightning.pytorch.demos import Transformer
class LitModel(pl.LightningModule):
def __init__(self):
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.l1 = nn.Linear(28 * 28, 10)
self.model = Transformer(vocab_size=vocab_size)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def forward(self, inputs, target):
return self.model(inputs, target)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
inputs, target = batch
output = self(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
return torch.optim.SGD(self.model.parameters(), lr=0.1)
Which you can train by doing:

.. code-block:: python
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer(max_epochs=1)
model = LitModel()
from torch.utils.data import DataLoader
dataset = pl.demos.WikiText2()
dataloader = DataLoader(dataset)
model = LightningTransformer(vocab_size=dataset.vocab_size)
trainer.fit(model, train_dataloaders=train_loader)
trainer = pl.Trainer(fast_dev_run=100)
trainer.fit(model=model, train_dataloaders=dataloader)
The LightningModule has many convenience methods, but the core ones you need to know about are:
The LightningModule has many convenient methods, but the core ones you need to know about are:

.. list-table::
:widths: 50 50
Expand Down Expand Up @@ -152,15 +156,15 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul

.. code-block:: python
class LitClassifier(pl.LightningModule):
def __init__(self, model):
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = model
self.model = Transformer(vocab_size=vocab_size)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss
Under the hood, Lightning does the following (pseudocode):
Expand Down Expand Up @@ -191,15 +195,15 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning

.. code-block:: python
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
def training_step(self, batch, batch_idx):
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
The :meth:`~lightning.pytorch.core.module.LightningModule.log` method automatically reduces the
requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:
Expand Down Expand Up @@ -230,25 +234,25 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho

.. code-block:: python
def __init__(self):
super().__init__()
self.training_step_outputs = []
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
self.training_step_outputs.append(preds)
return loss
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
self.training_step_outputs = []
def training_step(self, batch, batch_idx):
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
preds = ...
self.training_step_outputs.append(preds)
return loss
def on_train_epoch_end(self):
all_preds = torch.stack(self.training_step_outputs)
# do something with all preds
...
self.training_step_outputs.clear() # free memory
def on_train_epoch_end(self):
all_preds = torch.stack(self.training_step_outputs)
# do something with all preds
...
self.training_step_outputs.clear() # free memory
------------------
Expand All @@ -264,10 +268,10 @@ To activate the validation loop while training, override the :meth:`~lightning.p

.. code-block:: python
class LitModel(pl.LightningModule):
class LightningTransformer(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
inputs, target = batch
output = self.model(inputs, target)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
Expand Down Expand Up @@ -300,8 +304,8 @@ and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`.

.. code-block:: python
model = Model()
trainer = Trainer()
model = LightningTransformer(vocab_size=dataset.vocab_size)
trainer = pl.Trainer()
trainer.validate(model)
.. note::
Expand All @@ -322,25 +326,26 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule

.. code-block:: python
def __init__(self):
super().__init__()
self.validation_step_outputs = []
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
self.validation_step_outputs.append(pred)
return pred
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
self.validation_step_outputs = []
def validation_step(self, batch, batch_idx):
x, y = batch
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
pred = ...
self.validation_step_outputs.append(pred)
return pred
def on_validation_epoch_end(self):
all_preds = torch.stack(self.validation_step_outputs)
# do something with all preds
...
self.validation_step_outputs.clear() # free memory
def on_validation_epoch_end(self):
all_preds = torch.stack(self.validation_step_outputs)
# do something with all preds
...
self.validation_step_outputs.clear() # free memory
----------------

Expand All @@ -358,9 +363,10 @@ The only difference is that the test loop is only called when :meth:`~lightning.

.. code-block:: python
model = Model()
trainer = Trainer()
trainer.fit(model)
model = LightningTransformer(vocab_size=dataset.vocab_size)
dataloader = DataLoader(dataset)
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloaders=dataloader)
# automatically loads the best weights for you
trainer.test(model)
Expand All @@ -370,17 +376,23 @@ There are two ways to call ``test()``:
.. code-block:: python
# call after training
trainer = Trainer()
trainer.fit(model)
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloaders=dataloader)
# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)
trainer.test(dataloaders=test_dataloaders)
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
model = LightningTransformer.load_from_checkpoint(PATH)
dataset = WikiText2()
test_dataloader = DataLoader(dataset)
trainer = pl.Trainer()
trainer.test(model, dataloaders=test_dataloader)
.. note::
`WikiText2` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only.
A proper split can be created in :meth:`lightning.pytorch.core.LightningModule.setup` or :meth:`lightning.pytorch.core.LightningDataModule.setup`.

.. note::

It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
Expand All @@ -403,24 +415,18 @@ By default, the :meth:`~lightning.pytorch.core.module.LightningModule.predict_st
:meth:`~lightning.pytorch.core.module.LightningModule.forward` method. In order to customize this behaviour,
simply override the :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` method.

For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
For the example let's override ``predict_step``:

.. code-block:: python
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
self.model = Transformer(vocab_size=vocab_size)
# take average of `self.mc_iteration` iterations
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
return pred
def predict_step(self, batch):
inputs, target = batch
return self.model(inputs, target)
Under the hood, Lightning does the following (pseudocode):

Expand All @@ -440,15 +446,17 @@ There are two ways to call ``predict()``:
.. code-block:: python
# call after training
trainer = Trainer()
trainer.fit(model)
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloaders=dataloader)
# automatically auto-loads the best weights from the previous run
predictions = trainer.predict(dataloaders=predict_dataloader)
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
model = LightningTransformer.load_from_checkpoint(PATH)
dataset = pl.demos.WikiText2()
test_dataloader = DataLoader(dataset)
trainer = pl.Trainer()
predictions = trainer.predict(model, dataloaders=test_dataloader)
Inference in Research
Expand All @@ -460,15 +468,31 @@ If you want to perform inference with the system, you can add a ``forward`` meth

.. code-block:: python
class Autoencoder(pl.LightningModule):
def forward(self, x):
return self.decoder(x)
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
def forward(self, batch):
inputs, target = batch
return self.model(inputs, target)
def training_step(self, batch, batch_idx):
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.model.parameters(), lr=0.1)
model = LightningTransformer(vocab_size=dataset.vocab_size)
model = Autoencoder()
model.eval()
with torch.no_grad():
reconstruction = model(embedding)
batch = dataloader.dataset[0]
pred = model(batch)
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
such as text generation:
Expand Down Expand Up @@ -618,7 +642,7 @@ checkpoint, which simplifies model re-instantiation after training.

.. code-block:: python
class LitMNIST(LightningModule):
class LitMNIST(pl.LightningModule):
def __init__(self, layer_1_dim=128, learning_rate=1e-2):
super().__init__()
# call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
Expand All @@ -642,7 +666,7 @@ parameters should be provided back when reloading the LightningModule. In this c

.. code-block:: python
class LitMNIST(LightningModule):
class LitMNIST(pl.LightningModule):
def __init__(self, loss_fx, generator_network, layer_1_dim=128):
super().__init__()
self.layer_1_dim = layer_1_dim
Expand Down

0 comments on commit a013386

Please sign in to comment.