Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for log_model='best_and_last' option in wandb logger #9356

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Added support for `log_model=best_and_last` option in `pytorch_lightning.loggers.WandbLogger`.


- Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944))


Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class WandbLogger(LightningLoggerBase):
as W&B artifacts.

* if ``log_model == 'all'``, checkpoints are logged during training.
* if ``log_model == 'best_and_last'``, checkpoints are logged during training and only the best and the
last checkpoints (according to ModelCheckpoint) are kept as wandb artifacts. Previous versions are automatically deleted.
* if ``log_model == True``, checkpoints are logged at the end of training, except when
:paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1``
which also logs every checkpoint during training.
Expand Down Expand Up @@ -248,7 +250,11 @@ def version(self) -> Optional[str]:

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
if (
self._log_model == "all"
or (self._log_model is True and checkpoint_callback.save_top_k == -1)
or self._log_model == "best_and_last"
):
self._scan_and_log_checkpoints(checkpoint_callback)
elif self._log_model is True:
self._checkpoint_callback = checkpoint_callback
Expand Down Expand Up @@ -301,3 +307,13 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelChe
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t

if self._log_model == "best_and_last":
# Clean up previous artifacts.
# Adapted from https://gitbook-docs.wandb.ai/guides/artifacts/api#cleaning-up-unused-versions
api = wandb.Api(overrides={"project": self.experiment.project})

for version in api.artifact_versions(f"model-{self.experiment.id}", "model"):
# Clean up all versions that don't have an alias such as 'latest'.
if len(version.aliases) == 0:
version.delete()
Comment on lines +314 to +319
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question. I am not familiar with wandb API.

It seems model are being versioned and your are deleting all versions which doesn't have either latest or best aliases.

I am not sure to grasp why this would save only best and last model weights.

Furthermore, I don't think this would work for multiple ModelCheckpoint. Should we save the monitor as metadata to perform the filtering.

best_and_last should produce at maximum num_model_checkpoints + 1 checkpoints right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We automatically tag some artifact versions (model checkpoints). We tag the "latest" and we tag the "best" when monitoring value is defined (they can point to the same model). So there are 2 versions tagged at most.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@borisdayma Yah exactly. The implementation relies on the fact that there are at most 2 live aliases at the same time.
Several aliases can probably be included (best_0, best_1, best_2, ..., latest), but this is another PR :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe best_{monitor_0}, best_{monitor_1}, best_{monitor_2} would be better, it would enable users to navigate their weights better on the Wandb UI.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that we directly leverage ModelCheckpoint to identify best metrics (easier to maintain the callback, avoid replicating the same logic, and maybe easier for users).
You can see an example here.

12 changes: 11 additions & 1 deletion tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ def test_wandb_log_model(wandb, tmpdir):
trainer.fit(model)
assert wandb.init().log_artifact.call_count == 2

# test log_model='best_and_last'
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
logger = WandbLogger(log_model="best_and_last")
logger.experiment.id = "1"
logger.experiment.project_name.return_value = "project"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this test use a ModelCheckpoint ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. We should definitely build a better test. I am just not familiar with the whole "Mocking" thing, and with tests in general. Should my test function be also decorated with mocking?

If not, then this requires some wandb default setup to run the experiment. Not sure how to make this right as I never wrote tests for such a large project like Lightning :)

trainer.fit(model)
assert wandb.init().log_artifact.call_count == 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not test, that the other ones are properly removed. Can we also test/mock this somehow?

Copy link
Author

@ohayonguy ohayonguy Sep 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure. I will figure out how this should be properly tested.
Do you think that this should be done with mocking? @justusschock

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should train for 5 epochs in this test


# test log_model=False
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
Expand Down Expand Up @@ -203,7 +213,7 @@ def test_wandb_log_model(wandb, tmpdir):
type="model",
metadata={
"score": None,
"original_filename": "epoch=1-step=5-v3.ckpt",
"original_filename": "epoch=1-step=5-v4.ckpt",
"ModelCheckpoint": {
"monitor": None,
"mode": "min",
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def training_epoch_end(self, outputs) -> None:


def test_batch_loop_releases_loss(tmpdir):
"""Test that loss/graph is released so that it can be garbage collected before the next training step"""
"""Test that loss/graph is released so that it can be garbage collected before the next training step."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
Expand Down