From 4cad7a61742a566fce0100c66fd73c04abdf7b8e Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Thu, 31 Aug 2023 20:22:19 -0400 Subject: [PATCH 01/11] update common/lightning_module --- .../common/lightning_module.rst | 230 ++++++++++-------- 1 file changed, 128 insertions(+), 102 deletions(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 3a2691880a20c..519f88840cd0d 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -84,39 +84,45 @@ Here are the only required methods. .. code-block:: python - import lightning.pytorch as pl - import torch.nn as nn - import torch.nn.functional as F + import pytorch_lightning as pl + import torch + from pytorch_lightning.demos.transformer import Transformer - class LitModel(pl.LightningModule): - def __init__(self): + + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): super().__init__() - self.l1 = nn.Linear(28 * 28, 10) + self.model = Transformer(vocab_size=vocab_size) - def forward(self, x): - return torch.relu(self.l1(x.view(x.size(0), -1))) + def forward(self, batch): + input, target = batch + return self.model(input.view(1, -1), target.view(1, -1)) def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) + input, target = batch + output = self.model(input, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) + return torch.optim.SGD(self.model.parameters(), lr=0.1) Which you can train by doing: .. code-block:: python - train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) - trainer = pl.Trainer(max_epochs=1) - model = LitModel() + from pytorch_lightning.demos.transformer import WikiText2 + from torch.utils.data import DataLoader + + dataset = WikiText2() + dataloader = DataLoader(dataset) + model = LightningTransformer(vocab_size=dataset.vocab_size) - trainer.fit(model, train_dataloaders=train_loader) + trainer = pl.Trainer(fast_dev_run=True) + trainer.fit(model=model, train_dataloaders=dataloader) -The LightningModule has many convenience methods, but the core ones you need to know about are: +The LightningModule has many convenient methods, but the core ones you need to know about are: .. list-table:: :widths: 50 50 @@ -152,15 +158,15 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul .. code-block:: python - class LitClassifier(pl.LightningModule): - def __init__(self, model): + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): super().__init__() - self.model = model + self.model = Transformer(vocab_size=vocab_size) def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) + input, target = batch + output = self.model(input, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss Under the hood, Lightning does the following (pseudocode): @@ -191,15 +197,15 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning .. code-block:: python - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) + def training_step(self, batch, batch_idx): + input, target = batch + output = self.model(input, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) - # logs metrics for each training_step, - # and the average across the epoch, to the progress bar and logger - self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss + # logs metrics for each training_step, + # and the average across the epoch, to the progress bar and logger + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss The :meth:`~lightning.pytorch.core.module.LightningModule.log` method automatically reduces the requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood: @@ -230,25 +236,25 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho .. code-block:: python - def __init__(self): - super().__init__() - self.training_step_outputs = [] - - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - preds = ... - self.training_step_outputs.append(preds) - return loss + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.model = Transformer(vocab_size=vocab_size) + self.training_step_outputs = [] + def training_step(self, batch, batch_idx): + input, target = batch + output = self.model(input, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + preds = ... + self.training_step_outputs.append(preds) + return loss - def on_train_epoch_end(self): - all_preds = torch.stack(self.training_step_outputs) - # do something with all preds - ... - self.training_step_outputs.clear() # free memory + def on_train_epoch_end(self): + all_preds = torch.stack(self.training_step_outputs) + # do something with all preds + ... + self.training_step_outputs.clear() # free memory ------------------ @@ -264,10 +270,10 @@ To activate the validation loop while training, override the :meth:`~lightning.p .. code-block:: python - class LitModel(pl.LightningModule): + class LightningTransformer(pl.LightningModule): def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) + input, target = batch + output = self.model(input, target) loss = F.cross_entropy(y_hat, y) self.log("val_loss", loss) @@ -300,8 +306,8 @@ and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`. .. code-block:: python - model = Model() - trainer = Trainer() + model = LightningTransformer(vocab_size=dataset.vocab_size) + trainer = pl.Trainer() trainer.validate(model) .. note:: @@ -322,25 +328,26 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule .. code-block:: python - def __init__(self): - super().__init__() - self.validation_step_outputs = [] - - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - pred = ... - self.validation_step_outputs.append(pred) - return pred + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.model = Transformer(vocab_size=vocab_size) + self.validation_step_outputs = [] + def validation_step(self, batch, batch_idx): + x, y = batch + input, target = batch + output = self.model(input, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + pred = ... + self.validation_step_outputs.append(pred) + return pred - def on_validation_epoch_end(self): - all_preds = torch.stack(self.validation_step_outputs) - # do something with all preds - ... - self.validation_step_outputs.clear() # free memory + def on_validation_epoch_end(self): + all_preds = torch.stack(self.validation_step_outputs) + # do something with all preds + ... + self.validation_step_outputs.clear() # free memory ---------------- @@ -358,9 +365,10 @@ The only difference is that the test loop is only called when :meth:`~lightning. .. code-block:: python - model = Model() - trainer = Trainer() - trainer.fit(model) + model = LightningTransformer(vocab_size=dataset.vocab_size) + dataloader = DataLoader(dataset) + trainer = pl.Trainer() + trainer.fit(model=model, train_dataloaders=dataloader) # automatically loads the best weights for you trainer.test(model) @@ -370,17 +378,23 @@ There are two ways to call ``test()``: .. code-block:: python # call after training - trainer = Trainer() - trainer.fit(model) + trainer = pl.Trainer() + trainer.fit(model=model, train_dataloaders=dataloader) # automatically auto-loads the best weights from the previous run - trainer.test(dataloaders=test_dataloader) + trainer.test(dataloaders=trainer.test_dataloaders) # or call with pretrained model - model = MyLightningModule.load_from_checkpoint(PATH) - trainer = Trainer() + model = LightningTransformer.load_from_checkpoint(PATH) + dataset = WikiText2() + test_dataloader = DataLoader(dataset) + trainer = pl.Trainer() trainer.test(model, dataloaders=test_dataloader) +.. note:: + `WikiText2` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only. + A proper split can be created in :meth:`lightning.pytorch.core.module.LightningModule.setup` or :meth:`lightning.pytorch.core.module.LightningDataModule.setup`. + .. note:: It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once. @@ -403,24 +417,18 @@ By default, the :meth:`~lightning.pytorch.core.module.LightningModule.predict_st :meth:`~lightning.pytorch.core.module.LightningModule.forward` method. In order to customize this behaviour, simply override the :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` method. -For the example let's override ``predict_step`` and try out `Monte Carlo Dropout `_: +For the example let's override ``predict_step``: .. code-block:: python - class LitMCdropoutModel(pl.LightningModule): - def __init__(self, model, mc_iteration): + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): super().__init__() - self.model = model - self.dropout = nn.Dropout() - self.mc_iteration = mc_iteration - - def predict_step(self, batch, batch_idx): - # enable Monte Carlo Dropout - self.dropout.train() + self.model = Transformer(vocab_size=vocab_size) - # take average of `self.mc_iteration` iterations - pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0) - return pred + def predict_step(self, batch): + input, target = batch + return self.model(input, target) Under the hood, Lightning does the following (pseudocode): @@ -440,15 +448,17 @@ There are two ways to call ``predict()``: .. code-block:: python # call after training - trainer = Trainer() - trainer.fit(model) + trainer = pl.Trainer() + trainer.fit(model=model, train_dataloaders=dataloader) # automatically auto-loads the best weights from the previous run predictions = trainer.predict(dataloaders=predict_dataloader) # or call with pretrained model - model = MyLightningModule.load_from_checkpoint(PATH) - trainer = Trainer() + model = LightningTransformer.load_from_checkpoint(PATH) + dataset = WikiText2() + test_dataloader = DataLoader(dataset) + trainer = pl.Trainer() predictions = trainer.predict(model, dataloaders=test_dataloader) Inference in Research @@ -460,15 +470,31 @@ If you want to perform inference with the system, you can add a ``forward`` meth .. code-block:: python - class Autoencoder(pl.LightningModule): - def forward(self, x): - return self.decoder(x) + class LightningTransformer(pl.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.model = Transformer(vocab_size=vocab_size) + def forward(self, batch): + input, target = batch + return self.model(input.view(1, -1), target.view(1, -1)) + + def training_step(self, batch, batch_idx): + input, target = batch + output = self.model(input, target) + loss = torch.nn.functional.nll_loss(output, target.view(-1)) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.model.parameters(), lr=0.1) + + + model = LightningTransformer(vocab_size=dataset.vocab_size) - model = Autoencoder() model.eval() with torch.no_grad(): - reconstruction = model(embedding) + batch = dataloader.dataset[0] + pred = model(batch) The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation: @@ -618,7 +644,7 @@ checkpoint, which simplifies model re-instantiation after training. .. code-block:: python - class LitMNIST(LightningModule): + class LitMNIST(pl.LightningModule): def __init__(self, layer_1_dim=128, learning_rate=1e-2): super().__init__() # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint @@ -642,7 +668,7 @@ parameters should be provided back when reloading the LightningModule. In this c .. code-block:: python - class LitMNIST(LightningModule): + class LitMNIST(pl.LightningModule): def __init__(self, loss_fx, generator_network, layer_1_dim=128): super().__init__() self.layer_1_dim = layer_1_dim From 648e8295d62b3b6019388ed5119aa64be7d89143 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 1 Sep 2023 07:41:18 -0400 Subject: [PATCH 02/11] update import convention --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 519f88840cd0d..afa695ca1a3de 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -84,7 +84,7 @@ Here are the only required methods. .. code-block:: python - import pytorch_lightning as pl + import lightning.pytorch as pl import torch from pytorch_lightning.demos.transformer import Transformer From 62e20e2e2e8ebe6bbc9d9da03d96f355a038f88c Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 1 Sep 2023 10:16:50 -0400 Subject: [PATCH 03/11] update imports --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index afa695ca1a3de..4f56f13f3c5fd 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -87,7 +87,7 @@ Here are the only required methods. import lightning.pytorch as pl import torch - from pytorch_lightning.demos.transformer import Transformer + from lightning.pytorch.demos import LightningTransformer class LightningTransformer(pl.LightningModule): From 77318dbaa83414f843a252342c4c5f0e315b9e5b Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 1 Sep 2023 10:33:02 -0400 Subject: [PATCH 04/11] update imports --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 4f56f13f3c5fd..639cf98ca3e62 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -112,7 +112,7 @@ Which you can train by doing: .. code-block:: python - from pytorch_lightning.demos.transformer import WikiText2 + from lightning.pytorch.demos import WikiText2 from torch.utils.data import DataLoader dataset = WikiText2() From f813873e8954d706771b9640c04ec95d58dcbb82 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 1 Sep 2023 14:36:51 -0400 Subject: [PATCH 05/11] fix import from demos --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 639cf98ca3e62..596d9fa23d2ed 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -87,7 +87,7 @@ Here are the only required methods. import lightning.pytorch as pl import torch - from lightning.pytorch.demos import LightningTransformer + from lightning.pytorch.demos import Transformer class LightningTransformer(pl.LightningModule): From a6cbab69ad9abd00e373fb812cea5ddcb8cec57b Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Sep 2023 14:22:56 -0400 Subject: [PATCH 06/11] implement suggestions --- docs/source-pytorch/common/lightning_module.rst | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 596d9fa23d2ed..a63644d14f718 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -95,13 +95,12 @@ Here are the only required methods. super().__init__() self.model = Transformer(vocab_size=vocab_size) - def forward(self, batch): - input, target = batch - return self.model(input.view(1, -1), target.view(1, -1)) + def forward(self, input, target): + return self.model(input, target) def training_step(self, batch, batch_idx): input, target = batch - output = self.model(input, target) + output = self(input, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss @@ -119,7 +118,7 @@ Which you can train by doing: dataloader = DataLoader(dataset) model = LightningTransformer(vocab_size=dataset.vocab_size) - trainer = pl.Trainer(fast_dev_run=True) + trainer = pl.Trainer(fast_dev_run=100) trainer.fit(model=model, train_dataloaders=dataloader) The LightningModule has many convenient methods, but the core ones you need to know about are: @@ -382,7 +381,7 @@ There are two ways to call ``test()``: trainer.fit(model=model, train_dataloaders=dataloader) # automatically auto-loads the best weights from the previous run - trainer.test(dataloaders=trainer.test_dataloaders) + trainer.test(dataloaders=test_dataloaders) # or call with pretrained model model = LightningTransformer.load_from_checkpoint(PATH) From b6a78026a5fc48d8786eaafec54a5a4941d5a8fb Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:32:10 -0400 Subject: [PATCH 07/11] commit suggestion Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- docs/source-pytorch/common/lightning_module.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index a63644d14f718..39d55f7db8633 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -111,7 +111,6 @@ Which you can train by doing: .. code-block:: python - from lightning.pytorch.demos import WikiText2 from torch.utils.data import DataLoader dataset = WikiText2() From 9a17f1461712165239406c1b0f4dbb7b86c25b2d Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:32:30 -0400 Subject: [PATCH 08/11] commit suggestion Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 39d55f7db8633..e523acb5d8ab9 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -113,7 +113,7 @@ Which you can train by doing: from torch.utils.data import DataLoader - dataset = WikiText2() + dataset = pl.demos.WikiText2() dataloader = DataLoader(dataset) model = LightningTransformer(vocab_size=dataset.vocab_size) From acf0917b1414f85e3a12bce52544fc34f7f9da2b Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:32:56 -0400 Subject: [PATCH 09/11] commit suggestion Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index e523acb5d8ab9..d8e3754dfe858 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -391,7 +391,7 @@ There are two ways to call ``test()``: .. note:: `WikiText2` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only. - A proper split can be created in :meth:`lightning.pytorch.core.module.LightningModule.setup` or :meth:`lightning.pytorch.core.module.LightningDataModule.setup`. + A proper split can be created in :meth:`lightning.pytorch.core.LightningModule.setup` or :meth:`lightning.pytorch.core.LightningDataModule.setup`. .. note:: From 646b5b26e1d101a139bc5fb16c0751d5fc53322a Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:33:10 -0400 Subject: [PATCH 10/11] commit suggestion Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index d8e3754dfe858..80faefee86e7c 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -454,7 +454,7 @@ There are two ways to call ``predict()``: # or call with pretrained model model = LightningTransformer.load_from_checkpoint(PATH) - dataset = WikiText2() + dataset = pl.demos.WikiText2() test_dataloader = DataLoader(dataset) trainer = pl.Trainer() predictions = trainer.predict(model, dataloaders=test_dataloader) From 93a94fd330b8da663797b727dc09293d37be46d5 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:39:31 -0400 Subject: [PATCH 11/11] remove python keyword - change `input` to `inputs` --- .../common/lightning_module.rst | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 80faefee86e7c..19153919f8da7 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -95,12 +95,12 @@ Here are the only required methods. super().__init__() self.model = Transformer(vocab_size=vocab_size) - def forward(self, input, target): - return self.model(input, target) + def forward(self, inputs, target): + return self.model(inputs, target) def training_step(self, batch, batch_idx): - input, target = batch - output = self(input, target) + inputs, target = batch + output = self(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss @@ -162,8 +162,8 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul self.model = Transformer(vocab_size=vocab_size) def training_step(self, batch, batch_idx): - input, target = batch - output = self.model(input, target) + inputs, target = batch + output = self.model(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss @@ -196,8 +196,8 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning .. code-block:: python def training_step(self, batch, batch_idx): - input, target = batch - output = self.model(input, target) + inputs, target = batch + output = self.model(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) # logs metrics for each training_step, @@ -241,8 +241,8 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho self.training_step_outputs = [] def training_step(self, batch, batch_idx): - input, target = batch - output = self.model(input, target) + inputs, target = batch + output = self.model(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) preds = ... self.training_step_outputs.append(preds) @@ -270,8 +270,8 @@ To activate the validation loop while training, override the :meth:`~lightning.p class LightningTransformer(pl.LightningModule): def validation_step(self, batch, batch_idx): - input, target = batch - output = self.model(input, target) + inputs, target = batch + output = self.model(inputs, target) loss = F.cross_entropy(y_hat, y) self.log("val_loss", loss) @@ -334,8 +334,8 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule def validation_step(self, batch, batch_idx): x, y = batch - input, target = batch - output = self.model(input, target) + inputs, target = batch + output = self.model(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) pred = ... self.validation_step_outputs.append(pred) @@ -425,8 +425,8 @@ For the example let's override ``predict_step``: self.model = Transformer(vocab_size=vocab_size) def predict_step(self, batch): - input, target = batch - return self.model(input, target) + inputs, target = batch + return self.model(inputs, target) Under the hood, Lightning does the following (pseudocode): @@ -474,12 +474,12 @@ If you want to perform inference with the system, you can add a ``forward`` meth self.model = Transformer(vocab_size=vocab_size) def forward(self, batch): - input, target = batch - return self.model(input.view(1, -1), target.view(1, -1)) + inputs, target = batch + return self.model(inputs, target) def training_step(self, batch, batch_idx): - input, target = batch - output = self.model(input, target) + inputs, target = batch + output = self.model(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) return loss