Skip to content

Commit

Permalink
fix cyclic import
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and Borda committed Feb 16, 2021
1 parent 829a822 commit 6cce24f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
5 changes: 2 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.plugins.precision import (
ApexMixedPrecisionPlugin,
NativeMixedPrecisionPlugin,
Expand Down Expand Up @@ -63,7 +62,7 @@ def __init__(
self.lr_schedulers = None
self.optimizer_frequencies = None

def setup(self, trainer: "Trainer", model: LightningModule) -> None:
def setup(self, trainer, model: LightningModule) -> None:
"""
Connects the plugins to the training process, creates optimizers
Expand Down Expand Up @@ -302,7 +301,7 @@ def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer: "Trainer"):
def setup_optimizers(self, trainer):
"""creates optimizers and schedulers
Args:
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer

from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
Expand Down Expand Up @@ -330,11 +329,11 @@ def post_training(self):
if self.main_rpc_process:
super().post_training()

def start_training(self, trainer: 'Trainer') -> None:
def start_training(self, trainer) -> None:
if self.main_rpc_process:
super().start_training(trainer)

def start_testing(self, trainer: 'Trainer') -> None:
def start_testing(self, trainer) -> None:
if self.main_rpc_process:
super().start_testing(trainer)

Expand Down

0 comments on commit 6cce24f

Please sign in to comment.