diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index a6017c646e..8e7943b15d 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -64,7 +64,14 @@ wait_for_the_master, DDPNotSetupException, ) -from super_gradients.common.environment.ddp_utils import get_local_rank, require_ddp_setup, is_ddp_subprocess, get_world_size, get_device_ids +from super_gradients.common.environment.ddp_utils import ( + get_local_rank, + require_ddp_setup, + is_ddp_subprocess, + get_world_size, + get_device_ids, + broadcast_from_master, +) from super_gradients.training.utils.ema import ModelEMA from super_gradients.training.utils.optimizer_utils import build_optimizer from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params @@ -1387,7 +1394,9 @@ def forward(self, inputs, targets): if not silent_mode: logger.info(f"Started training for {self.max_epochs - self.start_epoch} epochs ({self.start_epoch}/" f"{self.max_epochs - 1})\n") for epoch in range(self.start_epoch, self.max_epochs): - if context.stop_training: + # broadcast_from_master is necessary here, since in DDP mode, only the master node will + # receive the Ctrl-C signal, and we want all nodes to stop training. + if broadcast_from_master(context.stop_training): logger.info("Request to stop training has been received, stopping training") break @@ -1506,6 +1515,7 @@ def forward(self, inputs, targets): self.phase_callback_handler.on_average_best_models_validation_end(context) except KeyboardInterrupt: + context.update_context(stop_training=True) logger.info( "\n[MODEL TRAINING EXECUTION HAS BEEN INTERRUPTED]... Please wait until SOFT-TERMINATION process " "finishes and saves all of the Model Checkpoints and log files before terminating..."