diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a10d80b91678..97f5f038970c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `log_model=best_and_last` option in `pytorch_lightning.loggers.WandbLogger`. + + - Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944)) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index c93b8d02bca16..8b19097044ae2 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -64,6 +64,8 @@ class WandbLogger(LightningLoggerBase): as W&B artifacts. * if ``log_model == 'all'``, checkpoints are logged during training. + * if ``log_model == 'best_and_last'``, checkpoints are logged during training and only the best and the + last checkpoints (according to ModelCheckpoint) are kept as wandb artifacts. Previous versions are automatically deleted. * if ``log_model == True``, checkpoints are logged at the end of training, except when :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. @@ -248,7 +250,11 @@ def version(self) -> Optional[str]: def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: # log checkpoints as artifacts - if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: + if ( + self._log_model == "all" + or (self._log_model is True and checkpoint_callback.save_top_k == -1) + or self._log_model == "best_and_last" + ): self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback @@ -301,3 +307,13 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelChe self.experiment.log_artifact(artifact, aliases=aliases) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t + + if self._log_model == "best_and_last": + # Clean up previous artifacts. + # Adapted from https://gitbook-docs.wandb.ai/guides/artifacts/api#cleaning-up-unused-versions + api = wandb.Api(overrides={"project": self.experiment.project}) + + for version in api.artifact_versions(f"model-{self.experiment.id}", "model"): + # Clean up all versions that don't have an alias such as 'latest'. + if len(version.aliases) == 0: + version.delete() diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 8388d7877ab7e..bf33c1a974fc7 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -176,6 +176,16 @@ def test_wandb_log_model(wandb, tmpdir): trainer.fit(model) assert wandb.init().log_artifact.call_count == 2 + # test log_model='best_and_last' + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + logger = WandbLogger(log_model="best_and_last") + logger.experiment.id = "1" + logger.experiment.project_name.return_value = "project" + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + assert wandb.init().log_artifact.call_count == 2 + # test log_model=False wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() @@ -203,7 +213,7 @@ def test_wandb_log_model(wandb, tmpdir): type="model", metadata={ "score": None, - "original_filename": "epoch=1-step=5-v3.ckpt", + "original_filename": "epoch=1-step=5-v4.ckpt", "ModelCheckpoint": { "monitor": None, "mode": "min", diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index c37681e4831ca..d21e8efc7a5cb 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -191,7 +191,7 @@ def training_epoch_end(self, outputs) -> None: def test_batch_loop_releases_loss(tmpdir): - """Test that loss/graph is released so that it can be garbage collected before the next training step""" + """Test that loss/graph is released so that it can be garbage collected before the next training step.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx):