-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for log_model='best_and_last' option in wandb logger #9356
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -176,6 +176,16 @@ def test_wandb_log_model(wandb, tmpdir): | |
trainer.fit(model) | ||
assert wandb.init().log_artifact.call_count == 2 | ||
|
||
# test log_model='best_and_last' | ||
wandb.init().log_artifact.reset_mock() | ||
wandb.init.reset_mock() | ||
logger = WandbLogger(log_model="best_and_last") | ||
logger.experiment.id = "1" | ||
logger.experiment.project_name.return_value = "project" | ||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this test use a ModelCheckpoint ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup. We should definitely build a better test. I am just not familiar with the whole "Mocking" thing, and with tests in general. Should my test function be also decorated with mocking? If not, then this requires some wandb default setup to run the experiment. Not sure how to make this right as I never wrote tests for such a large project like Lightning :) |
||
trainer.fit(model) | ||
assert wandb.init().log_artifact.call_count == 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this does not test, that the other ones are properly removed. Can we also test/mock this somehow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For sure. I will figure out how this should be properly tested. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you should train for 5 epochs in this test |
||
|
||
# test log_model=False | ||
wandb.init().log_artifact.reset_mock() | ||
wandb.init.reset_mock() | ||
|
@@ -203,7 +213,7 @@ def test_wandb_log_model(wandb, tmpdir): | |
type="model", | ||
metadata={ | ||
"score": None, | ||
"original_filename": "epoch=1-step=5-v3.ckpt", | ||
"original_filename": "epoch=1-step=5-v4.ckpt", | ||
"ModelCheckpoint": { | ||
"monitor": None, | ||
"mode": "min", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question. I am not familiar with wandb API.
It seems model are being versioned and your are deleting all versions which doesn't have either latest or best aliases.
I am not sure to grasp why this would save only best and last model weights.
Furthermore, I don't think this would work for multiple ModelCheckpoint. Should we save the monitor as metadata to perform the filtering.
best_and_last should produce at maximum num_model_checkpoints + 1 checkpoints right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We automatically tag some artifact versions (model checkpoints). We tag the "latest" and we tag the "best" when monitoring value is defined (they can point to the same model). So there are 2 versions tagged at most.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@borisdayma Yah exactly. The implementation relies on the fact that there are at most 2 live aliases at the same time.
Several aliases can probably be included (best_0, best_1, best_2, ..., latest), but this is another PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe best_{monitor_0}, best_{monitor_1}, best_{monitor_2} would be better, it would enable users to navigate their weights better on the Wandb UI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that we directly leverage
ModelCheckpoint
to identify best metrics (easier to maintain the callback, avoid replicating the same logic, and maybe easier for users).You can see an example here.