66
66
from super_gradients .training .utils .ema import ModelEMA
67
67
from super_gradients .training .utils .optimizer_utils import build_optimizer
68
68
from super_gradients .training .utils .sg_trainer_utils import MonitoredValue , log_main_training_params
69
- from super_gradients .training .utils .utils import fuzzy_idx_in_list
69
+ from super_gradients .training .utils .utils import fuzzy_idx_in_list , unwrap_model
70
70
from super_gradients .training .utils .weight_averaging_utils import ModelWeightAveraging
71
71
from super_gradients .training .metrics import Accuracy , Top5
72
72
from super_gradients .training .utils import random_seed
@@ -396,9 +396,6 @@ def _net_to_device(self):
396
396
local_rank = int (device_config .device .split (":" )[1 ])
397
397
self .net = torch .nn .parallel .DistributedDataParallel (self .net , device_ids = [local_rank ], output_device = local_rank , find_unused_parameters = True )
398
398
399
- else :
400
- self .net = core_utils .WrappedModel (self .net )
401
-
402
399
def _train_epoch (self , epoch : int , silent_mode : bool = False ) -> tuple :
403
400
"""
404
401
train_epoch - A single epoch training procedure
@@ -601,15 +598,16 @@ def _save_checkpoint(
601
598
metric = validation_results_dict [self .metric_to_watch ]
602
599
603
600
# BUILD THE state_dict
604
- state = {"net" : self .net .state_dict (), "acc" : metric , "epoch" : epoch }
601
+ state = {"net" : unwrap_model (self .net ).state_dict (), "acc" : metric , "epoch" : epoch }
602
+
605
603
if optimizer is not None :
606
604
state ["optimizer_state_dict" ] = optimizer .state_dict ()
607
605
608
606
if self .scaler is not None :
609
607
state ["scaler_state_dict" ] = self .scaler .state_dict ()
610
608
611
609
if self .ema :
612
- state ["ema_net" ] = self .ema_model .ema .state_dict ()
610
+ state ["ema_net" ] = unwrap_model ( self .ema_model .ema ) .state_dict ()
613
611
614
612
processing_params = self ._get_preprocessing_from_valid_loader ()
615
613
if processing_params is not None :
@@ -636,7 +634,7 @@ def _save_checkpoint(
636
634
logger .info ("Best checkpoint overriden: validation " + self .metric_to_watch + ": " + str (metric ))
637
635
638
636
if self .training_params .average_best_models :
639
- net_for_averaging = self .ema_model .ema if self .ema else self .net
637
+ net_for_averaging = unwrap_model ( self .ema_model .ema if self .ema else self .net )
640
638
641
639
state ["net" ] = self .model_weight_averaging .get_average_model (net_for_averaging , validation_results_dict = validation_results_dict )
642
640
self .sg_logger .add_checkpoint (tag = self .average_model_checkpoint_filename , state_dict = state , global_step = epoch )
@@ -652,7 +650,7 @@ def _prep_net_for_train(self) -> None:
652
650
self ._net_to_device ()
653
651
654
652
# SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
655
- self .update_param_groups = hasattr (self .net . module , "update_param_groups" )
653
+ self .update_param_groups = hasattr (unwrap_model ( self .net ) , "update_param_groups" )
656
654
657
655
self .checkpoint = {}
658
656
self .strict_load = core_utils .get_param (self .training_params , "resume_strict_load" , StrictLoad .ON )
@@ -1161,7 +1159,9 @@ def forward(self, inputs, targets):
1161
1159
if not self .ddp_silent_mode :
1162
1160
if self .training_params .dataset_statistics :
1163
1161
dataset_statistics_logger = DatasetStatisticsTensorboardLogger (self .sg_logger )
1164
- dataset_statistics_logger .analyze (self .train_loader , all_classes = self .classes , title = "Train-set" , anchors = self .net .module .arch_params .anchors )
1162
+ dataset_statistics_logger .analyze (
1163
+ self .train_loader , all_classes = self .classes , title = "Train-set" , anchors = unwrap_model (self .net ).arch_params .anchors
1164
+ )
1165
1165
dataset_statistics_logger .analyze (self .valid_loader , all_classes = self .classes , title = "val-set" )
1166
1166
1167
1167
sg_trainer_utils .log_uncaught_exceptions (logger )
@@ -1175,7 +1175,7 @@ def forward(self, inputs, targets):
1175
1175
if isinstance (self .training_params .optimizer , str ) or (
1176
1176
inspect .isclass (self .training_params .optimizer ) and issubclass (self .training_params .optimizer , torch .optim .Optimizer )
1177
1177
):
1178
- self .optimizer = build_optimizer (net = self .net , lr = self .training_params .initial_lr , training_params = self .training_params )
1178
+ self .optimizer = build_optimizer (net = unwrap_model ( self .net ) , lr = self .training_params .initial_lr , training_params = self .training_params )
1179
1179
elif isinstance (self .training_params .optimizer , torch .optim .Optimizer ):
1180
1180
self .optimizer = self .training_params .optimizer
1181
1181
else :
@@ -1248,7 +1248,7 @@ def forward(self, inputs, targets):
1248
1248
1249
1249
processing_params = self ._get_preprocessing_from_valid_loader ()
1250
1250
if processing_params is not None :
1251
- self .net . module .set_dataset_processing_params (** processing_params )
1251
+ unwrap_model ( self .net ) .set_dataset_processing_params (** processing_params )
1252
1252
1253
1253
try :
1254
1254
# HEADERS OF THE TRAINING PROGRESS
@@ -1295,7 +1295,7 @@ def forward(self, inputs, targets):
1295
1295
num_gpus = get_world_size (),
1296
1296
)
1297
1297
1298
- # model switch - we replace self.net.module with the ema model for the testing and saving part
1298
+ # model switch - we replace self.net with the ema model for the testing and saving part
1299
1299
# and then switch it back before the next training epoch
1300
1300
if self .ema :
1301
1301
self .ema_model .update_attr (self .net )
@@ -1355,7 +1355,7 @@ def forward(self, inputs, targets):
1355
1355
def _get_preprocessing_from_valid_loader (self ) -> Optional [dict ]:
1356
1356
valid_loader = self .valid_loader
1357
1357
1358
- if isinstance (self .net . module , HasPredict ) and isinstance (valid_loader .dataset , HasPreprocessingParams ):
1358
+ if isinstance (unwrap_model ( self .net ) , HasPredict ) and isinstance (valid_loader .dataset , HasPreprocessingParams ):
1359
1359
try :
1360
1360
return valid_loader .dataset .get_dataset_preprocessing_params ()
1361
1361
except Exception as e :
@@ -1413,7 +1413,7 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
1413
1413
def hook (module , _ ):
1414
1414
module .forward = MultiGPUModeAutocastWrapper (module .forward )
1415
1415
1416
- self .net . module .register_forward_pre_hook (hook = hook )
1416
+ unwrap_model ( self .net ) .register_forward_pre_hook (hook = hook )
1417
1417
1418
1418
if self .load_checkpoint :
1419
1419
scaler_state_dict = core_utils .get_param (self .checkpoint , "scaler_state_dict" )
@@ -1439,7 +1439,7 @@ def _validate_final_average_model(self, cleanup_snapshots_pkl_file=False):
1439
1439
with wait_for_the_master (local_rank ):
1440
1440
average_model_sd = read_ckpt_state_dict (average_model_ckpt_path )["net" ]
1441
1441
1442
- self .net .load_state_dict (average_model_sd )
1442
+ unwrap_model ( self .net ) .load_state_dict (average_model_sd )
1443
1443
# testing the averaged model and save instead of best model if needed
1444
1444
averaged_model_results_dict = self ._validate_epoch (epoch = self .max_epochs )
1445
1445
@@ -1462,7 +1462,7 @@ def get_arch_params(self):
1462
1462
1463
1463
@property
1464
1464
def get_structure (self ):
1465
- return self .net . module .structure
1465
+ return unwrap_model ( self .net ) .structure
1466
1466
1467
1467
@property
1468
1468
def get_architecture (self ):
@@ -1494,7 +1494,8 @@ def _re_build_model(self, arch_params={}):
1494
1494
1495
1495
if device_config .multi_gpu == MultiGPUMode .DISTRIBUTED_DATA_PARALLEL :
1496
1496
logger .warning ("Warning: distributed training is not supported in re_build_model()" )
1497
- self .net = torch .nn .DataParallel (self .net , device_ids = get_device_ids ()) if device_config .multi_gpu else core_utils .WrappedModel (self .net )
1497
+ if device_config .multi_gpu == MultiGPUMode .DATA_PARALLEL :
1498
+ self .net = torch .nn .DataParallel (self .net , device_ids = get_device_ids ())
1498
1499
1499
1500
@property
1500
1501
def get_module (self ):
@@ -1635,7 +1636,7 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
1635
1636
if "model_name" in get_callable_param_names (sg_logger_cls .__init__ ):
1636
1637
if sg_logger_params .get ("model_name" ) is None :
1637
1638
# Use the model name used in `models.get(...)` if relevant
1638
- sg_logger_params ["model_name" ] = get_model_name (self .net . module )
1639
+ sg_logger_params ["model_name" ] = get_model_name (unwrap_model ( self .net ) )
1639
1640
1640
1641
if sg_logger_params ["model_name" ] is None :
1641
1642
raise ValueError (
0 commit comments