Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the deprecated get_progress_bar_dict #12839

Merged
merged 20 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed deprecated `dataloader_idx` argument from `on_train_batch_start/end` hooks `Callback` and `LightningModule` ([#12769](https://github.com/PyTorchLightning/pytorch-lightning/pull/12769))


- Removed deprecated `get_progress_bar_dict` property from `LightningModule` ([#12839](https://github.com/PyTorchLightning/pytorch-lightning/pull/12839))

### Fixed


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_metrics(self, trainer, model):
Return:
Dictionary with the items to be displayed in the progress bar.
"""
standard_metrics = pl_module.get_progress_bar_dict()
standard_metrics = get_standard_metrics(trainer, pl_module)
pbar_metrics = trainer.progress_bar_metrics
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
Expand Down
30 changes: 0 additions & 30 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -1731,35 +1730,6 @@ def unfreeze(self) -> None:

self.train()

def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
r"""
.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of
`pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7.

Implement this to override the default items displayed in the progress bar.
By default it includes the average loss value, split index of BPTT (if used)
and the version of the experiment when using a logger.

.. code-block::

Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]

Here is an example how to override the defaults:

.. code-block:: python

def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items

Return:
Dictionary with the items to be displayed in the progress bar.
"""
return progress_base.get_standard_metrics(self.trainer, self)

def _verify_is_manual_optimization(self, fn_name):
if self.automatic_optimization:
raise MisconfigurationException(
Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:

__verify_dp_batch_transfer_support(trainer, model)
_check_add_get_queue(model)
# TODO: Delete _check_progress_bar in v1.7
_check_progress_bar(model)
# TODO: Delete _check_on_post_move_to_device in v1.7
_check_on_post_move_to_device(model)
_check_deprecated_callback_hooks(trainer)
Expand Down Expand Up @@ -143,20 +141,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
)


def _check_progress_bar(model: "pl.LightningModule") -> None:
r"""
Checks if get_progress_bar_dict is overridden and sends a deprecation warning.

Args:
model: The model to check the get_progress_bar_dict method.
"""
if is_overridden("get_progress_bar_dict", model):
rank_zero_deprecation(
"The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7."
" Please use the `ProgressBarBase.get_metrics` instead."
)


def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
r"""
Checks if `on_post_move_to_device` method is overridden and sends a deprecation warning.
Expand Down
15 changes: 1 addition & 14 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, cast, Dict, Generator, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union
from weakref import proxy

import torch
Expand Down Expand Up @@ -2190,19 +2190,6 @@ def distributed_sampler_kwargs(self) -> Optional[dict]:
def data_parallel(self) -> bool:
return isinstance(self.strategy, ParallelStrategy)

@property
def progress_bar_dict(self) -> dict:
"""Read-only for progress bar metrics."""
rank_zero_deprecation(
"`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7."
" Use `ProgressBarBase.get_metrics` instead."
)
ref_model = self.lightning_module
ref_model = cast(pl.LightningModule, ref_model)
if self.progress_bar_callback:
return self.progress_bar_callback.get_metrics(self, ref_model)
return self.progress_bar_metrics

@property
def enable_validation(self) -> bool:
"""Check if we should run validation during training."""
Expand Down
22 changes: 0 additions & 22 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,6 @@
from tests.plugins.environments.test_lsf_environment import _make_rankfile


def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):
class TestModel(BoringModel):
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)
test_model = TestModel()
with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"):
trainer.fit(test_model)
standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix
assert "loss" in standard_metrics_postfix
assert "v_num" not in standard_metrics_postfix

with pytest.deprecated_call(match=r"`trainer.progress_bar_dict` is deprecated in v1.5"):
_ = trainer.progress_bar_dict


def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def __init__(self, not_supported):
"on_before_batch_transfer",
"transfer_batch_to_device",
"on_after_batch_transfer",
"get_progress_bar_dict",
}
)
# remove `nn.Module` hooks
Expand Down