Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchSizeFinder defining max validation batches for entire training loop #18394

Closed
joncarter1 opened this issue Aug 25, 2023 · 1 comment · Fixed by #18854
Closed

BatchSizeFinder defining max validation batches for entire training loop #18394

joncarter1 opened this issue Aug 25, 2023 · 1 comment · Fixed by #18854
Labels
bug Something isn't working help wanted Open to be worked on tuner ver: 2.0.x

Comments

@joncarter1
Copy link

joncarter1 commented Aug 25, 2023

Bug description

When the BatchSizeFinder callback is used, the steps_per_trial parameter ends up defining how many validation batches to run during the entire length of training. This is a similar issue to that observed with the LR Finder (#17412).

What version are you seeing the problem on?

v2.0

How to reproduce the bug

(Adapted from @blainehoak #17412 )

import time
import torch
from torch.utils.data import DataLoader, Dataset

import lightning
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BatchSizeFinder


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class DummyDataModule(lightning.LightningDataModule):
    def __init__(
        self,
        length: int,
        size: int = 32,
        batch_size: int = 32,
    ):
        super().__init__()
        self.size = size
        self.length = length
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            RandomDataset(self.size, self.length),
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            RandomDataset(self.size, self.length),
            batch_size=self.batch_size,
            shuffle=False,
        )


class BoringModel(LightningModule):
    def __init__(self, size=32, lr=0.1):
        super().__init__()
        self.model = torch.nn.Linear(size, 2)
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        time.sleep(0.01)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        time.sleep(0.5)  # Making no. steps visible in progress bar
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=self.lr)


def run():
    STEPS = 13  # This ends up determining the number of validation steps
    LENGTH = 10_000
    datamodule = DummyDataModule(length=LENGTH, batch_size=32)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir="/tmp/lightning_logs",
        max_epochs=10,
        enable_model_summary=False,
        callbacks=[BatchSizeFinder(steps_per_trial=STEPS, max_trials=3)],
    )
    trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__":
    run()

Error messages and logs

No response

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce GTX 1070
    - available: True
    - version: 11.8
  • Lightning:
    - lightning: 2.0.7
    - lightning-cloud: 0.5.37
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.2
    - torch: 2.0.1
    - torchmetrics: 0.11.4
    - torchvision: 0.15.2

More info

No response

@joncarter1 joncarter1 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 25, 2023
@joncarter1
Copy link
Author

n.b. it seems to work OK when scaling the batch size via the tuner:

tuner = Tuner(trainer)
tuner.scale_batch_size(model=model, datamodule=datamodule, mode="power")
trainer.fit(model, datamodule=datamodule)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on tuner ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants