Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs/source-pytorch/common/lightning_module.rst #18451

Merged
merged 18 commits into from
Sep 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 125 additions & 101 deletions docs/source-pytorch/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,38 +85,42 @@ Here are the only required methods.
.. code-block:: python

import lightning.pytorch as pl
import torch.nn as nn
import torch.nn.functional as F
import torch

from lightning.pytorch.demos import Transformer

class LitModel(pl.LightningModule):
def __init__(self):

class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.l1 = nn.Linear(28 * 28, 10)
self.model = Transformer(vocab_size=vocab_size)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def forward(self, inputs, target):
return self.model(inputs, target)

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
inputs, target = batch
output = self(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
return torch.optim.SGD(self.model.parameters(), lr=0.1)

Which you can train by doing:

.. code-block:: python

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer(max_epochs=1)
model = LitModel()
from torch.utils.data import DataLoader

dataset = pl.demos.WikiText2()
jxtngx marked this conversation as resolved.
Show resolved Hide resolved
dataloader = DataLoader(dataset)
model = LightningTransformer(vocab_size=dataset.vocab_size)

trainer.fit(model, train_dataloaders=train_loader)
trainer = pl.Trainer(fast_dev_run=100)
trainer.fit(model=model, train_dataloaders=dataloader)

The LightningModule has many convenience methods, but the core ones you need to know about are:
The LightningModule has many convenient methods, but the core ones you need to know about are:

.. list-table::
:widths: 50 50
Expand Down Expand Up @@ -152,15 +156,15 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul

.. code-block:: python

class LitClassifier(pl.LightningModule):
def __init__(self, model):
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = model
self.model = Transformer(vocab_size=vocab_size)

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss

Under the hood, Lightning does the following (pseudocode):
Expand Down Expand Up @@ -191,15 +195,15 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning

.. code-block:: python

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
def training_step(self, batch, batch_idx):
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))

# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss

The :meth:`~lightning.pytorch.core.module.LightningModule.log` method automatically reduces the
requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:
Expand Down Expand Up @@ -230,25 +234,25 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho

.. code-block:: python

def __init__(self):
super().__init__()
self.training_step_outputs = []


def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
self.training_step_outputs.append(preds)
return loss
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
self.training_step_outputs = []

def training_step(self, batch, batch_idx):
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
preds = ...
self.training_step_outputs.append(preds)
return loss

def on_train_epoch_end(self):
all_preds = torch.stack(self.training_step_outputs)
# do something with all preds
...
self.training_step_outputs.clear() # free memory
def on_train_epoch_end(self):
all_preds = torch.stack(self.training_step_outputs)
# do something with all preds
...
self.training_step_outputs.clear() # free memory


------------------
Expand All @@ -264,10 +268,10 @@ To activate the validation loop while training, override the :meth:`~lightning.p

.. code-block:: python

class LitModel(pl.LightningModule):
class LightningTransformer(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
inputs, target = batch
output = self.model(inputs, target)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)

Expand Down Expand Up @@ -300,8 +304,8 @@ and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`.

.. code-block:: python

model = Model()
trainer = Trainer()
model = LightningTransformer(vocab_size=dataset.vocab_size)
trainer = pl.Trainer()
trainer.validate(model)

.. note::
Expand All @@ -322,25 +326,26 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule

.. code-block:: python

def __init__(self):
super().__init__()
self.validation_step_outputs = []


def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
self.validation_step_outputs.append(pred)
return pred
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
self.validation_step_outputs = []

def validation_step(self, batch, batch_idx):
x, y = batch
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
pred = ...
self.validation_step_outputs.append(pred)
return pred

def on_validation_epoch_end(self):
all_preds = torch.stack(self.validation_step_outputs)
# do something with all preds
...
self.validation_step_outputs.clear() # free memory
def on_validation_epoch_end(self):
all_preds = torch.stack(self.validation_step_outputs)
# do something with all preds
...
self.validation_step_outputs.clear() # free memory

----------------

Expand All @@ -358,9 +363,10 @@ The only difference is that the test loop is only called when :meth:`~lightning.

.. code-block:: python

model = Model()
trainer = Trainer()
trainer.fit(model)
model = LightningTransformer(vocab_size=dataset.vocab_size)
dataloader = DataLoader(dataset)
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloaders=dataloader)

# automatically loads the best weights for you
trainer.test(model)
Expand All @@ -370,17 +376,23 @@ There are two ways to call ``test()``:
.. code-block:: python

# call after training
trainer = Trainer()
trainer.fit(model)
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloaders=dataloader)

# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)
trainer.test(dataloaders=test_dataloaders)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
model = LightningTransformer.load_from_checkpoint(PATH)
dataset = WikiText2()
test_dataloader = DataLoader(dataset)
trainer = pl.Trainer()
trainer.test(model, dataloaders=test_dataloader)

.. note::
`WikiText2` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only.
A proper split can be created in :meth:`lightning.pytorch.core.LightningModule.setup` or :meth:`lightning.pytorch.core.LightningDataModule.setup`.

.. note::

It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
Expand All @@ -403,24 +415,18 @@ By default, the :meth:`~lightning.pytorch.core.module.LightningModule.predict_st
:meth:`~lightning.pytorch.core.module.LightningModule.forward` method. In order to customize this behaviour,
simply override the :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` method.

For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
For the example let's override ``predict_step``:

.. code-block:: python

class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration

def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
self.model = Transformer(vocab_size=vocab_size)

# take average of `self.mc_iteration` iterations
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
return pred
def predict_step(self, batch):
inputs, target = batch
return self.model(inputs, target)

Under the hood, Lightning does the following (pseudocode):

Expand All @@ -440,15 +446,17 @@ There are two ways to call ``predict()``:
.. code-block:: python

# call after training
trainer = Trainer()
trainer.fit(model)
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloaders=dataloader)

# automatically auto-loads the best weights from the previous run
predictions = trainer.predict(dataloaders=predict_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
model = LightningTransformer.load_from_checkpoint(PATH)
dataset = pl.demos.WikiText2()
jxtngx marked this conversation as resolved.
Show resolved Hide resolved
test_dataloader = DataLoader(dataset)
trainer = pl.Trainer()
predictions = trainer.predict(model, dataloaders=test_dataloader)

Inference in Research
Expand All @@ -460,15 +468,31 @@ If you want to perform inference with the system, you can add a ``forward`` meth

.. code-block:: python

class Autoencoder(pl.LightningModule):
def forward(self, x):
return self.decoder(x)
class LightningTransformer(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)

def forward(self, batch):
inputs, target = batch
return self.model(inputs, target)

def training_step(self, batch, batch_idx):
inputs, target = batch
output = self.model(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss

def configure_optimizers(self):
return torch.optim.SGD(self.model.parameters(), lr=0.1)


model = LightningTransformer(vocab_size=dataset.vocab_size)

model = Autoencoder()
model.eval()
with torch.no_grad():
reconstruction = model(embedding)
batch = dataloader.dataset[0]
pred = model(batch)

The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
such as text generation:
Expand Down Expand Up @@ -618,7 +642,7 @@ checkpoint, which simplifies model re-instantiation after training.

.. code-block:: python

class LitMNIST(LightningModule):
class LitMNIST(pl.LightningModule):
def __init__(self, layer_1_dim=128, learning_rate=1e-2):
super().__init__()
# call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
Expand All @@ -642,7 +666,7 @@ parameters should be provided back when reloading the LightningModule. In this c

.. code-block:: python

class LitMNIST(LightningModule):
class LitMNIST(pl.LightningModule):
def __init__(self, loss_fx, generator_network, layer_1_dim=128):
super().__init__()
self.layer_1_dim = layer_1_dim
Expand Down