Skip to content

Commit

Permalink
Removed weights_summary argument from Trainer (#13070)
Browse files Browse the repository at this point in the history
Co-authored-by: rohitgr7 <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
4 people authored May 31, 2022
1 parent 5bdb936 commit f4f14bb
Show file tree
Hide file tree
Showing 10 changed files with 16 additions and 179 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed the deprecated `weights_summary` argument from the `Trainer` constructor ([#13070](https://github.com/PyTorchLightning/pytorch-lightning/pull/13070))


- Removed the deprecated `flush_logs_every_n_steps` argument from the `Trainer` constructor ([#13074](https://github.com/PyTorchLightning/pytorch-lightning/pull/13074))


Expand Down
30 changes: 0 additions & 30 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1546,36 +1546,6 @@ Example::
weights_save_path='my/path'
)

weights_summary
^^^^^^^^^^^^^^^

.. warning:: `weights_summary` is deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument instead. To disable the model summary,
pass ``enable_model_summary = False`` to the Trainer.


.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/weights_summary.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/weights_summary.mp4"></video>

|
Prints a summary of the weights when training begins.
Options: 'full', 'top', None.

.. testcode::

# default used by the Trainer (ie: print summary of top level modules)
trainer = Trainer(weights_summary="top")

# print full summary of all modules and submodules
trainer = Trainer(weights_summary="full")

# don't print a summary
trainer = Trainer(weights_summary=None)


enable_model_summary
^^^^^^^^^^^^^^^^^^^^
Expand Down
35 changes: 4 additions & 31 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.enums import ModelSummaryMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
Expand All @@ -48,7 +47,6 @@ def on_trainer_init(
default_root_dir: Optional[str],
weights_save_path: Optional[str],
enable_model_summary: bool,
weights_summary: Optional[str],
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
):
Expand Down Expand Up @@ -79,7 +77,7 @@ def on_trainer_init(
self._configure_progress_bar(enable_progress_bar)

# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary, weights_summary)
self._configure_model_summary_callback(enable_model_summary)

# accumulated grads
self._configure_accumulated_gradients(accumulate_grad_batches)
Expand Down Expand Up @@ -134,15 +132,7 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
elif enable_checkpointing:
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_model_summary_callback(
self, enable_model_summary: bool, weights_summary: Optional[str] = None
) -> None:
if weights_summary is None:
rank_zero_deprecation(
"Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
" in v1.7. Please set `Trainer(enable_model_summary=False)` instead."
)
return
def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
if not enable_model_summary:
return

Expand All @@ -154,31 +144,14 @@ def _configure_model_summary_callback(
)
return

if weights_summary == "top":
# special case the default value for weights_summary to preserve backward compatibility
max_depth = 1
else:
rank_zero_deprecation(
f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
" `max_depth` directly to the Trainer's `callbacks` argument instead."
)
if weights_summary not in ModelSummaryMode.supported_types():
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)

progress_bar_callback = self.trainer.progress_bar_callback
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)

if progress_bar_callback is not None and is_progress_bar_rich:
model_summary = RichModelSummary(max_depth=max_depth)
model_summary = RichModelSummary()
else:
model_summary = ModelSummary(max_depth=max_depth)
model_summary = ModelSummary()
self.trainer.callbacks.append(model_summary)
self.trainer._weights_summary = weights_summary

def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
Expand Down
23 changes: 0 additions & 23 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def __init__(
sync_batchnorm: bool = False,
precision: Union[int, str] = 32,
enable_model_summary: bool = True,
weights_summary: Optional[str] = "top",
weights_save_path: Optional[str] = None, # TODO: Remove in 1.8
num_sanity_val_steps: int = 2,
resume_from_checkpoint: Optional[Union[Path, str]] = None,
Expand Down Expand Up @@ -395,14 +394,6 @@ def __init__(
enable_model_summary: Whether to enable model summarization by default.
Default: ``True``.
weights_summary: Prints a summary of the weights when training begins.
.. deprecated:: v1.5
``weights_summary`` has been deprecated in v1.5 and will be removed in v1.7.
To disable the summary, pass ``enable_model_summary = False`` to the Trainer.
To customize the summary, pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument.
weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in `default_root_dir`.
Expand Down Expand Up @@ -485,9 +476,6 @@ def __init__(
self._tested_ckpt_path: Optional[str] = None # TODO: remove in v1.8
self._predicted_ckpt_path: Optional[str] = None # TODO: remove in v1.8

# todo: remove in v1.7
self._weights_summary: Optional[str] = None

# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
self._callback_connector.on_trainer_init(
Expand All @@ -497,7 +485,6 @@ def __init__(
default_root_dir,
weights_save_path,
enable_model_summary,
weights_summary,
max_time,
accumulate_grad_batches,
)
Expand Down Expand Up @@ -2724,16 +2711,6 @@ def _should_terminate_gracefully(self) -> bool:
value = torch.tensor(int(self._terminate_gracefully), device=self.strategy.root_device)
return self.strategy.reduce(value, reduce_op="sum") > 0

@property
def weights_summary(self) -> Optional[str]:
rank_zero_deprecation("`Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
return self._weights_summary

@weights_summary.setter
def weights_summary(self, val: Optional[str]) -> None:
rank_zero_deprecation("Setting `Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
self._weights_summary = val

"""
Other
"""
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
DistributedType,
GradClipAlgorithmType,
LightningEnum,
ModelSummaryMode,
)
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
from pytorch_lightning.utilities.imports import ( # noqa: F401
Expand Down
32 changes: 0 additions & 32 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,38 +199,6 @@ class AutoRestartBatchKeys(LightningEnum):
PL_RESTART_META = "__pl_restart_meta"


class ModelSummaryMode(LightningEnum):
# TODO: remove in v1.6 (as `mode` would be deprecated for `max_depth`)
"""Define the Model Summary mode to be used.
Can be one of
- `top`: only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
>>> # you can match the type with string
>>> ModelSummaryMode.TOP == 'TOP'
True
>>> # which is case invariant
>>> ModelSummaryMode.TOP in ('top', 'FULL')
True
"""

TOP = "top"
FULL = "full"

@staticmethod
def get_max_depth(mode: str) -> int:
if mode == ModelSummaryMode.TOP:
return 1
if mode == ModelSummaryMode.FULL:
return -1
raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.")

@staticmethod
def supported_types() -> list[str]:
return [x.value for x in ModelSummaryMode]


class _StrategyType(LightningEnum):
"""Define type of training strategy.
Expand Down
32 changes: 6 additions & 26 deletions tests/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.
from typing import List, Union

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
from tests.helpers.boring_model import BoringModel
Expand All @@ -29,38 +27,20 @@ def test_model_summary_callback_present_trainer():
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)


def test_model_summary_callback_with_weights_summary_none():
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
trainer = Trainer(weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

def test_model_summary_callback_with_enable_model_summary_false():
trainer = Trainer(enable_model_summary=False)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False, weights_summary="full")
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
trainer = Trainer(enable_model_summary=True, weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

def test_model_summary_callback_with_enable_model_summary_true():
trainer = Trainer(enable_model_summary=True)
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

def test_model_summary_callback_with_weights_summary():
trainer = Trainer(weights_summary="top")
# Default value of max_depth is set as 1, when enable_model_summary is True
# and ModelSummary is not passed in callbacks list
model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0]
assert model_summary_callback._max_depth == 1

with pytest.deprecated_call(match=r"weights_summary=full\)` is deprecated"):
trainer = Trainer(weights_summary="full")
model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0]
assert model_summary_callback._max_depth == -1


def test_model_summary_callback_override_weights_summary_flag():
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
trainer = Trainer(callbacks=ModelSummary(), weights_summary=None)
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)


def test_custom_model_summary_callback_summarize(tmpdir):
class CustomModelSummary(ModelSummary):
Expand Down
15 changes: 0 additions & 15 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,6 @@ def test_v1_7_0_deprecate_parameter_validation():
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401


def test_v1_7_0_weights_summary_trainer(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"):
t = Trainer(weights_summary="full")

with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=None\)` is deprecated in v1.5"):
t = Trainer(weights_summary=None)

t = Trainer(weights_summary="top")
with pytest.deprecated_call(match=r"`Trainer.weights_summary` is deprecated in v1.5"):
_ = t.weights_summary

with pytest.deprecated_call(match=r"Setting `Trainer.weights_summary` is deprecated in v1.5"):
t.weights_summary = "blah"


def test_v1_7_0_deprecated_slurm_job_id():
trainer = Trainer()
with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."):
Expand Down
14 changes: 1 addition & 13 deletions tests/utilities/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from pytorch_lightning.utilities.enums import _AcceleratorType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType
from pytorch_lightning.utilities.enums import _AcceleratorType, GradClipAlgorithmType, PrecisionType


def test_consistency():
Expand All @@ -34,16 +32,6 @@ def test_precision_supported_types():
assert not PrecisionType.supported_type("invalid")


def test_model_summary_mode():
assert ModelSummaryMode.supported_types() == ["top", "full"]
assert ModelSummaryMode.TOP in ("top", "full")
assert ModelSummaryMode.get_max_depth("top") == 1
assert ModelSummaryMode.get_max_depth("full") == -1

with pytest.raises(ValueError, match=f"`mode` can be {', '.join(list(ModelSummaryMode))}, got invalid."):
ModelSummaryMode.get_max_depth("invalid")


def test_gradient_clip_algorithms():
assert GradClipAlgorithmType.supported_types() == ["value", "norm"]
assert GradClipAlgorithmType.supported_type("norm")
Expand Down
10 changes: 2 additions & 8 deletions tests/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_9
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize, UNKNOWN_SIZE
from tests.helpers import BoringModel
from tests.helpers.advanced_models import ParityModuleRNN
Expand Down Expand Up @@ -139,15 +138,10 @@ def forward(self, inp):
return self.head(self.branch1(inp), self.branch2(inp))


def test_invalid_weights_summary():
"""Test that invalid value for weights_summary raises an error."""
def test_invalid_max_depth():
"""Test that invalid value for max_depth raises an error."""
model = LightningModule()

with pytest.raises(
MisconfigurationException, match="`weights_summary` can be None, .* got temp"
), pytest.deprecated_call(match="weights_summary=temp)` is deprecated"):
Trainer(weights_summary="temp")

with pytest.raises(ValueError, match="max_depth` can be .* got temp"):
ModelSummary(model, max_depth="temp")

Expand Down

0 comments on commit f4f14bb

Please sign in to comment.