Skip to content

Commit

Permalink
Raise an exception if using amp_level with native amp_backend (#9755
Browse files Browse the repository at this point in the history
)

* add exception

* chlog

* code review

* Apply suggestions from code review

Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
rohitgr7 and tchaton authored Oct 1, 2021
1 parent 9d98208 commit 617e798
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `HorovodPlugin.all_gather` to return a `torch.Tensor` instead of a list ([#9696](https://github.com/PyTorchLightning/pytorch-lightning/pull/9696))


- Raise an exception if using `amp_level` with native `amp_backend` ([#9755](https://github.com/PyTorchLightning/pytorch-lightning/pull/9755))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,18 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return TPUHalfPrecisionPlugin()

if self.amp_type == AMPType.NATIVE:
if self.amp_level is not None:
raise MisconfigurationException(
f"You have asked for `amp_level={repr(self.amp_level)}` which is not supported"
" with `amp_backend='native'`."
)

log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self._is_fully_sharded_training_type:
return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)

return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)

if self.amp_type == AMPType.APEX:
Expand All @@ -581,6 +588,9 @@ def select_precision_plugin(self) -> PrecisionPlugin:
"Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision."
)
log.info("Using APEX 16bit precision.")

self.amp_level = self.amp_level or "O2"

return ApexMixedPrecisionPlugin(self.amp_level)

raise MisconfigurationException(
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(
prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
amp_level: str = "O2",
amp_level: Optional[str] = None,
distributed_backend: Optional[str] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
Expand All @@ -190,7 +190,8 @@ def __init__(
amp_backend: The mixed precision backend to use ("native" or "apex").
amp_level: The optimization level to use (O1, O2, etc...).
amp_level: The optimization level to use (O1, O2, etc...). By default it will be set to "O2"
if ``amp_backend`` is set to "apex".
auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
trying to optimize initial learning for faster convergence. trainer.tune() method will
Expand All @@ -203,7 +204,7 @@ def __init__(
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.
auto_select_gpus: If enabled and `gpus` is an integer, pick available
auto_select_gpus: If enabled and ``gpus`` is an integer, pick available
gpus automatically. This is especially useful when
GPUs are configured to be in "exclusive mode", such
that only one process at a time can access them.
Expand All @@ -228,7 +229,7 @@ def __init__(
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
based on the accelerator type.
distributed_backend: Deprecated. Please use 'accelerator'.
distributed_backend: Deprecated. Please use ``accelerator``.
fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
of train, val and test to find any bugs (ie: a sort of unit test).
Expand Down
6 changes: 6 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,9 @@ def test_validate_precision_type(tmpdir, precision):

with pytest.raises(MisconfigurationException, match=f"Precision {precision} is invalid"):
Trainer(precision=precision)


@RunIf(min_gpus=1, amp_native=True)
def test_amp_level_raises_error_with_native(tmpdir):
with pytest.raises(MisconfigurationException, match="not supported with `amp_backend='native'`"):
_ = Trainer(default_root_dir=tmpdir, gpus=1, amp_level="O2", amp_backend="native", precision=16)

0 comments on commit 617e798

Please sign in to comment.