Skip to content

Commit

Permalink
Add warnings to on_before/after_batch_transfer hooks (#6059)
Browse files Browse the repository at this point in the history
* Add warnings to hooks

* Add default idx to prevent signature change in the future

* Nothing to see here

* Add default val to transfer_batch_to_device hook

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>

* Revert "Add default val to transfer_batch_to_device hook"

This reverts commit 5c6a68f

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SeanNaren and Borda authored Feb 18, 2021
1 parent d3a31bc commit 2cf39dc
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
11 changes: 9 additions & 2 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,14 @@ Override to alter or apply augmentations to your batch before it is transferred
.. testcode::

class MNISTDataModule(LightningDataModule):
def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch


.. warning::
Currently dataloader_idx always returns 0 and will be updated to support the true idx in the future.

.. note:: This hook only runs on single GPU training and DDP (no data-parallel).


Expand All @@ -329,11 +332,15 @@ Override to alter or apply augmentations to your batch after it is transferred t
.. testcode::

class MNISTDataModule(LightningDataModule):
def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch


.. warning::

Currently ``dataloader_idx`` always returns 0 and will be updated to support the true ``idx`` in the future.

.. note::
This hook only runs on single GPU training and DDP (no data-parallel). This hook
will also be called when using CPU device, so adding augmentations here or in
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,22 +616,25 @@ def transfer_batch_to_device(self, batch, device):
device = device or self.device
return move_data_to_device(batch, device)

def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: DataLoader idx for batch
Returns:
A batch of data
Example::
def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch
Expand All @@ -641,22 +644,25 @@ def on_before_batch_transfer(self, batch):
"""
return batch

def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch after it is transferred to the device.
.. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true ``idx`` in the future.
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: DataLoader idx for batch (Default: 0)
Returns:
A batch of data
Example::
def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ def logger(self):
""" Reference to the logger object in the Trainer. """
return self.trainer.logger if self.trainer else None

def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None):
batch = self.on_before_batch_transfer(batch)
def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0):
batch = self.on_before_batch_transfer(batch, dataloader_idx)
batch = self.transfer_batch_to_device(batch, device)
batch = self.on_after_batch_transfer(batch)
batch = self.on_after_batch_transfer(batch, dataloader_idx)
return batch

def print(self, *args, **kwargs) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,13 @@ class CurrentTestDM(LightningDataModule):
on_before_batch_transfer_hook_rank = None
on_after_batch_transfer_hook_rank = None

def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
self.on_before_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.samples += 1
return batch

def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
assert batch.samples.device == batch.targets.device == expected_device
self.on_after_batch_transfer_hook_rank = self.rank
self.rank += 1
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ class CurrentTestModel(BoringModel):
on_before_batch_transfer_hook_rank = None
on_after_batch_transfer_hook_rank = None

def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
self.on_before_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.samples += 1
return batch

def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
assert batch.samples.device == batch.targets.device == expected_device
self.on_after_batch_transfer_hook_rank = self.rank
self.rank += 1
Expand Down

0 comments on commit 2cf39dc

Please sign in to comment.