diff --git a/documentation/source/PhaseCallbacks.md b/documentation/source/PhaseCallbacks.md index 58c16b8cf3..cd8512ef2c 100644 --- a/documentation/source/PhaseCallbacks.md +++ b/documentation/source/PhaseCallbacks.md @@ -1,7 +1,8 @@ # Phase Callbacks Integrating your own code into an already existing training pipeline can draw much effort on the user's end. -To tackle this challenge, a list of callables triggered at specific points of the training code can be passed through `phase_calbacks_list` inside `training_params` when calling `Trainer.train(...)`. +To tackle this challenge, a list of callables triggered at specific points of the training code can +be passed through `training_params.phase_calbacks_list` when calling `Trainer.train(...)`. SG's `super_gradients.training.utils.callbacks` module implements some common use cases as callbacks: @@ -20,12 +21,12 @@ SG's `super_gradients.training.utils.callbacks` module implements some common us TrainingStageSwitchCallbackBase YoloXTrainingStageSwitchCallback -For example, the YoloX's COCO detection training recipe uses `YoloXTrainingStageSwitchCallback` to turn off augmentations and incorporate L1 loss starting from epoch 285: +For example, the YoloX's COCO detection training recipe uses `YoloXTrainingStageSwitchCallback` to turn +off augmentations and incorporate L1 loss starting from epoch 285: `super_gradients/recipes/training_hyperparams/coco2017_yolox_train_params.yaml`: ```yaml - max_epochs: 300 ... @@ -39,118 +40,155 @@ phase_callbacks: ... ``` -Another example is how we use `BinarySegmentationVisualizationCallback` to visualize predictions during training in our [Segmentation Transfer Learning Notebook](https://bit.ly/3qKwMbe): - - -## Integrating Your Code with Callbacks +Another example is how we use `BinarySegmentationVisualizationCallback` to visualize predictions +during training in the [Segmentation Transfer Learning Notebook](https://bit.ly/3qKwMbe): -Integrating your code requires implementing a callback that `Trainer` would trigger in the proper phases inside SG's training pipeline. -So let's first get familiar with `super_gradients.training.utils.callbacks.base_callbacks.Callback` class. +### How Callbacks work -It implements the following methods: +`Callback` implements the following methods: ```python - on_training_start(self, context: PhaseContext) -> None - on_train_loader_start(self, context: PhaseContext) -> None: - on_train_batch_start(self, context: PhaseContext) -> None: - on_train_batch_loss_end(self, context: PhaseContext) -> None: - on_train_batch_backward_end(self, context: PhaseContext) -> None: - on_train_batch_gradient_step_start(self, context: PhaseContext) -> None: - on_train_batch_gradient_step_end(self, context: PhaseContext) -> None: - on_train_batch_end(self, context: PhaseContext) -> None: - on_train_loader_end(self, context: PhaseContext) -> None: - on_validation_loader_start(self, context: PhaseContext) -> None: - on_validation_batch_start(self, context: PhaseContext) -> None: - on_validation_batch_end(self, context: PhaseContext) -> None: - on_validation_loader_end(self, context: PhaseContext) -> None: - on_validation_end_best_epoch(self, context: PhaseContext) -> None: - on_test_loader_start(self, context: PhaseContext) -> None: - on_test_batch_start(self, context: PhaseContext) -> None: - on_test_batch_end(self, context: PhaseContext) -> None: - on_test_loader_end(self, context: PhaseContext) -> None: - on_training_end(self, context: PhaseContext) -> None: +# super_gradients.training.utils.callbacks.base_callbacks.Callback +class Callback: + def on_training_start(self, context: PhaseContext) -> None: pass + def on_train_loader_start(self, context: PhaseContext) -> None: pass + def on_train_batch_start(self, context: PhaseContext) -> None: pass + def on_train_batch_loss_end(self, context: PhaseContext) -> None: pass + def on_train_batch_backward_end(self, context: PhaseContext) -> None: pass + def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None: pass + def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None: pass + def on_train_batch_end(self, context: PhaseContext) -> None: pass + def on_train_loader_end(self, context: PhaseContext) -> None: pass + def on_validation_loader_start(self, context: PhaseContext) -> None: pass + def on_validation_batch_start(self, context: PhaseContext) -> None: pass + def on_validation_batch_end(self, context: PhaseContext) -> None: pass + def on_validation_loader_end(self, context: PhaseContext) -> None: pass + def on_validation_end_best_epoch(self, context: PhaseContext) -> None: pass + def on_test_loader_start(self, context: PhaseContext) -> None: pass + def on_test_batch_start(self, context: PhaseContext) -> None: pass + def on_test_batch_end(self, context: PhaseContext) -> None: pass + def on_test_loader_end(self, context: PhaseContext) -> None: pass + def on_training_end(self, context: PhaseContext) -> None: pass +``` +The order of the events is as follows: +```python +on_training_start(context) # called once before training starts, good for setting up the warmup LR + + for epoch in range(epochs): + on_train_loader_start(context) + for batch in train_loader: + on_train_batch_start(context) + on_train_batch_loss_end(context) # called after loss has been computed + on_train_batch_backward_end(context) # called after .backward() was called + on_train_batch_gradient_step_start(context) # called before the optimizer step about to happen (gradient clipping, logging of gradients) + on_train_batch_gradient_step_end(context) # called after gradient step was done, good place to update LR (for step-based schedulers) + on_train_batch_end(context) + on_train_loader_end(context) + + on_validation_loader_start(context) + for batch in validation_loader: + on_validation_batch_start(context) + on_validation_batch_end(context) + on_validation_loader_end(context) + on_validation_end_best_epoch(context) + + on_test_start(context) + for batch in test_loader: + on_test_batch_start(context) + on_test_batch_end(context) + on_test_end(context) + +on_training_end(context) # called once after training ends. ``` -Our callback needs to inherit from the above class and override the appropriate methods according to the points at which we would like to trigger it. +Callbacks are implemented by inheriting this `Callback` class, and then by override any of the above-mentioned +method with the wanted behavior. -To understand which methods we need to override, we need to understand better when are the above methods triggered. +### Phase Context -From the class docs, the order of the events is as follows: -```python - on_training_start(context) # called once before training starts, good for setting up the warmup LR - - for epoch in range(epochs): - on_train_loader_start(context) - for batch in train_loader: - on_train_batch_start(context) - on_train_batch_loss_end(context) # called after loss has been computed - on_train_batch_backward_end(context) # called after .backward() was called - on_train_batch_gradient_step_start(context) # called before the optimizer step about to happen (gradient clipping, logging of gradients) - on_train_batch_gradient_step_end(context) # called after gradient step was done, good place to update LR (for step-based schedulers) - on_train_batch_end(context) - on_train_loader_end(context) - - on_validation_loader_start(context) - for batch in validation_loader: - on_validation_batch_start(context) - on_validation_batch_end(context) - on_validation_loader_end(context) - on_validation_end_best_epoch(context) - - on_test_start(context) - for batch in test_loader: - on_test_batch_start(context) - on_test_batch_end(context) - on_test_end(context) - - on_training_end(context) # called once after training ends. +You may have noticed that the `Callback`'s methods expect a single argument - a `PhaseContext` instance. +`PhaseContext` includes attributes representing a wide range of training attributes at a given point of the training. +``` + - epoch + - batch_idx + - optimizer + - metrics_dict + - inputs + - preds + - target + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - test_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - architecture + - arch_params + - metric_to_watch + - valid_metrics + - ema_model + - loss_logging_items_names ``` -As you noticed, all of `Callback`'s methods expect a single argument - a `PhaseContext` instance. -This argument gives access to some variables at the points mentioned above in the code through its attributes. -We can discover what variables are exposed by looking at the documentation of the `Callback`'s specific methods we need to override. +Each of these attributes is set to `None` by default, up until the point it computed or defined in the training pipeline. +- E.g. `epoch` will be `None` within `on_training_start` because, as explained above, this steps happens before the first epoch begins -For example: +You can find which context attribute is set by looking into each method docstring: ```python -... +class Callback: + + ... + def on_training_start(self, context: PhaseContext) -> None: """ - Called once before the start of the first epoch - At this point, the context argument is guaranteed to have the following attributes: - - optimizer - - net - - checkpoints_dir_path - - criterion - - sg_logger - - train_loader - - valid_loader - - training_params - - checkpoint_params - - architecture - - arch_params - - metric_to_watch - - device - - ema_model - ... - :return: + Called once before start of the first epoch + At this point, the context argument will have the following attributes: + - optimizer + - criterion + - device + - experiment_name + - ckpt_dir + - net + - sg_logger + - train_loader + - valid_loader + - training_params + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + + The corresponding Phase enum value for this event is Phase.PRE_TRAINING. + :param context: """ + pass ``` -Now let's implement our callback. -Suppose we would like to implement a simple callback that saves the first batch of images in each epoch for both training and validation -in a new folder called "batch_images" under our local checkpoints directory. +### Build your own Callback +Suppose we would like to implement a simple callback that saves the first batch of images in each epoch for both +training and validation in a new folder called "batch_images" under the local checkpoints directory. -Our callback needs to be triggered in 3 places: -1. At the start of training, create a new "batch_images" under our local checkpoints directory. -2. Before passing a train image batch through the network. -3. Before passing a validation image batch through the network. +This callback needs to be triggered in 3 places: +1. At the start of training, create a new "batch_images" under the local checkpoints directory. +2. Before passing a train image batch through the network, save it in the new folder. +3. Before passing a validation image batch through the network, save it in the new folder. -Therefore, our callback will override `Callback`'s `on_training_start`, `on_train_batch_start`, and `on_validation_batch_start` methods: +Therefore, the callback will override `Callback`'s `on_training_start`, `on_train_batch_start`, and `on_validation_batch_start` methods: ```python from super_gradients.training.utils.callbacks import Callback, PhaseContext @@ -179,35 +217,41 @@ class SaveFirstBatchCallback(Callback): if context.batch_idx == 0 and not self.saved_first_validation_batch: save_image(context.inputs, os.path.join(self.outputs_path, f"first_validation_batch_epoch_{context.epoch}.png")) self.saved_first_validation_batch = True - - ``` +**IMPORTANT** -Note the `@multi_process_safe` decorator, which allows the callback to be triggered precisely once when running distributed training. +When training on multiple nodes (see [DDP](device.md)), the callback will be called at each step once for every +node you are working with. This behaviour may be useful in some specific cases, but in general you will +want to have each method to be triggered only once per step. You can add the decorator `@multi_process_safe` to ensure +that only the main node will trigger the callback. -For coded training scripts (i.e., not [using configuration files](configuration_files.md)), we can pass an instance of the callback through `phase_callbacks`: +In our example, we want to trigger only once per step, so we need to add the `@multi_process_safe` decorator. - ```python -... +### Using Custom Callback within Python Script +The callback can directly be passed through `training_params.phase_callbacks` -... +```python trainer = Trainer("my_experiment") train_dataloader = ... valid_dataloader = ... model = ... train_params = { - ... "loss": "cross_entropy", - "criterion_params": {} - ... + "criterion_params": {}, "phase_callbacks": [SaveFirstBatchCallback()], + ... } trainer.train(training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader) ``` -Otherwise, for training with configuration files, we need to register our new callback by decorating it with the `register_loss` decorator: +### Using Custom Callback in a Recipe +If you are working with [Configuration files](configuration_files.md), you will be required to do an extra step. +This is similar to using any custom objects in a recipe, and is already defined in the [above-mentioned](configuration_files.md). + +To summarize, you need to register the new callback by decorating it with the `register_callback` decorator, +so that SuperGradients would know how to instantiate it from the `.yaml` recipe. ```python from super_gradients.training.utils.callbacks import Callback, PhaseContext @@ -254,7 +298,7 @@ phase_callbacks: - SaveFirstBatchCallback ``` -Last, in your ``my_train_from_recipe_script.py`` file, import the newly registered class (even though the class itself is unused, just to trigger the registry): +Last, make sure to import `SaveFirstBatchCallback` in the script you use to launch training from config: ```python @@ -278,3 +322,6 @@ Last, in your ``my_train_from_recipe_script.py`` file, import the newly register if __name__ == "__main__": run() ``` + +This is required, as otherwise `SaveFirstBatchCallback` would not be imported at all and therefore SuperGradients +would fail to recognize and instantiate it. diff --git a/src/super_gradients/training/utils/callbacks/base_callbacks.py b/src/super_gradients/training/utils/callbacks/base_callbacks.py index bbb5c628e1..daa790af97 100644 --- a/src/super_gradients/training/utils/callbacks/base_callbacks.py +++ b/src/super_gradients/training/utils/callbacks/base_callbacks.py @@ -1,6 +1,12 @@ from enum import Enum from typing import List +from typing import Optional +import torch +from torchmetrics.collections import MetricCollection +from torch.utils.data.dataloader import DataLoader +from torch.nn.modules.loss import _Loss + __all__ = ["Phase", "PhaseCallback", "PhaseContext", "CallbackHandler", "Callback"] @@ -28,35 +34,35 @@ class PhaseContext: def __init__( self, - epoch=None, - batch_idx=None, - optimizer=None, + epoch: Optional[int] = None, + batch_idx: Optional[int] = None, + optimizer: Optional[torch.optim.Optimizer] = None, metrics_dict=None, - inputs=None, - preds=None, - target=None, - metrics_compute_fn=None, - loss_avg_meter=None, - loss_log_items=None, - criterion=None, - device=None, - experiment_name=None, - ckpt_dir=None, - net=None, - lr_warmup_epochs=None, - sg_logger=None, - train_loader=None, - valid_loader=None, - test_loader=None, - training_params=None, - ddp_silent_mode=None, - checkpoint_params=None, - architecture=None, - arch_params=None, - metric_to_watch=None, - valid_metrics=None, - ema_model=None, - loss_logging_items_names=None, + inputs: Optional[torch.Tensor] = None, + preds: Optional[torch.Tensor] = None, + target: Optional[torch.Tensor] = None, + metrics_compute_fn: Optional[MetricCollection] = None, + loss_avg_meter: Optional["AverageMeter"] = None, # noqa: ignore + loss_log_items: Optional[torch.Tensor] = None, + criterion: Optional[_Loss] = None, + device: Optional[str] = None, + experiment_name: Optional[str] = None, + ckpt_dir: Optional[str] = None, + net: Optional["SgModule"] = None, # noqa: ignore + lr_warmup_epochs: Optional[int] = None, + sg_logger: Optional["BaseSGLogger"] = None, # noqa: ignore + train_loader: Optional[DataLoader] = None, + valid_loader: Optional[DataLoader] = None, + test_loader: Optional[DataLoader] = None, + training_params: Optional["TrainingParams"] = None, # noqa: ignore + ddp_silent_mode: Optional[bool] = None, + checkpoint_params: Optional["HpmStruct"] = None, # noqa: ignore + architecture: Optional = None, + arch_params: Optional["HpmStruct"] = None, # noqa: ignore + metric_to_watch: Optional[str] = None, + valid_metrics: Optional[MetricCollection] = None, # noqa: ignore + ema_model: Optional["SgModule"] = None, # noqa: ignore + loss_logging_items_names: Optional[List[str]] = None, ): self.epoch = epoch self.batch_idx = batch_idx @@ -165,36 +171,47 @@ class Callback: def on_training_start(self, context: PhaseContext) -> None: """ Called once before start of the first epoch - At this point, the context argument is guaranteed to have the following attributes: - - optimizer - - net - - checkpoints_dir_path - - criterion - - sg_logger - - train_loader - - valid_loader - - training_params - - checkpoint_params - - architecture - - arch_params - - metric_to_watch - - device - - ema_model + At this point, the context argument will have the following attributes: + - optimizer + - criterion + - device + - experiment_name + - ckpt_dir + - net + - sg_logger + - train_loader + - valid_loader + - training_params + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics The corresponding Phase enum value for this event is Phase.PRE_TRAINING. :param context: - :return: """ pass def on_train_loader_start(self, context: PhaseContext) -> None: """ Called each epoch at the start of train data loader (before getting the first batch). - At this point, the context argument is guaranteed to have the following attributes: - - epoch + At this point, the context argument will have the following attributes: + - optimizer + - criterion + - device + - experiment_name + - ckpt_dir + - net + - sg_logger + - train_loader + - valid_loader + - training_params + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START. :param context: - :return: """ pass @@ -203,37 +220,103 @@ def on_train_batch_start(self, context: PhaseContext) -> None: Called at each batch after getting batch of data from data loader and moving it to target device. This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined). - At this point the context argument is guaranteed to have the following attributes: - - batch_idx - - inputs - - targets - - **additional_batch_items + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - target + - metrics_compute_fn + - loss_avg_meter + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics :param context: - :return: """ pass def on_train_batch_loss_end(self, context: PhaseContext) -> None: """ Called after model forward and loss computation has been done. - At this point the context argument is guaranteed to have the following attributes: - - preds - - loss_log_items + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END. :param context: - :return: """ - pass def on_train_batch_backward_end(self, context: PhaseContext) -> None: """ Called after loss.backward() method was called for a given batch + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names :param context: - :return: """ pass @@ -241,143 +324,550 @@ def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None: """ Called before the graadient step is about to happen. Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + :param context: - :return: """ pass def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None: """ Called after gradient step has been performed. Good place to update LR (for step-based schedulers) + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - inputs + - target + - metrics_compute_fn + - loss_avg_meter + - criterion + - device + - stop_training + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP. :param context: - :return: """ pass def on_train_batch_end(self, context: PhaseContext) -> None: """ Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names :param context: - :return: """ - pass def on_train_loader_end(self, context: PhaseContext) -> None: """ Called each epoch at the end of train data loader (after processing the last batch). + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END. :param context: - :return: """ - pass def on_validation_loader_start(self, context: PhaseContext) -> None: """ Called each epoch at the start of validation data loader (before getting the first batch). + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + :param context: - :return: """ - pass def on_validation_batch_start(self, context: PhaseContext) -> None: """ Called at each batch after getting batch of data from validation loader and moving it to target device. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - inputs + - target + - metrics_compute_fn + - loss_avg_meter + - criterion + - device + - stop_training + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - loss_logging_items_names + :param context: - :return: """ pass def on_validation_batch_end(self, context: PhaseContext) -> None: """ Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - inputs + - preds + - target + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END. :param context: - :return: """ pass def on_validation_loader_end(self, context: PhaseContext) -> None: """ Called each epoch at the end of validation data loader (after processing the last batch). + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END. :param context: - :return: """ pass def on_validation_end_best_epoch(self, context: PhaseContext) -> None: """ Called each epoch after validation has been performed and the best metric has been achieved. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH. :param context: - :return: """ pass def on_test_loader_start(self, context: PhaseContext) -> None: """ Called once at the start of test data loader (before getting the first batch). + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + :param context: - :return: """ - pass def on_test_batch_start(self, context: PhaseContext) -> None: """ Called at each batch after getting batch of data from test loader and moving it to target device. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + :param context: - :return: """ pass def on_test_batch_end(self, context: PhaseContext) -> None: """ Called after all forward step have been performed for a given batch and there is nothing left to do. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.TEST_BATCH_END. :param context: - :return: """ pass def on_test_loader_end(self, context: PhaseContext) -> None: """ Called once at the end of test data loader (after processing the last batch). + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.TEST_END. :param context: - :return: """ pass def on_average_best_models_validation_start(self, context: PhaseContext) -> None: """ Called once after the test was end before the training loop has finished. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START. :param context: - :return: """ pass def on_average_best_models_validation_end(self, context: PhaseContext) -> None: """ Called once after the average model validation has finished. + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_dict + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START. :param context: - :return: """ pass def on_training_end(self, context: PhaseContext) -> None: """ Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.) + At this point, the context argument will have the following attributes: + - epoch + - batch_idx + - optimizer + - inputs + - preds + - target + - metrics_compute_fn + - loss_avg_meter + - loss_log_items + - criterion + - device + - stop_training + - experiment_name + - ckpt_dir + - net + - lr_warmup_epochs + - sg_logger + - train_loader + - valid_loader + - training_params + - ddp_silent_mode + - checkpoint_params + - arch_params + - metric_to_watch + - valid_metrics + - loss_logging_items_names + The corresponding Phase enum value for this event is Phase.POST_TRAINING. :param context: - :return: """ pass