Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combining check_val_every_n_epoch and save_top_k is broken #9163

Closed
nicola-decao opened this issue Aug 27, 2021 · 4 comments
Closed

Combining check_val_every_n_epoch and save_top_k is broken #9163

nicola-decao opened this issue Aug 27, 2021 · 4 comments
Assignees
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on

Comments

@nicola-decao
Copy link

🐛 Bug

When using a Trainer with check_val_every_n_epoch = n with n > 1 the trained checks the validation every n epochs and this works. But when used in combination with a ModelCheckpoint with save_top_k = m with m > 1 it also saves the model at every iteration. It should instead check every n. This behaviour happened in previous versions (if I remember correctly it worked in 1.2. But now is broken.

To Reproduce

This piece of code with the BoringModel reproduces the issue. It saves the model every epoch instead of every n epochs (see bash in the bottom).

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=10,
        check_val_every_n_epoch=2,
        weights_summary=None,
        callbacks=[
            ModelCheckpoint(
                monitor="valid_loss",
                mode="min",
                dirpath="./",
                save_top_k=10,
                filename="model-{epoch:02d}-{valid_loss:.2f}",
            )
        ]
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

if __name__ == "__main__":
    run()
>>> ls -l *.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=01-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2579 Aug 27 09:39 model-epoch=02-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=03-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=04-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=05-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=06-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=07-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=08-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=09-valid_loss=-28.27.ckpt

Expected behavior

The model should check validation loss and save the model every check_val_every_n_epoch epochs. This should be the correct models saved:

>>> ls -l *.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=01-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=03-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=05-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=07-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=09-valid_loss=-28.27.ckpt

Environment

  • CUDA:
    • GPU:
      • TITAN X (Pascal)
      • TITAN X (Pascal)
      • TITAN X (Pascal)
      • TITAN X (Pascal)
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.8.1
    • pytorch-lightning: 1.4.4
    • tqdm: 4.62.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.10
    • version: 1 SMP Wed Feb 3 15:06:38 UTC 2021
@nicola-decao nicola-decao added bug Something isn't working help wanted Open to be worked on labels Aug 27, 2021
@nicola-decao
Copy link
Author

@Borda this seems an easy fix (although I do not know the codebase well enough to be quick on this one. But I suggest that a test is needed to not make this to broke again as it was working in previous versions.

@Borda
Copy link
Member

Borda commented Aug 27, 2021

@nicola-decao yep, do you want to take it and send a PR? 🐰

@Borda Borda added the checkpointing Related to checkpointing label Aug 27, 2021
@nicola-decao
Copy link
Author

nicola-decao commented Aug 27, 2021

@Borda I could try, TBH I have no idea how to fix this. I found this https://github.com/PyTorchLightning/pytorch-lightning/blob/6da3a6185f19a2dc64ee742b5a9b7fc200059582/pytorch_lightning/callbacks/model_checkpoint.py#L117
So in theory there is a way to do it: I can set every_n_epochs=args.check_val_every_n_epoch and there is no need to fix a bug. However, this seems impractical as once a user uses every_n_epochs should be clear that the ModelCheckpoint should be called every n epochs.

I can see 2 solutions here:

  1. we print a warning when Trainer.check_val_every_n_epoch != ModelCheckpoint.every_n_epochs to warn the user that this might not be the desired behaviour
  2. we set ModelCheckpoint.every_n_epochs = ModelCheckpoint.every_n_epochs when is not specified by the user

This is kind of important as one might check validation every n steps but save the model every m steps where n != m.

@nicola-decao
Copy link
Author

@Borda do you think this needs a fix or what? I can add the warning in case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

3 participants