From 615b2f736303ad3cd56b33111f69952255f01c0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Mar 2021 00:18:38 +0100 Subject: [PATCH] Improve DummyLogger (#6398) * fix dummy logger * docs * update docs * add changelog * add none return annotation * return empty string for name, version --- CHANGELOG.md | 3 +++ pytorch_lightning/loggers/base.py | 34 +++++++++++++----------- tests/loggers/test_base.py | 9 +++++++ tests/trainer/flags/test_fast_dev_run.py | 1 + 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bdae336a18874..046e07cc55736 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) +- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) + + ## [1.2.2] - 2021-03-02 ### Added diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 4fdb5e8c437bf..035a42338fe68 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -279,12 +279,14 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: return params @abstractmethod - def log_hyperparams(self, params: argparse.Namespace): + def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): """ Record hyperparameters. Args: params: :class:`~argparse.Namespace` containing the hyperparameters + args: Optional positional arguments, depends on the specific logger being used + kwargs: Optional keywoard arguments, depends on the specific logger being used """ def log_graph(self, model: LightningModule, input_array=None) -> None: @@ -418,41 +420,41 @@ def nop(*args, **kw): def __getattr__(self, _): return self.nop - def __getitem__(self, idx): - # enables self.logger[0].experiment.add_image - # and self.logger.experiment[0].add_image(...) + def __getitem__(self, idx) -> "DummyExperiment": + # enables self.logger.experiment[0].add_image(...) return self class DummyLogger(LightningLoggerBase): - """ Dummy logger for internal use. Is usefull if we want to disable users - logger for a feature, but still secure that users code can run """ + """ + Dummy logger for internal use. It is useful if we want to disable user's + logger for a feature, but still ensure that user code can run + """ def __init__(self): super().__init__() self._experiment = DummyExperiment() @property - def experiment(self): + def experiment(self) -> DummyExperiment: return self._experiment - @rank_zero_only - def log_metrics(self, metrics, step): + def log_metrics(self, *args, **kwargs) -> None: pass - @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, *args, **kwargs) -> None: pass @property - def name(self): - pass + def name(self) -> str: + return "" @property - def version(self): - pass + def version(self) -> str: + return "" - def __getitem__(self, idx): + def __getitem__(self, idx) -> "DummyLogger": + # enables self.logger[0].experiment.add_image(...) return self diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index c48fef5e04b49..cf3a0cb74b3f4 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -229,15 +229,24 @@ def log_metrics(self, metrics, step): def test_dummyexperiment_support_indexing(): + """ Test that the DummyExperiment can imitate indexing the experiment in a LoggerCollection. """ experiment = DummyExperiment() assert experiment[0] == experiment def test_dummylogger_support_indexing(): + """ Test that the DummyLogger can imitate indexing of a LoggerCollection. """ logger = DummyLogger() assert logger[0] == logger +def test_dummylogger_noop_method_calls(): + """ Test that the DummyLogger methods can be called with arbitrary arguments. """ + logger = DummyLogger() + logger.log_hyperparams("1", 2, three="three") + logger.log_metrics("1", 2, three="three") + + def test_np_sanitization(): class CustomParamsLogger(CustomLogger): diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 09c5b58d363d9..9160d8d0f3d61 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -71,6 +71,7 @@ def test_step(self, batch, batch_idx): checkpoint_callback = ModelCheckpoint() early_stopping_callback = EarlyStopping() trainer_config = dict( + default_root_dir=tmpdir, fast_dev_run=fast_dev_run, val_check_interval=2, logger=True,