From 6cce24f84eb442765805a8e86bcf82e04c0766de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 14 Feb 2021 03:36:25 +0100 Subject: [PATCH] fix cyclic import --- pytorch_lightning/accelerators/accelerator.py | 5 ++--- pytorch_lightning/plugins/training_type/rpc_sequential.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 77062f350ca09..321c2fd78aefa 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -18,7 +18,6 @@ from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule -from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.plugins.precision import ( ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, @@ -63,7 +62,7 @@ def __init__( self.lr_schedulers = None self.optimizer_frequencies = None - def setup(self, trainer: "Trainer", model: LightningModule) -> None: + def setup(self, trainer, model: LightningModule) -> None: """ Connects the plugins to the training process, creates optimizers @@ -302,7 +301,7 @@ def on_train_end(self) -> None: """Hook to do something at the end of the training""" pass - def setup_optimizers(self, trainer: "Trainer"): + def setup_optimizers(self, trainer): """creates optimizers and schedulers Args: diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index edfda8fce284d..36db16c4894a0 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -21,7 +21,6 @@ from torch.nn.parallel import DistributedDataParallel from torch.optim import Optimizer -from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin @@ -330,11 +329,11 @@ def post_training(self): if self.main_rpc_process: super().post_training() - def start_training(self, trainer: 'Trainer') -> None: + def start_training(self, trainer) -> None: if self.main_rpc_process: super().start_training(trainer) - def start_testing(self, trainer: 'Trainer') -> None: + def start_testing(self, trainer) -> None: if self.main_rpc_process: super().start_testing(trainer)