diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py index fd0da54181484..45fa90a48445f 100644 --- a/pytorch_lightning/accelerators/base_backend.py +++ b/pytorch_lightning/accelerators/base_backend.py @@ -8,6 +8,9 @@ class Accelerator(object): def __init__(self, trainer): self.trainer = trainer + def setup(self): + pass + def teardown(self): pass diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 614241625a7a5..f44d9f20a2b7c 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -43,15 +43,24 @@ class DDPBackend(Accelerator): - def __init__(self, trainer): + def __init__(self, trainer, mode: str = 'ddp'): super().__init__(trainer) self.task_idx = None self._has_spawned_children = False + self.mode = mode - def slurm_setup(self): + def setup(self): + if self.mode == 'ddp': + pass + elif self.mode == 'slurm_ddp': + self.__slurm_setup() + elif self.mode == 'torchelastic_ddp': + self.__torchelastic_setup() + + def __slurm_setup(self): self.task_idx = int(os.environ['SLURM_LOCALID']) - def torchelastic_setup(self): + def __torchelastic_setup(self): self.task_idx = int(os.environ['LOCAL_RANK']) def train(self, model): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9b758604ea4f6..7b5f7463d63bc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1040,19 +1040,19 @@ def fit( if self.use_ddp2: self.accelerator_backend = DDP2Backend(self) self.accelerator_backend.setup() - self.accelerator_backend.train(model) + results = self.accelerator_backend.train(model) self.accelerator_backend.teardown() elif use_slurm_ddp: - self.accelerator_backend = DDPBackend(self) - self.accelerator_backend.slurm_setup() - self.accelerator_backend.train(model) + self.accelerator_backend = DDPBackend(self, mode='slurm_ddp') + self.accelerator_backend.setup() + results = self.accelerator_backend.train(model) self.accelerator_backend.teardown() elif use_torchelastic_ddp: - self.accelerator_backend = DDPBackend(self) - self.accelerator_backend.torchelastic_setup() - self.accelerator_backend.train(model) + self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp') + self.accelerator_backend.setup() + results = self.accelerator_backend.train(model) self.accelerator_backend.teardown() # regular ddp using .spawn