From 89d0d7b40cd6dae66672987bc9125350a239c89b Mon Sep 17 00:00:00 2001 From: ohayonguy Date: Tue, 7 Sep 2021 11:25:08 +0300 Subject: [PATCH 1/4] Added support for log_model='best_and_last' option in wandb logger --- pytorch_lightning/loggers/wandb.py | 15 ++++++++++++++- tests/loggers/test_wandb.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index c93b8d02bca16..9179e8ebb07d5 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -64,6 +64,8 @@ class WandbLogger(LightningLoggerBase): as W&B artifacts. * if ``log_model == 'all'``, checkpoints are logged during training. + * if ``log_model == 'best_and_last'``, checkpoints are logged during training and only the best and the + last checkpoints (according to ModelCheckpoint) are kept as wandb artifacts. Previous versions are automatically deleted. * if ``log_model == True``, checkpoints are logged at the end of training, except when :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. @@ -248,7 +250,8 @@ def version(self) -> Optional[str]: def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: # log checkpoints as artifacts - if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: + if self._log_model == "all" or (self._log_model is True and checkpoint_callback.save_top_k == -1) \ + or self._log_model == 'best_and_last': self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback @@ -301,3 +304,13 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelChe self.experiment.log_artifact(artifact, aliases=aliases) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t + + if self._log_model == 'best_and_last': + # Clean up previous artifacts. + # Adapted from https://gitbook-docs.wandb.ai/guides/artifacts/api#cleaning-up-unused-versions + api = wandb.Api(overrides={'project': self.experiment.project}) + + for version in api.artifact_versions(f"model-{self.experiment.id}", "model"): + # Clean up all versions that don't have an alias such as 'latest'. + if len(version.aliases) == 0: + version.delete() diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 8388d7877ab7e..bbd9ad99c95f4 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -176,6 +176,16 @@ def test_wandb_log_model(wandb, tmpdir): trainer.fit(model) assert wandb.init().log_artifact.call_count == 2 + # test log_model='best_and_last' + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + logger = WandbLogger(log_model="best_and_last") + logger.experiment.id = "1" + logger.experiment.project_name.return_value = "project" + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + assert wandb.init().log_artifact.call_count == 2 + # test log_model=False wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() From 55bdc95e46395179815aeed3ccfe84bf549ee10b Mon Sep 17 00:00:00 2001 From: ohayonguy Date: Tue, 7 Sep 2021 11:47:57 +0300 Subject: [PATCH 2/4] fixed test_wandb.py --- tests/loggers/test_wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index bbd9ad99c95f4..bf33c1a974fc7 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -213,7 +213,7 @@ def test_wandb_log_model(wandb, tmpdir): type="model", metadata={ "score": None, - "original_filename": "epoch=1-step=5-v3.ckpt", + "original_filename": "epoch=1-step=5-v4.ckpt", "ModelCheckpoint": { "monitor": None, "mode": "min", From 723a47bf70e77735482caf8967562106dd8b92c0 Mon Sep 17 00:00:00 2001 From: ohayonguy Date: Tue, 7 Sep 2021 11:51:51 +0300 Subject: [PATCH 3/4] updated changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a10d80b91678..97f5f038970c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `log_model=best_and_last` option in `pytorch_lightning.loggers.WandbLogger`. + + - Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944)) From e0b0fa7b848da57896ecd8c0286ccbcd9a0b47f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Sep 2021 08:53:40 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loggers/wandb.py | 11 +++++++---- tests/trainer/loops/test_training_loop.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 9179e8ebb07d5..8b19097044ae2 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -250,8 +250,11 @@ def version(self) -> Optional[str]: def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: # log checkpoints as artifacts - if self._log_model == "all" or (self._log_model is True and checkpoint_callback.save_top_k == -1) \ - or self._log_model == 'best_and_last': + if ( + self._log_model == "all" + or (self._log_model is True and checkpoint_callback.save_top_k == -1) + or self._log_model == "best_and_last" + ): self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback @@ -305,10 +308,10 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelChe # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t - if self._log_model == 'best_and_last': + if self._log_model == "best_and_last": # Clean up previous artifacts. # Adapted from https://gitbook-docs.wandb.ai/guides/artifacts/api#cleaning-up-unused-versions - api = wandb.Api(overrides={'project': self.experiment.project}) + api = wandb.Api(overrides={"project": self.experiment.project}) for version in api.artifact_versions(f"model-{self.experiment.id}", "model"): # Clean up all versions that don't have an alias such as 'latest'. diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index c37681e4831ca..d21e8efc7a5cb 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -191,7 +191,7 @@ def training_epoch_end(self, outputs) -> None: def test_batch_loop_releases_loss(tmpdir): - """Test that loss/graph is released so that it can be garbage collected before the next training step""" + """Test that loss/graph is released so that it can be garbage collected before the next training step.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx):