Skip to content

Commit

Permalink
feat: Add Rich Progress Bar (#8929)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
kaushikb11 and tchaton authored Aug 24, 2021
1 parent 1e4d892 commit 538e743
Show file tree
Hide file tree
Showing 11 changed files with 552 additions and 178 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))


- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))


### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
Expand All @@ -45,4 +45,5 @@
"QuantizationAwareTraining",
"StochasticWeightAveraging",
"Timer",
"RichProgressBar",
]
23 changes: 23 additions & 0 deletions pytorch_lightning/callbacks/progress/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Progress Bars
=============
Use or override one of the progress bar callbacks.
"""
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
181 changes: 181 additions & 0 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks import Callback


class ProgressBarBase(Callback):
r"""
The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback`
that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
You should implement your highly custom progress bars with this as the base class.
Example::
class LitProgressBar(ProgressBarBase):
def __init__(self):
super().__init__() # don't forget this :)
self.enable = True
def disable(self):
self.enable = False
def on_train_batch_end(self, trainer, pl_module, outputs):
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
"""

def __init__(self):

self._trainer = None
self._train_batch_idx = 0
self._val_batch_idx = 0
self._test_batch_idx = 0
self._predict_batch_idx = 0

@property
def trainer(self):
return self._trainer

@property
def train_batch_idx(self) -> int:
"""
The current batch index being processed during training.
Use this to update your progress bar.
"""
return self._train_batch_idx

@property
def val_batch_idx(self) -> int:
"""
The current batch index being processed during validation.
Use this to update your progress bar.
"""
return self._val_batch_idx

@property
def test_batch_idx(self) -> int:
"""
The current batch index being processed during testing.
Use this to update your progress bar.
"""
return self._test_batch_idx

@property
def predict_batch_idx(self) -> int:
"""
The current batch index being processed during predicting.
Use this to update your progress bar.
"""
return self._predict_batch_idx

@property
def total_train_batches(self) -> int:
"""
The total number of training batches during training, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
training dataloader is of infinite size.
"""
return self.trainer.num_training_batches

@property
def total_val_batches(self) -> int:
"""
The total number of validation batches during validation, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
validation dataloader is of infinite size.
"""
total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0

return total_val_batches

@property
def total_test_batches(self) -> int:
"""
The total number of testing batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
test dataloader is of infinite size.
"""
return sum(self.trainer.num_test_batches)

@property
def total_predict_batches(self) -> int:
"""
The total number of predicting batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
predict dataloader is of infinite size.
"""
return sum(self.trainer.num_predict_batches)

def disable(self):
"""
You should provide a way to disable the progress bar.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the
output on processes that have a rank different from 0, e.g., in multi-node training.
"""
raise NotImplementedError

def enable(self):
"""
You should provide a way to enable the progress bar.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training
routines like the :ref:`learning rate finder <advanced/lr_finder:Learning Rate Finder>`
to temporarily enable and disable the main progress bar.
"""
raise NotImplementedError

def print(self, *args, **kwargs):
"""
You should provide a way to print without breaking the progress bar.
"""
print(*args, **kwargs)

def on_init_end(self, trainer):
self._trainer = trainer

def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.fit_loop.batch_idx

def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
self._val_batch_idx = 0

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._val_batch_idx += 1

def on_test_start(self, trainer, pl_module):
self._test_batch_idx = 0

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1

def on_predict_epoch_start(self, trainer, pl_module):
self._predict_batch_idx = 0

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._predict_batch_idx += 1
Loading

0 comments on commit 538e743

Please sign in to comment.