Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to LightningModule.trainer #12345

Merged
merged 13 commits into from
Mar 29, 2022
Prev Previous commit
Next Next commit
fixed mypy issues
  • Loading branch information
johnhenning authored Mar 22, 2022
commit f18f3d9603454b46723f58ccd290fa5f2d920eb4
2 changes: 2 additions & 0 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""

assert model.trainer is not None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)

if optim_conf is None:
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
# `require_backward_grad_sync` will be reset in the
# ddp_strategy `post_training_step` hook
if not lightning_module.automatic_optimization:
trainer.model.require_backward_grad_sync = False
# NOTE: ignoring this since trainer assumes this is a Module not DataParallel

trainer.model.require_backward_grad_sync = False # type: ignore [assignment]
elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
elif trainer and (trainer.sanity_checking or trainer.validating):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, cast, Iterator, List, Sized, Union
from typing import Any, cast, Iterable, Iterator, List, Sized, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -163,5 +163,5 @@ def batch_size(self) -> int:
return self._sampler.batch_size

@property
def sampler(self) -> Sampler:
def sampler(self) -> Union[Sampler, Iterable]:
return self._sampler.sampler
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def backward(
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
"""
assert model.trainer is not None
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
" the backward logic internally."
)
assert model.trainer is not None
deepspeed_engine: DeepSpeedEngine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)

def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None:
def _run_backward(self, tensor: Tensor, model: Optional[DeepSpeedEngine], *args: Any, **kwargs: Any) -> None:
if model is None:
raise ValueError("Please provide the model as input to `backward`.")
model.backward(tensor, *args, **kwargs)
Expand All @@ -75,7 +76,8 @@ def optimizer_step(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
assert model.trainer is not None
deepspeed_engine: DeepSpeedEngine = model.trainer.model if isinstance(model, pl.LightningModule) else model
return deepspeed_engine.step(**kwargs)

def clip_gradients(
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Ten
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
model.trainer._call_callback_hooks("on_before_backward", closure_loss)
model.trainer._call_lightning_module_hook("on_before_backward", closure_loss)
if model.trainer is not None:
model.trainer._call_callback_hooks("on_before_backward", closure_loss)
model.trainer._call_lightning_module_hook("on_before_backward", closure_loss)
return closure_loss

def backward(
Expand Down Expand Up @@ -89,6 +90,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
"""
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
assert model.trainer is not None
model.trainer._call_callback_hooks("on_after_backward")
model.trainer._call_lightning_module_hook("on_after_backward")
return closure_loss
Expand All @@ -109,9 +111,13 @@ def _after_closure(
return
trainer = model.trainer
assert trainer is not None

trainer._call_callback_hooks("on_before_optimizer_step", optimizer, optimizer_idx)
trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer, optimizer_idx)
# TODO: this is done for the entire model but should be changed to per-optimizer

assert hasattr(trainer, "gradient_clip_algorithm")

if optimizer_idx == 0:
self._track_grad_norm(trainer)
self._clip_gradients(
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def configure_ddp(self) -> None:
self._model = self._setup_model(model)

# start the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.resume(self.model) # type: ignore

Expand Down Expand Up @@ -188,6 +189,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:

def teardown(self) -> None:
# abort the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.abort(self.model) # type: ignore

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def __init__(

self._terminate_on_nan = terminate_on_nan
self.gradient_clip_val: Union[int, float] = gradient_clip_val
self.gradient_clip_algorithm = (
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
GradClipAlgorithmType(gradient_clip_algorithm.lower())
if gradient_clip_algorithm is not None
else gradient_clip_algorithm
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def _run(
# ----------------------------
# INSPECT THE CORE LOOPS
# ----------------------------
fr"""
rf"""
Lightning internal flow looks like this:
{Trainer.fit} or {Trainer.test} or {Trainer.predict} ||
| ||
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from contextlib import contextmanager
from functools import partial
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Mapping, Optional, Set, Type, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, SequentialSampler
Expand Down Expand Up @@ -170,7 +170,9 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
return float("inf")


def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
try:
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/utilities/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
import logging
from collections import OrderedDict
from multiprocessing import context
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -265,10 +266,10 @@ def _forward_example_input(self) -> None:
mode = model.training
model.eval()

forward_context: contextlib.AbstractContextManager = contextlib.nullcontext()

if trainer is not None:
forward_context = trainer.precision_plugin.forward_context()
else:
forward_context = contextlib.nullcontext()

with torch.no_grad(), forward_context:
# let the model hooks collect the input- and output shapes
Expand Down Expand Up @@ -398,7 +399,7 @@ def get_human_readable_count(number: int) -> str:
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10 ** shift)
number = number * (10**shift)
index = num_groups - 1
if index < 1 or number >= 100:
return f"{int(number):,d} {labels[index]}"
Expand Down