Skip to content

Commit

Permalink
refactored dataloader process hook (#3139)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Aug 25, 2020
1 parent 229b876 commit f064d74
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ def test_step_end(self, output):

def validation_step_end(self, output):
return output

def process_dataloader(self, dataloader):
return dataloader
7 changes: 7 additions & 0 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as xla_pl
except ImportError:
XLA_AVAILABLE = False
else:
Expand Down Expand Up @@ -139,6 +140,12 @@ def test_step(self, args):
output = self.trainer.model.test_step(*args)
return output

def process_dataloader(self, dataloader):
device = xm.xla_device(self.trainer.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
return dataloader

def to_device(self, batch):
"""
Transfers the data to the TPU.
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,7 @@ def _evaluate(
dl_outputs = []

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
dataloader = self.accelerator_backend.process_dataloader(dataloader)

# each dataloader has a max num batches
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
Expand Down
11 changes: 1 addition & 10 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,15 +427,6 @@ def train(self):

self.run_training_teardown()

def prepare_train_loop_dataloader(self, train_dataloader):
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)

return train_dataloader

def run_on_epoch_start_hook(self, model):
# Epoch start events
with self.profiler.profile('on_epoch_start'):
Expand Down Expand Up @@ -464,7 +455,7 @@ def run_training_epoch(self):
self.run_on_epoch_start_hook(model)

# modify dataloader if needed (ddp, etc...)
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)

# bookkeeping
num_optimizers = len(self._get_optimizers_iterable())
Expand Down

0 comments on commit f064d74

Please sign in to comment.