From 2cf39dc442df4fd63265178e1c6b2ed05f8b82d2 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 18 Feb 2021 19:24:19 +0000 Subject: [PATCH] Add warnings to on_before/after_batch_transfer hooks (#6059) * Add warnings to hooks * Add default idx to prevent signature change in the future * Nothing to see here * Add default val to transfer_batch_to_device hook * Apply suggestions from code review Co-authored-by: Jirka Borovec * Revert "Add default val to transfer_batch_to_device hook" This reverts commit 5c6a68f2 Co-authored-by: Jirka Borovec --- docs/source/extensions/datamodules.rst | 11 +++++++++-- pytorch_lightning/core/hooks.py | 14 ++++++++++---- pytorch_lightning/core/lightning.py | 6 +++--- tests/core/test_datamodules.py | 4 ++-- tests/models/test_hooks.py | 4 ++-- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 8a6a85eb4bb70..a6c083dc61fcf 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -314,11 +314,14 @@ Override to alter or apply augmentations to your batch before it is transferred .. testcode:: class MNISTDataModule(LightningDataModule): - def on_before_batch_transfer(self, batch): + def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch +.. warning:: + Currently dataloader_idx always returns 0 and will be updated to support the true idx in the future. + .. note:: This hook only runs on single GPU training and DDP (no data-parallel). @@ -329,11 +332,15 @@ Override to alter or apply augmentations to your batch after it is transferred t .. testcode:: class MNISTDataModule(LightningDataModule): - def on_after_batch_transfer(self, batch): + def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch +.. warning:: + + Currently ``dataloader_idx`` always returns 0 and will be updated to support the true ``idx`` in the future. + .. note:: This hook only runs on single GPU training and DDP (no data-parallel). This hook will also be called when using CPU device, so adding augmentations here or in diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 05fc9e9ec3cee..e0b33c1219e8b 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -616,22 +616,25 @@ def transfer_batch_to_device(self, batch, device): device = device or self.device return move_data_to_device(batch, device) - def on_before_batch_transfer(self, batch): + def on_before_batch_transfer(self, batch, dataloader_idx): """ Override to alter or apply batch augmentations to your batch before it is transferred to the device. + .. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future. + Note: This hook only runs on single GPU training and DDP (no data-parallel). Args: batch: A batch of data that needs to be altered or augmented. + dataloader_idx: DataLoader idx for batch Returns: A batch of data Example:: - def on_before_batch_transfer(self, batch): + def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch @@ -641,22 +644,25 @@ def on_before_batch_transfer(self, batch): """ return batch - def on_after_batch_transfer(self, batch): + def on_after_batch_transfer(self, batch, dataloader_idx): """ Override to alter or apply batch augmentations to your batch after it is transferred to the device. + .. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true ``idx`` in the future. + Note: This hook only runs on single GPU training and DDP (no data-parallel). Args: batch: A batch of data that needs to be altered or augmented. + dataloader_idx: DataLoader idx for batch (Default: 0) Returns: A batch of data Example:: - def on_after_batch_transfer(self, batch): + def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 473a792d3ba44..bd97b7951cfa8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -178,10 +178,10 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None - def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None): - batch = self.on_before_batch_transfer(batch) + def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0): + batch = self.on_before_batch_transfer(batch, dataloader_idx) batch = self.transfer_batch_to_device(batch, device) - batch = self.on_after_batch_transfer(batch) + batch = self.on_after_batch_transfer(batch, dataloader_idx) return batch def print(self, *args, **kwargs) -> None: diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 299d196604be0..aa50405f87cd9 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -438,13 +438,13 @@ class CurrentTestDM(LightningDataModule): on_before_batch_transfer_hook_rank = None on_after_batch_transfer_hook_rank = None - def on_before_batch_transfer(self, batch): + def on_before_batch_transfer(self, batch, dataloader_idx): self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 batch.samples += 1 return batch - def on_after_batch_transfer(self, batch): + def on_after_batch_transfer(self, batch, dataloader_idx): assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d714062fb7915..416a858537245 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -160,13 +160,13 @@ class CurrentTestModel(BoringModel): on_before_batch_transfer_hook_rank = None on_after_batch_transfer_hook_rank = None - def on_before_batch_transfer(self, batch): + def on_before_batch_transfer(self, batch, dataloader_idx): self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 batch.samples += 1 return batch - def on_after_batch_transfer(self, batch): + def on_after_batch_transfer(self, batch, dataloader_idx): assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1