diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f6f3f3693b1b..067f1e4a08a31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -118,6 +118,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038)) +- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + + ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 3f401669c351e..2f133eaccf512 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -19,6 +19,8 @@ """ import importlib +import io +import os import sys # check if ipywidgets is installed before importing tqdm.auto @@ -187,6 +189,12 @@ def enable(self): """ raise NotImplementedError + def print(self, *args, **kwargs): + """ + You should provide a way to print without breaking the progress bar. + """ + print(*args, **kwargs) + def on_init_end(self, trainer): self._trainer = trainer @@ -451,6 +459,22 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da def on_predict_end(self, trainer, pl_module): self.predict_progress_bar.close() + def print( + self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False + ): + active_progress_bar = None + + if not self.main_progress_bar.disable: + active_progress_bar = self.main_progress_bar + elif not self.val_progress_bar.disable: + active_progress_bar = self.val_progress_bar + elif not self.test_progress_bar.disable: + active_progress_bar = self.test_progress_bar + + if active_progress_bar is not None: + s = sep.join(map(str, args)) + active_progress_bar.write(s, end=end, file=file, nolock=nolock) + def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 91d2b3565d193..03aea0df9c533 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -190,8 +190,8 @@ def print(self, *args, **kwargs) -> None: Prints only from process 0. Use this in any distributed mode to log only once. Args: - *args: The thing to print. Will be passed to Python's built-in print function. - **kwargs: Will be passed to Python's built-in print function. + *args: The thing to print. The same as for Python's built-in print function. + **kwargs: The same as for Python's built-in print function. Example:: @@ -200,7 +200,11 @@ def forward(self, x): """ if self.trainer.is_global_zero: - print(*args, **kwargs) + progress_bar = self.trainer.progress_bar_callback + if progress_bar is not None and progress_bar.is_enabled: + progress_bar.print(*args, **kwargs) + else: + print(*args, **kwargs) def log( self, diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 9ec48008512fb..f16d8afd9cffd 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from unittest import mock -from unittest.mock import call, Mock +from unittest.mock import ANY, call, Mock import pytest import torch @@ -381,3 +382,69 @@ def training_step(self, batch, batch_idx): def test_tqdm_format_num(input_num, expected): """ Check that the specialized tqdm.format_num appends 0 to floats and strings """ assert tqdm.format_num(input_num) == expected + + +class PrintModel(BoringModel): + + def training_step(self, *args, **kwargs): + self.print("training_step", end="") + return super().training_step(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + self.print("validation_step", file=sys.stderr) + return super().validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + self.print("test_step") + return super().test_step(*args, **kwargs) + + +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print(tqdm_write, tmpdir): + """ Test that printing in the LightningModule redirects arguments to the progress bar. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1, + callbacks=[bar], + ) + trainer.fit(model) + trainer.test(model) + assert tqdm_write.call_count == 3 + assert tqdm_write.call_args_list == [ + call("training_step", end="", file=None, nolock=False), + call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), + call("test_step", end=os.linesep, file=None, nolock=False), + ] + + +@mock.patch('builtins.print') +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): + """ Test that printing in LightningModule goes through built-in print functin when progress bar is disabled. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1, + callbacks=[bar], + ) + bar.disable() + trainer.fit(model) + trainer.test(model) + + mock_print.assert_has_calls([ + call("training_step", end=""), + call("validation_step", file=ANY), + call("test_step"), + ]) + tqdm_write.assert_not_called()