From 84d8f9af44bb82f13d571e1905e105cc220ae201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 8 Sep 2023 10:52:41 +0200 Subject: [PATCH] Close validation progress bar before updating training bar (#18503) (cherry picked from commit a3f6e98e7c8a12e677bb9d10bd4f61ba716e0cad) --- src/lightning/pytorch/CHANGELOG.md | 3 +++ src/lightning/pytorch/callbacks/progress/tqdm_progress.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8ad0dc388431f..1ee0cfa7cd4ca 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +- Fixed visual glitch with the TQDM progress bar leaving the validation bar incomplete before switching back to the training display ([#18503](https://github.com/Lightning-AI/lightning/pull/18503)) + + ## [2.0.7] - 2023-08-14 ### Added diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 075dc3b99356e..dd93b8ab46655 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -240,8 +240,8 @@ def on_sanity_check_start(self, *_: Any) -> None: self.train_progress_bar = Tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, *_: Any) -> None: - self.train_progress_bar.close() self.val_progress_bar.close() + self.train_progress_bar.close() def on_train_start(self, *_: Any) -> None: self.train_progress_bar = self.init_train_tqdm() @@ -300,10 +300,10 @@ def on_validation_batch_end( _update_n(self.val_progress_bar, n) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self._train_progress_bar is not None and trainer.state.fn == "fit": - self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() self.reset_dataloader_idx_tracker() + if self._train_progress_bar is not None and trainer.state.fn == "fit": + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar = self.init_test_tqdm()