Skip to content

Commit

Permalink
Switch to on_.*_start()
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Feb 24, 2020
1 parent 693fc00 commit 0c788d8
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 74 deletions.
4 changes: 2 additions & 2 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ Callbacks
_save_model,
on_epoch_end,
on_train_end,
on_epoch_begin,
on_epoch_start,
check_monitor_top_k,
on_train_begin,
on_train_start,
2 changes: 1 addition & 1 deletion docs/source/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Loggers
_save_model,
on_epoch_end,
on_train_end,
on_epoch_begin,
on_epoch_start,
14 changes: 7 additions & 7 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,55 @@
class Callback(abc.ABC):
"""Abstract base class used to build new callbacks."""

def on_init_begin(self, trainer, pl_module):
def on_init_start(self, trainer, pl_module):
"""Called when the trainer initialization begins."""
assert pl_module is None

def on_init_end(self, trainer, pl_module):
"""Called when the trainer initialization ends."""
pass

def on_fit_begin(self, trainer, pl_module):
def on_fit_start(self, trainer, pl_module):
"""Called when the fit begins."""
pass

def on_fit_end(self, trainer, pl_module):
"""Called when the fit ends."""
pass

def on_epoch_begin(self, trainer, pl_module):
def on_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
pass

def on_epoch_end(self, trainer, pl_module):
"""Called when the epoch ends."""
pass

def on_batch_begin(self, trainer, pl_module):
def on_batch_start(self, trainer, pl_module):
"""Called when the training batch begins."""
pass

def on_batch_end(self, trainer, pl_module):
"""Called when the training batch ends."""
pass

def on_train_begin(self, trainer, pl_module):
def on_train_start(self, trainer, pl_module):
"""Called when the train begins."""
pass

def on_train_end(self, trainer, pl_module):
"""Called when the train ends."""
pass

def on_validation_begin(self, trainer, pl_module):
def on_validation_start(self, trainer, pl_module):
"""Called when the validation loop begins."""
pass

def on_validation_end(self, trainer, pl_module):
"""Called when the validation loop ends."""
pass

def on_test_begin(self, trainer, pl_module):
def on_test_start(self, trainer, pl_module):
"""Called when the test begins."""
pass

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1

self.on_train_begin(None, None)
self.on_train_start(None, None)

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
Expand All @@ -82,7 +82,7 @@ def check_metrics(self, logs):

return True

def on_train_begin(self, trainer, pl_module):
def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, scheduling: dict):
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())

def on_epoch_begin(self, trainer, pl_module):
def on_epoch_start(self, trainer, pl_module):
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
Expand Down
28 changes: 14 additions & 14 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,70 @@ def __init__(self):
self.callbacks: list[Callback] = []
self.get_model: Callable = None

def on_init_begin(self):
def on_init_start(self):
"""Called when the trainer initialization begins."""
for callback in self.callbacks:
callback.on_init_begin(self, None)
callback.on_init_start(self, None)

def on_init_end(self):
"""Called when the trainer initialization ends."""
for callback in self.callbacks:
callback.on_init_end(self, self.get_model())

def on_fit_begin(self):
def on_fit_start(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_begin(self, self.get_model())
callback.on_fit_start(self, self.get_model())

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end(self, self.get_model())

def on_epoch_begin(self):
def on_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_begin(self, self.get_model())
callback.on_epoch_start(self, self.get_model())

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.get_model())

def on_train_begin(self):
def on_train_start(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_begin(self, self.get_model())
callback.on_train_start(self, self.get_model())

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end(self, self.get_model())

def on_batch_begin(self):
def on_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_begin(self, self.get_model())
callback.on_batch_start(self, self.get_model())

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())

def on_validation_begin(self):
def on_validation_start(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_begin(self, self.get_model())
callback.on_validation_start(self, self.get_model())

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end(self, self.get_model())

def on_test_begin(self):
def on_test_start(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_begin(self, self.get_model())
callback.on_test_start(self, self.get_model())

def on_test_end(self):
"""Called when the test ends."""
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def __init__(self):
self.use_tpu = None

# Callback system
self.on_validation_begin: Callable = None
self.on_validation_start: Callable = None
self.on_validation_end: Callable = None
self.on_test_begin: Callable = None
self.on_test_start: Callable = None
self.on_test_end: Callable = None

@abstractmethod
Expand Down Expand Up @@ -303,9 +303,9 @@ def run_evaluation(self, test=False):

# Validation/Test begin callbacks
if test:
self.on_test_begin()
self.on_test_start()
else:
self.on_validation_begin()
self.on_validation_start()

# hook
model = self.get_model()
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(
Example::
from pytorch_lightning.callbacks import Callback
class PrintCallback(Callback):
def on_train_begin(self):
def on_train_start(self):
print("Training is started!")
def on_train_end(self):
print(f"Training is done. The logs are: {self.trainer.logs}")
Expand Down Expand Up @@ -602,7 +602,7 @@ def on_train_end(self):

# Init callbacks
self.callbacks = callbacks
self.on_init_begin()
self.on_init_start()

# Transfer params
# Backward compatibility
Expand Down Expand Up @@ -925,7 +925,7 @@ def fit(
_set_dataloader(model, test_dataloader, 'test_dataloader')

# Fit begin callbacks
self.on_fit_begin()
self.on_fit_start()

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ def __init__(self):
# Callback system
self.callbacks: list[Callback] = []
self.max_steps = None
self.on_train_begin: Callable = None
self.on_train_start: Callable = None
self.on_train_end: Callable = None
self.on_batch_begin: Callable = None
self.on_batch_start: Callable = None
self.on_batch_end: Callable = None
self.on_epoch_begin: Callable = None
self.on_epoch_start: Callable = None
self.on_epoch_end: Callable = None

@property
Expand Down Expand Up @@ -323,7 +323,7 @@ def train(self):
' but will start from "0" in v0.8.0.', DeprecationWarning)

# Train begin callbacks
self.on_train_begin()
self.on_train_start()

# get model
model = self.get_model()
Expand Down Expand Up @@ -372,7 +372,7 @@ def train(self):
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(self, self.get_model())
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# -----------------
# RUN TNG EPOCH
Expand Down Expand Up @@ -423,7 +423,7 @@ def train(self):
def run_training_epoch(self):

# Epoch begin callbacks
self.on_epoch_begin()
self.on_epoch_start()

# before epoch hook
if self.is_function_implemented('on_epoch_start'):
Expand Down Expand Up @@ -527,7 +527,7 @@ def run_training_batch(self, batch, batch_idx):
return 0, grad_norm_dic, {}

# Batch begin callbacks
self.on_batch_begin()
self.on_batch_start()

# hook
if self.is_function_implemented('on_batch_start'):
Expand Down
Loading

0 comments on commit 0c788d8

Please sign in to comment.