Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an exception if using amp_level with native amp_backend #9755

Merged
merged 6 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `HorovodPlugin.all_gather` to return a `torch.Tensor` instead of a list ([#9696](https://github.com/PyTorchLightning/pytorch-lightning/pull/9696))


- Raise an exception if using `amp_level` with native `amp_backend` ([#9755](https://github.com/PyTorchLightning/pytorch-lightning/pull/9755))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,18 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return TPUHalfPrecisionPlugin()

if self.amp_type == AMPType.NATIVE:
if self.amp_level is not None:
raise MisconfigurationException(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
f"You have asked for amp_level={self.amp_level} which is not supported "
"with amp_backend='native'."
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)

log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self._is_fully_sharded_training_type:
return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)

return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)

if self.amp_type == AMPType.APEX:
Expand All @@ -581,6 +588,10 @@ def select_precision_plugin(self) -> PrecisionPlugin:
"Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision."
)
log.info("Using APEX 16bit precision.")

if self.amp_level is None:
self.amp_level = "O2"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

return ApexMixedPrecisionPlugin(self.amp_level)

raise MisconfigurationException(
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(
prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
amp_level: str = "O2",
amp_level: Optional[str] = None,
distributed_backend: Optional[str] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
Expand All @@ -190,7 +190,8 @@ def __init__(

amp_backend: The mixed precision backend to use ("native" or "apex").

amp_level: The optimization level to use (O1, O2, etc...).
amp_level: The optimization level to use (O1, O2, etc...). By default it will be set to "O2"
if ``amp_backend`` is set to "apex".

auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
trying to optimize initial learning for faster convergence. trainer.tune() method will
Expand All @@ -203,7 +204,7 @@ def __init__(
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.

auto_select_gpus: If enabled and `gpus` is an integer, pick available
auto_select_gpus: If enabled and ``gpus`` is an integer, pick available
gpus automatically. This is especially useful when
GPUs are configured to be in "exclusive mode", such
that only one process at a time can access them.
Expand All @@ -228,7 +229,7 @@ def __init__(
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
based on the accelerator type.

distributed_backend: Deprecated. Please use 'accelerator'.
distributed_backend: Deprecated. Please use ``accelerator``.

fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
of train, val and test to find any bugs (ie: a sort of unit test).
Expand Down
6 changes: 6 additions & 0 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,9 @@ def test_cpu_amp_precision_16_throws_error(tmpdir):
default_root_dir=tmpdir,
precision=16,
)


@RunIf(min_gpus=1, amp_native=True)
def test_amp_level_raises_error_with_native(tmpdir):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(MisconfigurationException, match="not supported with amp_backend='native'"):
_ = Trainer(default_root_dir=tmpdir, gpus=1, amp_level="O2", amp_backend="native", precision=16)