Skip to content

Commit

Permalink
Deprecate terminate_on_nan Trainer argument in favor of `detect_ano…
Browse files Browse the repository at this point in the history
…maly` (#9175)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2021
1 parent 6a0c47a commit 173f4c8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))


- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`


Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/trainer/connectors/training_trick_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Optional, Union

from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -26,10 +26,15 @@ def on_trainer_init(
gradient_clip_val: Union[int, float],
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
terminate_on_nan: bool,
terminate_on_nan: Optional[bool],
):
if not isinstance(terminate_on_nan, bool):
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
if terminate_on_nan is not None:
rank_zero_deprecation(
"Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7."
" Please use `Trainer(detect_anomaly=True)` instead."
)
if not isinstance(terminate_on_nan, bool):
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")

# gradient clipping
if not isinstance(gradient_clip_val, (int, float)):
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
terminate_on_nan: bool = False,
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
Expand All @@ -177,7 +177,7 @@ def __init__(
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
stochastic_weight_avg: bool = False,
detect_anomaly: bool = False,
terminate_on_nan: Optional[bool] = None,
):
r"""
Customize every aspect of training via flags.
Expand Down Expand Up @@ -351,6 +351,12 @@ def __init__(
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
.. deprecated:: v1.5
Trainer argument ``terminate_on_nan`` was deprecated in v1.5 and will be removed in 1.7.
Please use ``detect_anomaly`` instead.
detect_anomaly: Enable anomaly detection for the autograd engine.
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
ipus: How many IPUs to train on.
Expand Down
10 changes: 10 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir):
_ = Trainer(stochastic_weight_avg=True)


@pytest.mark.parametrize("terminate_on_nan", [True, False])
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
with pytest.deprecated_call(
match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7"
):
trainer = Trainer(terminate_on_nan=terminate_on_nan)
assert trainer.terminate_on_nan is terminate_on_nan
assert trainer._detect_anomaly is False


def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):
Expand Down

0 comments on commit 173f4c8

Please sign in to comment.