From 85ea319a6a08129244f2e02154dd534912134c04 Mon Sep 17 00:00:00 2001 From: chaton <thomas@grid.ai> Date: Mon, 28 Dec 2020 15:34:18 +0100 Subject: [PATCH 01/16] Trainer.test should return only test metrics (#5214) * resolve bug * merge tests --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 595a5e84bf630..0afb08d120d67 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from copy import deepcopy +import os from pprint import pprint from typing import Dict, Iterable, Union From 5c112f045d441323cb675254a02c4a56ed8123c6 Mon Sep 17 00:00:00 2001 From: Tadej Svetina <tadej.svetina@gmail.com> Date: Tue, 29 Dec 2020 22:09:10 +0100 Subject: [PATCH 02/16] Fix metric state reset (#5273) * Fix metric state reset * Fix test * Improve formatting Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai> --- tests/metrics/test_metric.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 03b79633e3eb7..00aaefcabd5c5 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -41,6 +41,20 @@ def compute(self): pass +class DummyList(Metric): + name = "DummyList" + + def __init__(self): + super().__init__() + self.add_state("x", list(), dist_reduce_fx=None) + + def update(self): + pass + + def compute(self): + pass + + def test_inherit(): Dummy() From 2072b61ac1506282e60b2fd0477f9da4ade08402 Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Sun, 10 Jan 2021 10:54:01 +0300 Subject: [PATCH 03/16] print() method added to ProgressBar --- pytorch_lightning/callbacks/progress.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 7de7982b4a2de..bcc36f86b4c60 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -426,6 +426,20 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da def on_predict_end(self, trainer, pl_module): self.predict_progress_bar.close() + def print(self, *args, sep=' ', end='\n', file=None, nolock=False): + active_progress_bar = None + + if not self.main_progress_bar.disable: + active_progress_bar = self.main_progress_bar + elif not self.val_progress_bar.disable: + active_progress_bar = self.val_progress_bar + elif not self.test_progress_bar.disable: + active_progress_bar = self.test_progress_bar + + if active_progress_bar is not None: + s = sep.join(map(str, args)) + active_progress_bar.write(s, end=end, file=file, nolock=nolock) + def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) From 539e2e660db09987fcac8046f18389055ab3d671 Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Sun, 10 Jan 2021 10:58:38 +0300 Subject: [PATCH 04/16] printing alongside progress bar added to LightningModule.print() --- pytorch_lightning/core/lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9c87836b4415a..a3776c4493487 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -194,7 +194,11 @@ def forward(self, x): """ if self.trainer.is_global_zero: - print(*args, **kwargs) + progress_bar = self.trainer.progress_bar_callback + if progress_bar is not None and progress_bar.is_enabled: + progress_bar.print(*args, **kwargs) + else: + print(*args, **kwargs) def log( self, From 023993b323954b676042409ee2768d970ebf5a27 Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Sun, 10 Jan 2021 11:09:01 +0300 Subject: [PATCH 05/16] LightningModule.print() method documentation updated --- pytorch_lightning/core/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a3776c4493487..cf8666bfd7b44 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -184,8 +184,8 @@ def print(self, *args, **kwargs) -> None: Prints only from process 0. Use this in any distributed mode to log only once. Args: - *args: The thing to print. Will be passed to Python's built-in print function. - **kwargs: Will be passed to Python's built-in print function. + *args: The thing to print. The same as for Python's built-in print function. + **kwargs: The same as for Python's built-in print function. Example:: From dbba2d759ad03f519359eaf6583ce6b9226cf99b Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Fri, 15 Jan 2021 23:42:16 +0300 Subject: [PATCH 06/16] ProgressBarBase.print() stub added --- pytorch_lightning/callbacks/progress.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index bcc36f86b4c60..e1c676e686980 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -162,6 +162,12 @@ def enable(self): """ raise NotImplementedError + def print(self, *args): + """ + You should provide a way to print without breaking the progress bar. + """ + raise NotImplementedError + def on_init_end(self, trainer): self._trainer = trainer From 1bfc866b0e9fd87c8d898dfa86a2590520ccb037 Mon Sep 17 00:00:00 2001 From: rohitgr7 <rohitgr1998@gmail.com> Date: Tue, 19 Jan 2021 22:58:59 +0530 Subject: [PATCH 07/16] stub --- pytorch_lightning/callbacks/progress.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index e1c676e686980..048af0f190480 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -168,6 +168,12 @@ def print(self, *args): """ raise NotImplementedError + def print(self, *args): + """ + You should provide a way to print without breaking the progress bar. + """ + raise NotImplementedError + def on_init_end(self, trainer): self._trainer = trainer From 0ee4abc554ca698633e66f9f219a5b83f20629be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com> Date: Wed, 20 Jan 2021 02:14:01 +0100 Subject: [PATCH 08/16] add progress bar tests --- tests/callbacks/test_progress_bar.py | 69 +++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8398aec88fe68..90bedba7f5046 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from unittest import mock -from unittest.mock import call, Mock +from unittest.mock import call, Mock, ANY import pytest import torch @@ -371,3 +372,69 @@ def training_step(self, batch, batch_idx): pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") + + +class PrintModel(BoringModel): + + def training_step(self, *args, **kwargs): + self.print("training_step", end="") + return super().training_step(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + self.print("validation_step", file=sys.stderr) + return super().validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + self.print("test_step") + return super().test_step(*args, **kwargs) + + +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print(tqdm_write, tmpdir): + """ Test that printing in the LightningModule redirects arguments to the progress bar. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1, + callbacks=[bar], + ) + trainer.fit(model) + trainer.test(model) + assert tqdm_write.call_count == 3 + assert tqdm_write.call_args_list == [ + call("training_step", end="", file=None, nolock=False), + call("validation_step", end="\n", file=sys.stderr, nolock=False), + call("test_step", end="\n", file=None, nolock=False), + ] + + +@mock.patch('builtins.print') +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): + """ Test that printing in LightningModule goes through built-in print functin when progress bar is disabled. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1, + callbacks=[bar], + ) + bar.disable() + trainer.fit(model) + trainer.test(model) + + mock_print.assert_has_calls([ + call("training_step", end=""), + call("validation_step", file=ANY), + call("test_step"), + ]) + tqdm_write.assert_not_called() From a70b7c07de059ffca33bba6b599886ab70da313c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com> Date: Wed, 20 Jan 2021 02:15:18 +0100 Subject: [PATCH 09/16] fix isort --- tests/callbacks/test_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 90bedba7f5046..aa614cbaaf509 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -14,7 +14,7 @@ import os import sys from unittest import mock -from unittest.mock import call, Mock, ANY +from unittest.mock import ANY, call, Mock import pytest import torch From a13a2187b593230d2bfe8bc91a5a30eb339edbc8 Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Wed, 3 Feb 2021 12:20:47 +0300 Subject: [PATCH 10/16] Progress Callback fixes --- pytorch_lightning/callbacks/progress.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 048af0f190480..db563dc0854f5 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -19,7 +19,10 @@ """ import importlib +import io +import os import sys +from typing import Optional # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -168,11 +171,11 @@ def print(self, *args): """ raise NotImplementedError - def print(self, *args): + def print(self, *args, **kwargs): """ You should provide a way to print without breaking the progress bar. """ - raise NotImplementedError + print(*args, **kwargs) def on_init_end(self, trainer): self._trainer = trainer @@ -438,7 +441,9 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da def on_predict_end(self, trainer, pl_module): self.predict_progress_bar.close() - def print(self, *args, sep=' ', end='\n', file=None, nolock=False): + def print( + self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False + ): active_progress_bar = None if not self.main_progress_bar.disable: From 19411f772a98119300e98570b919d960191e0dbe Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Mon, 8 Feb 2021 17:25:52 +0300 Subject: [PATCH 11/16] test_metric.py duplicate DummyList removed --- tests/metrics/test_metric.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 00aaefcabd5c5..03b79633e3eb7 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -41,20 +41,6 @@ def compute(self): pass -class DummyList(Metric): - name = "DummyList" - - def __init__(self): - super().__init__() - self.add_state("x", list(), dist_reduce_fx=None) - - def update(self): - pass - - def compute(self): - pass - - def test_inherit(): Dummy() From 68be2b846fcb9d1c364facae62846acb8dca0e92 Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Mon, 8 Feb 2021 17:52:34 +0300 Subject: [PATCH 12/16] PEP and isort fixes --- pytorch_lightning/callbacks/progress.py | 7 ------- .../connectors/logger_connector/logger_connector.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index db563dc0854f5..85ffe4be2c869 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -22,7 +22,6 @@ import io import os import sys -from typing import Optional # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -165,12 +164,6 @@ def enable(self): """ raise NotImplementedError - def print(self, *args): - """ - You should provide a way to print without breaking the progress bar. - """ - raise NotImplementedError - def print(self, *args, **kwargs): """ You should provide a way to print without breaking the progress bar. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0afb08d120d67..595a5e84bf630 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy import os +from copy import deepcopy from pprint import pprint from typing import Dict, Iterable, Union From ab908cfafd81b69be9b360cfb3328b5e1a38866d Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Mon, 8 Feb 2021 18:04:22 +0300 Subject: [PATCH 13/16] CHANGELOG updated --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95ca3329f8497..3a3fdd7074515 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -105,6 +105,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningDataModule.from_datasets(...)` ([#5133](https://github.com/PyTorchLightning/pytorch-lightning/pull/5133)) +- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + + ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) From 599aa6a29056d2e47aa6635f7b85e97f5f7f1ed5 Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Mon, 8 Feb 2021 18:34:23 +0300 Subject: [PATCH 14/16] test_progress_bar_print win linesep fix --- tests/callbacks/test_progress_bar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index aa614cbaaf509..f9add25b72534 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -408,8 +408,8 @@ def test_progress_bar_print(tqdm_write, tmpdir): assert tqdm_write.call_count == 3 assert tqdm_write.call_args_list == [ call("training_step", end="", file=None, nolock=False), - call("validation_step", end="\n", file=sys.stderr, nolock=False), - call("test_step", end="\n", file=None, nolock=False), + call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), + call("test_step", end=os.linesep, file=None, nolock=False), ] From 331aaccb5982d28c96a50a8af6e9b66b4baaa49c Mon Sep 17 00:00:00 2001 From: Alexander Snorkin <Alexander.Snorkin@acronis.com> Date: Thu, 18 Feb 2021 16:15:34 +0300 Subject: [PATCH 15/16] test_progress_bar.py remove whitespaces --- tests/callbacks/test_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 22ab1c27ab6db..f16d8afd9cffd 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -374,7 +374,7 @@ def training_step(self, batch, batch_idx): actual = str(pbar.postfix) assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") - + @pytest.mark.parametrize( "input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'], ['10000', '10000'], ['abc', 'abc']] From c152e0ce5d73cadc2e21dd371908bc618ca7da04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= <carlossmocholi@gmail.com> Date: Thu, 18 Feb 2021 16:32:08 +0100 Subject: [PATCH 16/16] Update CHANGELOG.md --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b61089c83a5df..067f1e4a08a31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,7 +121,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) - ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))