Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated 'terminate_on_nan' argument from Trainer #12553

2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `prepare_data_per_node` argument from the `Trainer` constructor ([#12536](https://github.com/PyTorchLightning/pytorch-lightning/pull/12536))


-
- Removed the deprecated `terminate_on_nan` argument from the `Trainer` constructor ([#12553](https://github.com/PyTorchLightning/pytorch-lightning/pull/12553))


-
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
_block_parallel_sync_behavior,
_build_training_step_kwargs,
_extract_hiddens,
check_finite_loss,
)
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -310,10 +308,6 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
def backward_fn(loss: Tensor) -> None:
self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)

# check if model weights are nan
if self.trainer._terminate_on_nan:
detect_nan_parameters(self.trainer.lightning_module)

return backward_fn

def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
Expand Down Expand Up @@ -437,9 +431,6 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
training_step_output, self.trainer.accumulate_grad_batches
)

if self.trainer._terminate_on_nan:
check_finite_loss(result.closure_loss)

if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
assert self.trainer._results is not None
Expand Down
33 changes: 0 additions & 33 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def __init__(
amp_level: Optional[str] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
terminate_on_nan: Optional[bool] = None,
) -> None:
r"""
Customize every aspect of training via flags.
Expand Down Expand Up @@ -386,16 +385,6 @@ def __init__(
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
Default: ``False``.

terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

.. deprecated:: v1.5
Trainer argument ``terminate_on_nan`` was deprecated in v1.5 and will be removed in 1.7.
Please use ``detect_anomaly`` instead.

detect_anomaly: Enable anomaly detection for the autograd engine.
Default: ``False``.

tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on (1)
Default: ``None``.

Expand Down Expand Up @@ -535,14 +524,6 @@ def __init__(
reload_dataloaders_every_n_epochs,
)

if terminate_on_nan is not None:
rank_zero_deprecation(
"Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7."
" Please use `Trainer(detect_anomaly=True)` instead."
)
if not isinstance(terminate_on_nan, bool):
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")

# gradient clipping
if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
Expand All @@ -563,7 +544,6 @@ def __init__(
f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}."
)

self._terminate_on_nan = terminate_on_nan
self.gradient_clip_val: Union[int, float] = gradient_clip_val
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
Expand Down Expand Up @@ -2797,19 +2777,6 @@ def configure_optimizers(self):
max_estimated_steps = min(max_estimated_steps, self.max_steps) if self.max_steps != -1 else max_estimated_steps
return max_estimated_steps

@property
def terminate_on_nan(self) -> bool:
rank_zero_deprecation("`Trainer.terminate_on_nan` is deprecated in v1.5 and will be removed in 1.7.")
return self._terminate_on_nan

@terminate_on_nan.setter
def terminate_on_nan(self, val: bool) -> None:
rank_zero_deprecation(
f"Setting `Trainer.terminate_on_nan = {val}` is deprecated in v1.5 and will be removed in 1.7."
f" Please set `Trainer(detect_anomaly={val})` instead."
)
self._terminate_on_nan = val # : 212


def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]:
if batches is None:
Expand Down
17 changes: 0 additions & 17 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,6 @@ def get_progress_bar_dict(self):
_ = trainer.progress_bar_dict


@pytest.mark.parametrize("terminate_on_nan", [True, False])
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
with pytest.deprecated_call(
match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7"
):
trainer = Trainer(terminate_on_nan=terminate_on_nan)
assert trainer.terminate_on_nan is terminate_on_nan
assert trainer._detect_anomaly is False

trainer = Trainer()
with pytest.deprecated_call(match=r"`Trainer.terminate_on_nan` is deprecated in v1.5"):
_ = trainer.terminate_on_nan

with pytest.deprecated_call(match=r"Setting `Trainer.terminate_on_nan = True` is deprecated in v1.5"):
trainer.terminate_on_nan = True


def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):
Expand Down
60 changes: 0 additions & 60 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,72 +904,12 @@ def validation_epoch_end(self, *args, **kwargs):
assert model.validation_epoch_end_invoked, "did not run `validation_epoch_end` with `fast_dev_run=True`"


@mock.patch("torch.Tensor.backward")
def test_nan_loss_detection(backward_mock, tmpdir):
class CurrentModel(BoringModel):
test_batch_inf = 3

def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
if batch_idx == self.test_batch_inf:
if isinstance(output, dict):
output["loss"] *= torch.tensor(math.inf) # make loss infinite
else:
output /= 0
return output

model = CurrentModel()

with pytest.deprecated_call(match="terminate_on_nan` was deprecated in v1.5"):
trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_inf + 1), terminate_on_nan=True)

with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"):
trainer.fit(model)
assert trainer.global_step == model.test_batch_inf
assert backward_mock.call_count == model.test_batch_inf

for param in model.parameters():
assert torch.isfinite(param).all()


def test_invalid_terminate_on_nan(tmpdir):
with pytest.raises(TypeError, match="`terminate_on_nan` should be a bool"), pytest.deprecated_call(
match="terminate_on_nan` was deprecated in v1.5"
):
Trainer(default_root_dir=tmpdir, terminate_on_nan="False")


@pytest.mark.parametrize("track_grad_norm", [0, torch.tensor(1), "nan"])
def test_invalid_track_grad_norm(tmpdir, track_grad_norm):
with pytest.raises(MisconfigurationException, match="`track_grad_norm` must be a positive number or 'inf'"):
Trainer(default_root_dir=tmpdir, track_grad_norm=track_grad_norm)


@mock.patch("torch.Tensor.backward")
def test_nan_params_detection(backward_mock, tmpdir):
class CurrentModel(BoringModel):
test_batch_nan = 3

def on_after_backward(self):
if self.global_step == self.test_batch_nan:
# simulate parameter that became nan
torch.nn.init.constant_(self.layer.bias, math.nan)

model = CurrentModel()

with pytest.deprecated_call(match="terminate_on_nan` was deprecated in v1.5"):
trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_nan + 1), terminate_on_nan=True)

with pytest.raises(ValueError, match=r".*Detected nan and/or inf values in `layer.bias`.*"):
trainer.fit(model)
assert trainer.global_step == model.test_batch_nan
assert backward_mock.call_count == model.test_batch_nan + 1

# after aborting the training loop, model still has nan-valued params
params = torch.cat([param.view(-1) for param in model.parameters()])
assert not torch.isfinite(params).all()


def test_on_exception_hook(tmpdir):
"""Test the on_exception callback hook and the trainer interrupted flag."""

Expand Down
20 changes: 20 additions & 0 deletions tests/utilities/test_finite_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import math

import pytest
import torch
import torch.nn as nn

from pytorch_lightning.utilities.finite_checks import detect_nan_parameters


@pytest.mark.parametrize("value", (math.nan, math.inf, -math.inf))
def test_detect_nan_parameters(value):
model = nn.Linear(2, 3)

detect_nan_parameters(model)

nn.init.constant_(model.bias, value)
assert not torch.isfinite(model.bias).all()

with pytest.raises(ValueError, match=r".*Detected nan and/or inf values in `bias`.*"):
detect_nan_parameters(model)