-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
"Wrong" code example in doc auto-scaling-of-batch-size section #5967
Comments
Something like this #5968 will probably fix this issue but it's not good. |
from argparse import ArgumentParser
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
if __name__ == '__main__':
pl.seed_everything(1234)
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()
dm = MNISTDataModule.from_argparse_args(args)
model = LitClassifier(args.hidden_dim, args.learning_rate)
trainer = Trainer(
gpus=1,
accelerator='dp',
auto_scale_batch_size='binsearch'
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=128, max_trials=3, datamodule=dm)
model.hparams.batch_size = new_batch_size
trainer.fit(model, datamodule=dm) minimal reproducible example |
I have the same issue. If you first call fit, the problem does not accure, because fit calls So it would be
Another workaround would be to add this line of code in TrainLoop.setup_fit with deleting it from trainer.fit. I do not know which solution is more idiomatic to PyTorch lightning, but I will prepare pool request of first solution. |
See the linked PR I am already working on |
Have written in Lightning-AI#5967 how it fixes the bug.
@awaelchli, sorry, didn't manage to find it. Looks nice and allows to use autoscalebatch and lrfinder without calling tune. Hope for faster approval. |
Yes, working on adding tests right now, hopefully done soon! Thanks for your help. |
Could You explain why do we have to add this also to lr_finder? |
Because for LR finder it will be a similar bug, just not with batch_size but with the lr attribute. |
Will it? Because we can not write LR in datamodule, but only in the PLmodule. I can not reproduce this bug with learning rate. |
🐛 Bug
In the doc auto-scaling-of-batch-size section, a code example is
However, this will not work as expected in the case where a LightningModule contains an attribute
self.datamodule
. Following the code will giveMisconfigurationException: Field batch_size not found in both model and model.hparams
.To Reproduce
See my one-page code
Expected behavior
Tuner should find the attibute
batch_size
inmodel.datamodule
in the methodTuner.scale_batch_size()
.Environment
This issue should be independent of environments.
Additional context
I took a look at the source code and found out that if we call
Trainer.tune()
directly, the invoke chain istrainer.tune() -> tuner.tune()->tuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->...
while the invoke chain of calling
Tuner.scale_batch_size()
istuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->...
.The problem is that
lightning_hasattr(model, attribute)
cannot find the attributemodel.datamodule.batch_size
if we skip the registration steps intrainer.tune()
.The text was updated successfully, but these errors were encountered: