Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving of checkpoint after every epoch using ModelCheckpoint if no metric is monitored #596

Closed
ghost opened this issue Dec 6, 2019 · 16 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on

Comments

@ghost
Copy link

ghost commented Dec 6, 2019

I may have missed something but it seems that ModelCheckpoint does not allow this based on the docs and code?

@ghost ghost added feature Is an improvement or enhancement help wanted Open to be worked on labels Dec 6, 2019
@williamFalcon
Copy link
Contributor

yup. i guess either return the epoch number as the thing to monitor or we modify to add this option

@ghost
Copy link
Author

ghost commented Dec 6, 2019

Hmm. I guess monitoring epoch number could work but I think some modifications should be made to handle the cases where there's no validation loop initialized. What do you think?

@simonjaq
Copy link

This would also be super important for me. I had a quite complicated experiment running on an older version relying on save_best_only = False saving every epoch without validation step. I lost quite a bit of training before I realized it was not saving checkpoints anymore. @williamFalcon is there a workaround? Like putting an empty validation step?

@knuser
Copy link

knuser commented Mar 5, 2020

I also need such functionality. @simonjaq did you found workaround for this problem?

@simonjaq
Copy link

simonjaq commented Mar 6, 2020

Hi. I made a custom checkpoint. Copied all the code but changed: def on_validation_end(self): to def on_epoch_end(self):
Then I'm calling the checkpoint in the Lightning loop:
def on_epoch_end(self): trainer.checkpoint_callback.on_epoch_end()

This works quite well. before starting the trainer I do this:
checkpoint_callback = ModelCheckpoint( filepath='./checkpoints/AD_15', save_top_k=10, monitor='g_loss', verbose=True, prefix='V0.13.8-RGB' )

@jamesjjcondon
Copy link
Contributor

Same here. I'm training one epoch in about 30minutes so am only validating every 10, say, to save time. So need to save every epoch without validating. @simonjaq can you point me in the right direction - which code did you copy? Was that callbacks.model_checkpoint or like here?

@simonjaq
Copy link

simonjaq commented Mar 15, 2020

Hello
I took the whole code from /pytorch_lightning/callbacks/model_checkpoint.py. And changed line 189 to on_epoch_end

(continues below code block)

"""
Callbacks
=========
Callbacks supported by Lightning
"""

import os
import shutil
import logging as log
import warnings

import numpy as np


class Callback(object):
    """Abstract base class used to build new callbacks."""

    def __init__(self):
        self._trainer = None

    def set_trainer(self, trainer):
        """Make a link to the trainer, so different things like `trainer.current_epoch`,
        `trainer.batch_idx`, `trainer.global_step` can be used."""
        self._trainer = trainer

    def on_epoch_begin(self):
        """Called when the epoch begins."""
        pass

    def on_epoch_end(self):
        """Called when the epoch ends."""
        pass

    def on_batch_begin(self):
        """Called when the training batch begins."""
        pass

    def on_batch_end(self):
        """Called when the training batch ends."""
        pass

    def on_train_begin(self):
        """Called when the train begins."""
        pass

    def on_train_end(self):
        """Called when the train ends."""
        pass

    def on_validation_begin(self):
        """Called when the validation loop begins."""
        pass

    def on_validation_end(self):
        """Called when the validation loop ends."""
        pass

    def on_test_begin(self):
        """Called when the test begins."""
        pass

    def on_test_end(self):
        """Called when the test ends."""
        pass


_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"


class ModelCheckpoint(Callback):
    r"""
    Save the model after every epoch.
    Args:
        filepath (str): path to save the model file.
            Can contain named formatting options to be auto-filled.
            Example::
                # save epoch and val_loss in name
                ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
                # saves file like: /path/epoch_2-val_loss_0.2.hdf5
        monitor (str): quantity to monitor.
        verbose (bool): verbosity mode, 0 or 1.
        save_top_k (int): if `save_top_k == k`,
            the best k models according to
            the quantity monitored will be saved.
            if `save_top_k == 0`, no models are saved.
            if `save_top_k == -1`, all models are saved.
            Please note that the monitors are checked every `period` epochs.
            if `save_top_k >= 2` and the callback is called multiple
            times inside an epoch, the name of the saved file will be
            appended with a version count starting with `v0`.
        mode (str): one of {auto, min, max}.
            If `save_top_k != 0`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only (bool): if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period (int): Interval (number of epochs) between checkpoints.
    Example::
        from pytorch_lightning import Trainer
        from pytorch_lightning.callbacks import ModelCheckpoint
        checkpoint_callback = ModelCheckpoint(filepath='my_path')
        Trainer(checkpoint_callback=checkpoint_callback)
        # saves checkpoints to my_path whenever 'val_loss' has a new min
    """

    def __init__(self, filepath, monitor='val_loss', verbose=0,
                 save_top_k=1, save_weights_only=False,
                 mode='auto', period=1, prefix=''):
        super(ModelCheckpoint, self).__init__()
        if (
            save_top_k and
            os.path.isdir(filepath) and
            len(os.listdir(filepath)) > 0
        ):
            warnings.warn(
                f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
                "All files in this directory will be deleted when a checkpoint is saved!"
            )

        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        os.makedirs(filepath, exist_ok=True)
        self.save_top_k = save_top_k
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_check = 0
        self.prefix = prefix
        self.best_k_models = {}
        # {filename: monitor}
        self.kth_best_model = ''
        self.best = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn(
                f'ModelCheckpoint mode {mode} is unknown, '
                'fallback to auto mode.', RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.kth_value = np.Inf
            self.mode = 'min'
        elif mode == 'max':
            self.monitor_op = np.greater
            self.kth_value = -np.Inf
            self.mode = 'max'
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.kth_value = -np.Inf
                self.mode = 'max'
            else:
                self.monitor_op = np.less
                self.kth_value = np.Inf
                self.mode = 'min'

    def _del_model(self, filepath):
        dirpath = os.path.dirname(filepath)

        # make paths
        os.makedirs(dirpath, exist_ok=True)

        try:
            shutil.rmtree(filepath)
        except OSError:
            os.remove(filepath)

    def _save_model(self, filepath):
        dirpath = os.path.dirname(filepath)

        # make paths
        os.makedirs(dirpath, exist_ok=True)

        # delegate the saving to the model
        self.save_function(filepath)

    def check_monitor_top_k(self, current):
        less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
        if less_than_k_models:
            return True
        return self.monitor_op(current, self.best_k_models[self.kth_best_model])

    def on_epoch_end(self):
        assert self._trainer is not None, _NO_TRAINER_ERROR_MSG

        logs = self._trainer.callback_metrics
        epoch = self._trainer.current_epoch
        self.epochs_since_last_check += 1

        if self.save_top_k == 0:
            # no models are saved
            return
        if self.epochs_since_last_check >= self.period:
            self.epochs_since_last_check = 0
            filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
            version_cnt = 0
            while os.path.isfile(filepath):
                # this epoch called before
                filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
                version_cnt += 1

            if self.save_top_k != -1:
                current = logs.get(self.monitor)

                if current is None:
                    warnings.warn(
                        f'Can save best model only with {self.monitor} available,'
                        ' skipping.', RuntimeWarning)
                else:
                    if self.check_monitor_top_k(current):

                        # remove kth
                        if len(self.best_k_models.keys()) == self.save_top_k:
                            delpath = self.kth_best_model
                            self.best_k_models.pop(self.kth_best_model)
                            self._del_model(delpath)

                        self.best_k_models[filepath] = current
                        if len(self.best_k_models.keys()) == self.save_top_k:
                            # monitor dict has reached k elements
                            if self.mode == 'min':
                                self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get)
                            else:
                                self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get)
                            self.kth_value = self.best_k_models[self.kth_best_model]

                        if self.mode == 'min':
                            self.best = min(self.best_k_models.values())
                        else:
                            self.best = max(self.best_k_models.values())
                        if self.verbose > 0:
                            log.info(
                                f'\nEpoch {epoch:05d}: {self.monitor} reached'
                                f' {current:0.5f} (best {self.best:0.5f}), saving model to'
                                f' {filepath} as top {self.save_top_k}')
                        self._save_model(filepath)

                    else:
                        if self.verbose > 0:
                            log.info(
                                f'\nEpoch {epoch:05d}: {self.monitor}'
                                f' was not in top {self.save_top_k}')

            else:
                if self.verbose > 0:
                    log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
                self._save_model(filepath)`

In my train loop I call the checkpoint:

class DiscriminatorGenerator(pl.LightningModule):

... ## all my training code ....

    def on_epoch_end(self):
        
        trainer.checkpoint_callback.on_epoch_end() 

My training block looks like this:

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger


checkpoint_callback = ModelCheckpoint(
    filepath='./checkpoints/AD_15',  
    save_top_k=10,  
    monitor='g_loss',
    #save_best_only=False,
    verbose=True,
    prefix='V0.13.8-RGB'
)



#default logger used by trainer
logger = TensorBoardLogger(
    save_dir='./logs',
    version=252,
    name='lightning_logs'
)

trainer = Trainer(
    logger=logger,
    min_epochs=2000,
    max_epochs=5000,
    checkpoint_callback=checkpoint_callback,
    amp_level='O1', use_amp=True,
    gpus=1,
    weights_summary='full'
    
)

This works for me. Note that I work in a Jupyter notebook and just insert the modified callback somewhere at the beginning of the notebook. This should also work by importing your modified callback.

@jamesjjcondon
Copy link
Contributor

Hopefully this is an improvement @williamFalcon (but still doesn't allow saving all models independent of validation).

@jamesjjcondon
Copy link
Contributor

jamesjjcondon commented Mar 18, 2020

Just for anyone else, I couldn't get the above to work. pl versions are different. Seemed to get messy putting trainer into model. I'm now saving every epoch, while still validating n > 1 epochs using this custom callback. Doesn't require adjusting of callbacks.model_checkpoint.py. fairly hacky and redoes filenames, but works.

`
class Non_val_epoch_saves(pl.Callback):
def init(self, iteration, filepath):
self.iteration = iteration
self.filepath = filepath
self.ver = int(self.iteration[-1])
if any(self.iteration in x for x in os.listdir(self.filepath)):
self.ver += 1

def on_epoch_end(self, trainer, pl_module):
    
    metrics = trainer.callback_metrics
    if 'avg_val_loss' in metrics:
        avl = metrics['avg_val_loss']
        avl = f'{avl:.3f}'
    else:
        avl = 'NA'
    tl = metrics.get(trainer.checkpoint_callback.monitor)
    current_tl = f'{tl:0.3f}'
    self.name = self.iteration[:-1] + str(self.ver) + '_epo='+ \
        str(trainer.current_epoch) + \
            '_tloss' + '=' + \
        current_tl + '_' + \
            'avloss=' + avl + \
            '.ckpt'
    trainer.checkpoint_callback._save_model(filepath=os.path.join(
        self.filepath, self.name)
        )`

which is called like:
`iteration = '18Mar_v0'

callback_dir = os.path.join(DATADIR,'dev_test_models/ckpts_' + iteration +'/')

callback = ModelCheckpoint(
        filepath=callback_dir,
        monitor='loss',
        verbose=1,
        save_top_k=0,
        save_weights_only=False,
        mode='min',
        period=1,
        prefix='''
        )

trainer = Trainer(
    accumulate_grad_batches=6, 
    callbacks=[Non_val_epoch_saves(
            iteration=iteration,
            filepath=callback_dir
            )],
    checkpoint_callback=callback,
    check_val_every_n_epoch=2,
    distributed_backend='ddp')

`

@lizhitwo
Copy link

lizhitwo commented Apr 24, 2020

What I did was

class ModelCheckpointAtEpochEnd(pl.Callback):
    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        metrics['epoch'] = trainer.current_epoch
        if trainer.disable_validation:
            trainer.checkpoint_callback.on_validation_end(trainer, pl_module)

And add this callback to the trainer too, and set the checkpoint_callback to do its thing. So it would be nice if this is added inside Lightning.

@stale
Copy link

stale bot commented Jun 23, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jun 23, 2020
@stale stale bot closed this as completed Jul 2, 2020
@andrewjong
Copy link

Please reopen this issue.

@Skyy93
Copy link
Contributor

Skyy93 commented Feb 18, 2021

I would appreciate this feature too

@ankit61
Copy link

ankit61 commented May 4, 2021

I agree - this feature will be helpful!

@ananthsub
Copy link
Contributor

ananthsub commented May 4, 2021

This is supported today inside the ModelCheckpoint callback

  • You can use save_top_k=-1 to save a new checkpoint whenever the callback is run
  • Or you can set save_last=True to save a checkpoint to the file last.ckpt (by default) whenever the checkpoint callback is run

@ankit61
Copy link

ankit61 commented May 4, 2021

I see - that's great

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

9 participants