Skip to content

Commit

Permalink
Deprecate dataloader_idx from on_train_batch_start/end (#9816)
Browse files Browse the repository at this point in the history
* deprecate hooks

* dep todo

* explicit

* Apply suggestions from code review

* Apply suggestions from code review

* code review

* base
  • Loading branch information
rohitgr7 authored Oct 7, 2021
1 parent 0561fd6 commit 4decbc0
Show file tree
Hide file tree
Showing 31 changed files with 150 additions and 67 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def on_train_end(self) -> None:
"""Called when train ends."""
return self.training_type_plugin.on_train_end()

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
# TODO: Update this in v1.7 (deprecation: #9816)
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Called in the training loop before anything happens for that batch."""
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)
return self.training_type_plugin.on_train_batch_start(batch, batch_idx)
9 changes: 7 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
pass

def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
"""Called when the train batch begins."""
pass
Expand All @@ -109,7 +114,7 @@ def on_train_batch_end(
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
unused: Optional[int] = 0,
) -> None:
"""Called when the train batch ends."""
pass
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

@rank_zero_only
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()
Expand All @@ -161,7 +161,6 @@ def on_train_batch_end(
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def on_train_batch_end(
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
if self._should_skip_saving_checkpoint(trainer):
Expand All @@ -304,9 +303,7 @@ def on_train_batch_end(

self.save_checkpoint(trainer)

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
) -> None:
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
trainer.fit_loop.global_step -= 1
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(self):
def disable(self):
self.enable = False
def on_train_batch_end(self, trainer, pl_module, outputs):
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch_idx) # don't forget this :)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
Expand Down Expand Up @@ -161,7 +161,7 @@ def on_train_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def on_train_epoch_start(self, trainer, pl_module):
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
total_batches = self.total_train_batches + self.total_val_batches
total_batches = convert_inf(total_batches)
if self._should_update(self.train_batch_idx, total_batches):
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,25 @@ def on_pretrain_routine_end(self) -> None:
- training_start
"""

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
"""Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
unused: Deprecated argument. Will be removed in v1.7.
"""

def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
"""Called in the training loop after the batch.
Args:
outputs: The outputs of training_step_end(training_step(x))
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
unused: Deprecated argument. Will be removed in v1.7.
"""

def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.loops.utilities import _get_active_optimizers
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache

_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
Expand Down Expand Up @@ -76,7 +77,14 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict:
return AttributeDict(signal=-1)

# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
return AttributeDict(signal=-1)

Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]

Expand Down Expand Up @@ -170,7 +171,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
num_optimizers=len(self.trainer.optimizers),
)
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, self.batch_idx, 0)

# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_end
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def on_test_end(self):
def on_predict_end(self):
self._detach_models()

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
# Updates optimizer stats if LR scheduler modified the optimizer state
optimizer = self.lightning_module.trainer.optimizers[0]
self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def on_predict_end(self):
"""Called when predict ends."""
pass

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
"""Called in the training loop before anything happens for that batch."""
pass

Expand Down
17 changes: 13 additions & 4 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT


Expand Down Expand Up @@ -161,15 +162,23 @@ def on_batch_end(self):
for callback in self.callbacks:
callback.on_batch_end(self, self.lightning_module)

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# TODO: Update this in v1.7 (deprecation: #9816)
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)
if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True):
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0)
else:
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx)

def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx):
# TODO: Update this in v1.7 (deprecation: #9816)
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True):
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0)
else:
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx)

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
self._check_on_post_move_to_device(model)
# TODO: Delete _check_on_keyboard_interrupt in v1.7
self._check_on_keyboard_interrupt()
# TODO: Remove this in v1.7 (deprecation: #9816)
self._check_dl_idx_in_on_train_batch_hooks(model)

def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None:
# -----------------------------------
Expand Down Expand Up @@ -261,3 +263,18 @@ def _check_on_keyboard_interrupt(self) -> None:
"The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
" Please use the `on_exception` callback hook instead."
)

def _check_dl_idx_in_on_train_batch_hooks(self, model: "pl.LightningModule") -> None:
for hook in ("on_train_batch_start", "on_train_batch_end"):
if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True):
rank_zero_deprecation(
f"Base `LightningModule.{hook}` hook signature has changed in v1.5."
" The `dataloader_idx` argument will be removed in v1.7."
)

for cb in self.trainer.callbacks:
if is_param_in_hook_signature(getattr(cb, hook), "dataloader_idx", explicit=True):
rank_zero_deprecation(
f"Base `Callback.{hook}` hook signature has changed in v1.5."
" The `dataloader_idx` argument will be removed in v1.7."
)
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def on_batch_start(self, trainer, pl_module):

self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0])

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
"""Called when the training batch ends, logs the calculated loss."""
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self):
def should_update(self):
return self.count % 2 == 0

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
self.called["on_train_batch_start"] += 1
self.weight_before = self.layer.weight.clone()

Expand All @@ -181,7 +181,7 @@ def training_step(self, batch, batch_idx):
opt.zero_grad()
return loss

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
self.called["on_train_batch_end"] += 1
after_before = self.layer.weight.clone()
if self.should_update:
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
"""Tests that only training_step can be used."""

class CB(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
assert "loss" in outputs

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand All @@ -32,7 +32,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
assert "x" in outputs

class TestModel(BoringModel):
def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None:
assert "loss" in outputs

def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ class CurrentProgressBar(ProgressBar):
val_batches_seen = 0
test_batches_seen = 0

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
super().on_train_batch_start(trainer, pl_module, batch, batch_idx)
assert self.train_batch_idx == trainer.fit_loop.batch_idx

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
assert self.train_batch_idx == trainer.fit_loop.batch_idx + 1
if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0:
assert self.main_progress_bar.n == self.train_batch_idx
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,12 @@ class TestModel(BoringModel):
def configure_optimizers(self):
return OptimizerWithHooks(self)

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
self.count_on_train_batch_start += 1
optimizer = self.optimizers(use_pl_optimizer=False)
assert len(optimizer._fwd_handles) == 1

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
self.count_on_train_batch_end += 1
del self.trainer._lightning_optimizers
gc.collect() # not necessary, just in case
Expand Down
Loading

0 comments on commit 4decbc0

Please sign in to comment.