Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to LightningModule.trainer #12345

Merged
merged 13 commits into from
Mar 29, 2022
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}")

# pointer to the trainer object
self.trainer = None
self.trainer: Optional["pl.Trainer"] = None
johnhenning marked this conversation as resolved.
Show resolved Hide resolved

# true if using amp
self.use_amp: bool = False
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
assert model.trainer is not None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)

if optim_conf is None:
Expand Down
53 changes: 23 additions & 30 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
from typing import Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -57,13 +57,10 @@ def on_post_move_to_device(self) -> None:


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ``test_step``.
If the LightningModule is in none of the states `training`, `testing` or `validation`,
the inputs will be redirected to the
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
def __init__(self, pl_module: "pl.LightningModule"):
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.

Inheriting classes may also modify the inputs or outputs of forward.

Args:
Expand All @@ -77,28 +74,24 @@ def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionMod
self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore]

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
lightning_module = unwrap_lightning_module(self.module)
trainer = lightning_module.trainer

if trainer and trainer.training:
output = self.module.training_step(*inputs, **kwargs)

# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in `LightningModule.manual_backward`
# `require_backward_grad_sync` will be reset in the
# ddp_strategy `post_training_step` hook
if not lightning_module.automatic_optimization:
trainer.model.require_backward_grad_sync = False
elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
elif trainer and (trainer.sanity_checking or trainer.validating):
output = self.module.validation_step(*inputs, **kwargs)
elif trainer and trainer.predicting:
output = self.module.predict_step(*inputs, **kwargs)
else:
output = self.module(*inputs, **kwargs)

return output
trainer = self.module.trainer
if trainer is not None:
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in `LightningModule.manual_backward`
# `require_backward_grad_sync` will be reset in the
# ddp_strategy `post_training_step` hook
if not self.module.automatic_optimization:
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
return output
if trainer.testing:
return self.module.test_step(*inputs, **kwargs)
if trainer.sanity_checking or trainer.validating:
return self.module.validation_step(*inputs, **kwargs)
if trainer.predicting:
return self.module.predict_step(*inputs, **kwargs)
return self.module(*inputs, **kwargs)

def on_post_move_to_device(self) -> None:
pass
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, cast, Iterator, List, Sized, Union
from typing import Any, cast, Iterable, Iterator, List, Sized, Union

import torch
from torch import Tensor
Expand All @@ -27,8 +27,9 @@
class LightningDistributedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule") -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination
with :class:`~torch.nn.parallel.DistributedDataParallel` as shown in the example.
``training_step``, ``validation_step``, ``test_step`` or ``predict``.

This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel`.

Example:

Expand Down Expand Up @@ -163,5 +164,5 @@ def batch_size(self) -> int:
return self._sampler.batch_size

@property
def sampler(self) -> Sampler:
def sampler(self) -> Union[Sampler, Iterable]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return self._sampler.sampler
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def backward(
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
"""
assert model.trainer is not None
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
" the backward logic internally."
)
assert model.trainer is not None
deepspeed_engine: DeepSpeedEngine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)

Expand Down Expand Up @@ -75,7 +76,12 @@ def optimizer_step(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
deepspeed_engine: DeepSpeedEngine
if isinstance(model, pl.LightningModule):
assert model.trainer is not None
deepspeed_engine = model.trainer.model
else:
deepspeed_engine = model
return deepspeed_engine.step(**kwargs)

def clip_gradients(
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Ten
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
assert model.trainer is not None
model.trainer._call_callback_hooks("on_before_backward", closure_loss)
model.trainer._call_lightning_module_hook("on_before_backward", closure_loss)
return closure_loss
Expand Down Expand Up @@ -89,6 +90,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
"""
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
assert model.trainer is not None
model.trainer._call_callback_hooks("on_after_backward")
model.trainer._call_lightning_module_hook("on_after_backward")
return closure_loss
Expand All @@ -112,6 +114,7 @@ def _after_closure(
trainer._call_callback_hooks("on_before_optimizer_step", optimizer, optimizer_idx)
trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer, optimizer_idx)
# TODO: this is done for the entire model but should be changed to per-optimizer

if optimizer_idx == 0:
self._track_grad_norm(trainer)
self._clip_gradients(
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def configure_ddp(self) -> None:
self._model = self._setup_model(model)

# start the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.resume(self.model) # type: ignore

Expand Down Expand Up @@ -188,6 +189,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:

def teardown(self) -> None:
# abort the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.abort(self.model) # type: ignore

Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,8 @@ def __init__(

self._terminate_on_nan = terminate_on_nan
self.gradient_clip_val: Union[int, float] = gradient_clip_val
self.gradient_clip_algorithm = (
GradClipAlgorithmType(gradient_clip_algorithm.lower())
if gradient_clip_algorithm is not None
else gradient_clip_algorithm
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
)
self.track_grad_norm: float = float(track_grad_norm)

Expand Down Expand Up @@ -1184,7 +1182,7 @@ def _run(
# ----------------------------
# INSPECT THE CORE LOOPS
# ----------------------------
fr"""
rf"""
Lightning internal flow looks like this:
{Trainer.fit} or {Trainer.test} or {Trainer.predict} ||
| ||
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
return float("inf")


def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
try:
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/utilities/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to model weights summary."""

import contextlib
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -265,12 +263,8 @@ def _forward_example_input(self) -> None:
mode = model.training
model.eval()

if trainer is not None:
forward_context = trainer.precision_plugin.forward_context()
else:
forward_context = contextlib.nullcontext()

with torch.no_grad(), forward_context:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert trainer is not None
with torch.no_grad(), trainer.precision_plugin.forward_context():
# let the model hooks collect the input- and output shapes
if isinstance(input_, (list, tuple)):
model(*input_)
Expand Down