Skip to content

Commit

Permalink
ddp backend refactor (#3204)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Aug 26, 2020
1 parent f3384d0 commit ff3c2f4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ class Accelerator(object):
def __init__(self, trainer):
self.trainer = trainer

def setup(self):
pass

def teardown(self):
pass

Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,24 @@

class DDPBackend(Accelerator):

def __init__(self, trainer):
def __init__(self, trainer, mode: str = 'ddp'):
super().__init__(trainer)
self.task_idx = None
self._has_spawned_children = False
self.mode = mode

def slurm_setup(self):
def setup(self):
if self.mode == 'ddp':
pass
elif self.mode == 'slurm_ddp':
self.__slurm_setup()
elif self.mode == 'torchelastic_ddp':
self.__torchelastic_setup()

def __slurm_setup(self):
self.task_idx = int(os.environ['SLURM_LOCALID'])

def torchelastic_setup(self):
def __torchelastic_setup(self):
self.task_idx = int(os.environ['LOCAL_RANK'])

def train(self, model):
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,19 +1040,19 @@ def fit(
if self.use_ddp2:
self.accelerator_backend = DDP2Backend(self)
self.accelerator_backend.setup()
self.accelerator_backend.train(model)
results = self.accelerator_backend.train(model)
self.accelerator_backend.teardown()

elif use_slurm_ddp:
self.accelerator_backend = DDPBackend(self)
self.accelerator_backend.slurm_setup()
self.accelerator_backend.train(model)
self.accelerator_backend = DDPBackend(self, mode='slurm_ddp')
self.accelerator_backend.setup()
results = self.accelerator_backend.train(model)
self.accelerator_backend.teardown()

elif use_torchelastic_ddp:
self.accelerator_backend = DDPBackend(self)
self.accelerator_backend.torchelastic_setup()
self.accelerator_backend.train(model)
self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
self.accelerator_backend.setup()
results = self.accelerator_backend.train(model)
self.accelerator_backend.teardown()

# regular ddp using .spawn
Expand Down

0 comments on commit ff3c2f4

Please sign in to comment.