Skip to content

Commit

Permalink
Merge branch 'master' into feature/SG-000-fix-module
Browse files Browse the repository at this point in the history
# Conflicts:
#	src/super_gradients/training/sg_trainer/sg_trainer.py
  • Loading branch information
BloodAxe committed Jun 28, 2023
2 parents 03540e6 + cb01101 commit c340da8
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from copy import deepcopy
from pathlib import Path
from typing import Union, Tuple, Mapping, Dict, Any, List
from typing import Union, Tuple, Mapping, Dict, Any, List, Optional

import hydra
import numpy as np
Expand Down Expand Up @@ -604,8 +604,9 @@ def _save_checkpoint(
if self.ema:
state["ema_net"] = unwrap_model(self.ema_model.ema).state_dict()

if isinstance(unwrap_model(self.net), HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
state["processing_params"] = self.valid_loader.dataset.get_dataset_preprocessing_params()
processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
state["processing_params"] = processing_params

# SAVES CURRENT MODEL AS ckpt_latest
self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)
Expand Down Expand Up @@ -1239,7 +1240,10 @@ def forward(self, inputs, targets):
train_dataloader_len=len(self.train_loader),
)

self._set_net_preprocessing_from_valid_loader()
processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
self.net.module.set_dataset_processing_params(**processing_params)

try:
# HEADERS OF THE TRAINING PROGRESS
if not silent_mode:
Expand Down Expand Up @@ -1344,10 +1348,12 @@ def forward(self, inputs, targets):
if not self.ddp_silent_mode:
self.sg_logger.close()

def _set_net_preprocessing_from_valid_loader(self):
if isinstance(unwrap_model(self.net), HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
def _get_preprocessing_from_valid_loader(self) -> Optional[dict]:
valid_loader = self.valid_loader

if isinstance(unwrap_model(self.net), HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams):
try:
unwrap_model(self.net).set_dataset_processing_params(**self.valid_loader.dataset.get_dataset_preprocessing_params())
return valid_loader.dataset.get_dataset_preprocessing_params()
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the validation dataset:\n {e}.\n Before calling"
Expand Down

0 comments on commit c340da8

Please sign in to comment.