From de36cfc89479699dfb6b996c199dc36e8b63ece2 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 12 Sep 2023 18:00:44 +0300 Subject: [PATCH] Correctly handle Ctrl-C event (#1461) --- .../training/sg_trainer/sg_trainer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index a6017c646e..8e7943b15d 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -64,7 +64,14 @@ wait_for_the_master, DDPNotSetupException, ) -from super_gradients.common.environment.ddp_utils import get_local_rank, require_ddp_setup, is_ddp_subprocess, get_world_size, get_device_ids +from super_gradients.common.environment.ddp_utils import ( + get_local_rank, + require_ddp_setup, + is_ddp_subprocess, + get_world_size, + get_device_ids, + broadcast_from_master, +) from super_gradients.training.utils.ema import ModelEMA from super_gradients.training.utils.optimizer_utils import build_optimizer from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params @@ -1387,7 +1394,9 @@ def forward(self, inputs, targets): if not silent_mode: logger.info(f"Started training for {self.max_epochs - self.start_epoch} epochs ({self.start_epoch}/" f"{self.max_epochs - 1})\n") for epoch in range(self.start_epoch, self.max_epochs): - if context.stop_training: + # broadcast_from_master is necessary here, since in DDP mode, only the master node will + # receive the Ctrl-C signal, and we want all nodes to stop training. + if broadcast_from_master(context.stop_training): logger.info("Request to stop training has been received, stopping training") break @@ -1506,6 +1515,7 @@ def forward(self, inputs, targets): self.phase_callback_handler.on_average_best_models_validation_end(context) except KeyboardInterrupt: + context.update_context(stop_training=True) logger.info( "\n[MODEL TRAINING EXECUTION HAS BEEN INTERRUPTED]... Please wait until SOFT-TERMINATION process " "finishes and saves all of the Model Checkpoints and log files before terminating..."