Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/5275 clean progress bar print #5470

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))


- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

"""
import importlib
import io
import os
import sys

# check if ipywidgets is installed before importing tqdm.auto
Expand Down Expand Up @@ -187,6 +189,12 @@ def enable(self):
"""
raise NotImplementedError

def print(self, *args, **kwargs):
"""
You should provide a way to print without breaking the progress bar.
"""
print(*args, **kwargs)

def on_init_end(self, trainer):
self._trainer = trainer

Expand Down Expand Up @@ -451,6 +459,22 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da
def on_predict_end(self, trainer, pl_module):
self.predict_progress_bar.close()

def print(
self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False
):
active_progress_bar = None

if not self.main_progress_bar.disable:
active_progress_bar = self.main_progress_bar
elif not self.val_progress_bar.disable:
active_progress_bar = self.val_progress_bar
elif not self.test_progress_bar.disable:
active_progress_bar = self.test_progress_bar

if active_progress_bar is not None:
s = sep.join(map(str, args))
active_progress_bar.write(s, end=end, file=file, nolock=nolock)

def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def print(self, *args, **kwargs) -> None:
Prints only from process 0. Use this in any distributed mode to log only once.

Args:
*args: The thing to print. Will be passed to Python's built-in print function.
**kwargs: Will be passed to Python's built-in print function.
*args: The thing to print. The same as for Python's built-in print function.
**kwargs: The same as for Python's built-in print function.

Example::

Expand All @@ -200,7 +200,11 @@ def forward(self, x):

"""
if self.trainer.is_global_zero:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
print(*args, **kwargs)
progress_bar = self.trainer.progress_bar_callback
if progress_bar is not None and progress_bar.is_enabled:
progress_bar.print(*args, **kwargs)
else:
print(*args, **kwargs)

def log(
self,
Expand Down
69 changes: 68 additions & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from unittest import mock
from unittest.mock import call, Mock
from unittest.mock import ANY, call, Mock

import pytest
import torch
Expand Down Expand Up @@ -381,3 +382,69 @@ def training_step(self, batch, batch_idx):
def test_tqdm_format_num(input_num, expected):
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
assert tqdm.format_num(input_num) == expected


class PrintModel(BoringModel):

def training_step(self, *args, **kwargs):
self.print("training_step", end="")
return super().training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
self.print("validation_step", file=sys.stderr)
return super().validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
self.print("test_step")
return super().test_step(*args, **kwargs)


@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
def test_progress_bar_print(tqdm_write, tmpdir):
""" Test that printing in the LightningModule redirects arguments to the progress bar. """
model = PrintModel()
bar = ProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
max_steps=1,
callbacks=[bar],
)
trainer.fit(model)
trainer.test(model)
assert tqdm_write.call_count == 3
assert tqdm_write.call_args_list == [
call("training_step", end="", file=None, nolock=False),
call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
call("test_step", end=os.linesep, file=None, nolock=False),
]


@mock.patch('builtins.print')
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
""" Test that printing in LightningModule goes through built-in print functin when progress bar is disabled. """
model = PrintModel()
bar = ProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
max_steps=1,
callbacks=[bar],
)
bar.disable()
trainer.fit(model)
trainer.test(model)

mock_print.assert_has_calls([
call("training_step", end=""),
call("validation_step", file=ANY),
call("test_step"),
])
tqdm_write.assert_not_called()