Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Enable self.log in callbacks #5094

Merged
merged 6 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,8 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0):

def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0

def attach_model_logging_functions(self, model):
for callback in self.trainer.callbacks:
callback.log = model.log
tchaton marked this conversation as resolved.
Show resolved Hide resolved
callback.log_dict = model.log_dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: should we also allow for callbacks to write predictions aka self.write_prediction and self.write_prediction_dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but within another PR :)

3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
# check that model is configured correctly
self.trainer.config_validator.verify_loop_configurations(model)

# attach model log function to callback
self.trainer.callback_connector.attach_model_logging_functions(model)

def setup_training(self, model: LightningModule):
"""Sanity check a few things before starting actual training.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
# the goal of the project is simplicity for researchers, don't want to add too much
# engineer specific practices
setup(
name="pytorch-lightning",
name="pytorch-lightning-nightly",
version=pytorch_lightning.__version__,
description=pytorch_lightning.__docs__,
author=pytorch_lightning.__author__,
Expand Down
45 changes: 45 additions & 0 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,3 +771,48 @@ def on_train_epoch_end(self, *_):
trainer.fit(model)
assert model.epoch_end_called
assert model.on_train_epoch_end_called


def test_logging_in_callbacks_with_log_function(tmpdir):
"""
Tests ensure self.log can be used directly in callbacks.
"""
class LoggingCallback(callbacks.Callback):
def on_train_start(self, trainer, pl_module):
self.log("on_train_start", 1)

def on_train_epoch_start(self, trainer, pl_module):
self.log("on_train_epoch_start", 2)

def on_batch_end(self, trainer, pl_module):
self.log("on_batch_end", 3)

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.log("on_train_batch_end", 4)

def on_epoch_end(self, trainer, pl_module):
self.log("on_epoch_end", 5)

def on_train_epoch_end(self, trainer, pl_module, outputs):
self.log("on_train_epoch_end", 6)
self.callback_metrics = trainer.logger_connector.callback_metrics

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
callbacks=[LoggingCallback()]
)
trainer.fit(model)

expected = {
'on_train_start': 1,
'on_train_epoch_start': 2,
'on_batch_end': 3,
'on_train_batch_end': 4,
'on_epoch_end': 5,
'on_train_epoch_end': 6}
assert trainer.callback_metrics == expected