diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 7553d34d8a..348736de41 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -3,7 +3,7 @@ import os from copy import deepcopy from pathlib import Path -from typing import Union, Tuple, Mapping, Dict, Any, List +from typing import Union, Tuple, Mapping, Dict, Any, List, Optional import hydra import numpy as np @@ -604,8 +604,9 @@ def _save_checkpoint( if self.ema: state["ema_net"] = unwrap_model(self.ema_model.ema).state_dict() - if isinstance(unwrap_model(self.net), HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams): - state["processing_params"] = self.valid_loader.dataset.get_dataset_preprocessing_params() + processing_params = self._get_preprocessing_from_valid_loader() + if processing_params is not None: + state["processing_params"] = processing_params # SAVES CURRENT MODEL AS ckpt_latest self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch) @@ -1239,7 +1240,10 @@ def forward(self, inputs, targets): train_dataloader_len=len(self.train_loader), ) - self._set_net_preprocessing_from_valid_loader() + processing_params = self._get_preprocessing_from_valid_loader() + if processing_params is not None: + self.net.module.set_dataset_processing_params(**processing_params) + try: # HEADERS OF THE TRAINING PROGRESS if not silent_mode: @@ -1344,10 +1348,12 @@ def forward(self, inputs, targets): if not self.ddp_silent_mode: self.sg_logger.close() - def _set_net_preprocessing_from_valid_loader(self): - if isinstance(unwrap_model(self.net), HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams): + def _get_preprocessing_from_valid_loader(self) -> Optional[dict]: + valid_loader = self.valid_loader + + if isinstance(unwrap_model(self.net), HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams): try: - unwrap_model(self.net).set_dataset_processing_params(**self.valid_loader.dataset.get_dataset_preprocessing_params()) + return valid_loader.dataset.get_dataset_preprocessing_params() except Exception as e: logger.warning( f"Could not set preprocessing pipeline from the validation dataset:\n {e}.\n Before calling"