Skip to content

Commit

Permalink
docstring changes in accelerators (#6327)
Browse files Browse the repository at this point in the history
* docstring changes in accelerators

* docstrings moved

* whitespaces removed

* PEP8 correction[1]
  • Loading branch information
AlKun25 authored Mar 4, 2021
1 parent 7acbd65 commit 49c579f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
class CPUAccelerator(Accelerator):

def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If AMP is used with CPU, or if the selected device is not CPU.
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
class GPUAccelerator(Accelerator):

def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags()
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
class TPUAccelerator(Accelerator):

def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training.
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. "
Expand All @@ -31,7 +36,9 @@ def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
Expand Down

0 comments on commit 49c579f

Please sign in to comment.