Skip to content

Commit 2546323

Browse files
BloodAxeshaydeci
andauthored
Getting rid of "module." heritage (#1184)
* Getting rid of "module." heritage * Remove import of "WrappedModel", * Remove remaining usages of .net.module * Remove remaining usages of .net.module * Remove remaining usages of .net.module * Remove remaining usages of .net.module * Put back WrappedModel class in place, but add deprecation warning when someone is about to use it to indicate it is deprecated. * Fix _yolox_ckpt_solver (updating condition to account missing "module.") * Change python3.8 to python * Reorder tests * Add missing unwrap_model after merge with master --------- Co-authored-by: Shay Aharon <[email protected]>
1 parent ab2e792 commit 2546323

File tree

12 files changed

+114
-71
lines changed

12 files changed

+114
-71
lines changed

Makefile

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ yolo_nas_integration_tests:
88
python -m unittest tests/integration_tests/yolo_nas_integration_test.py
99

1010
recipe_accuracy_tests:
11-
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
12-
python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
13-
python3.8 src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
14-
python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
15-
python3.8 src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
11+
python src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
12+
python src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
13+
python src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
14+
python src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
15+
python src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
1616
coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py

src/super_gradients/training/kd_trainer/kd_trainer.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
2828
from super_gradients.training.utils.distributed_training_utils import setup_device
2929
from super_gradients.training.utils.ema import KDModelEMA
30+
from super_gradients.training.utils.utils import unwrap_model
3031

3132
logger = get_logger(__name__)
3233

@@ -211,7 +212,7 @@ def _load_checkpoint_to_model(self):
211212
the entire KD network following the same logic as in Trainer.
212213
"""
213214
teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path")
214-
teacher_net = self.net.module.teacher
215+
teacher_net = unwrap_model(self.net).teacher
215216

216217
if teacher_checkpoint_path is not None:
217218

@@ -271,12 +272,12 @@ def _save_best_checkpoint(self, epoch, state):
271272
Overrides parent best_ckpt saving to modify the state dict so that we only save the student.
272273
"""
273274
if self.ema:
274-
best_net = core_utils.WrappedModel(self.ema_model.ema.module.student)
275+
best_net = self.ema_model.ema.student
275276
state.pop("ema_net")
276277
else:
277-
best_net = core_utils.WrappedModel(self.net.module.student)
278+
best_net = self.net.student
278279

279-
state["net"] = best_net.state_dict()
280+
state["net"] = unwrap_model(best_net).state_dict()
280281
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
281282

282283
def train(

src/super_gradients/training/sg_trainer/sg_trainer.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from super_gradients.training.utils.ema import ModelEMA
6767
from super_gradients.training.utils.optimizer_utils import build_optimizer
6868
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
69-
from super_gradients.training.utils.utils import fuzzy_idx_in_list
69+
from super_gradients.training.utils.utils import fuzzy_idx_in_list, unwrap_model
7070
from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
7171
from super_gradients.training.metrics import Accuracy, Top5
7272
from super_gradients.training.utils import random_seed
@@ -396,9 +396,6 @@ def _net_to_device(self):
396396
local_rank = int(device_config.device.split(":")[1])
397397
self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
398398

399-
else:
400-
self.net = core_utils.WrappedModel(self.net)
401-
402399
def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
403400
"""
404401
train_epoch - A single epoch training procedure
@@ -601,15 +598,16 @@ def _save_checkpoint(
601598
metric = validation_results_dict[self.metric_to_watch]
602599

603600
# BUILD THE state_dict
604-
state = {"net": self.net.state_dict(), "acc": metric, "epoch": epoch}
601+
state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch}
602+
605603
if optimizer is not None:
606604
state["optimizer_state_dict"] = optimizer.state_dict()
607605

608606
if self.scaler is not None:
609607
state["scaler_state_dict"] = self.scaler.state_dict()
610608

611609
if self.ema:
612-
state["ema_net"] = self.ema_model.ema.state_dict()
610+
state["ema_net"] = unwrap_model(self.ema_model.ema).state_dict()
613611

614612
processing_params = self._get_preprocessing_from_valid_loader()
615613
if processing_params is not None:
@@ -636,7 +634,7 @@ def _save_checkpoint(
636634
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
637635

638636
if self.training_params.average_best_models:
639-
net_for_averaging = self.ema_model.ema if self.ema else self.net
637+
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
640638

641639
state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_dict=validation_results_dict)
642640
self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)
@@ -652,7 +650,7 @@ def _prep_net_for_train(self) -> None:
652650
self._net_to_device()
653651

654652
# SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
655-
self.update_param_groups = hasattr(self.net.module, "update_param_groups")
653+
self.update_param_groups = hasattr(unwrap_model(self.net), "update_param_groups")
656654

657655
self.checkpoint = {}
658656
self.strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON)
@@ -1161,7 +1159,9 @@ def forward(self, inputs, targets):
11611159
if not self.ddp_silent_mode:
11621160
if self.training_params.dataset_statistics:
11631161
dataset_statistics_logger = DatasetStatisticsTensorboardLogger(self.sg_logger)
1164-
dataset_statistics_logger.analyze(self.train_loader, all_classes=self.classes, title="Train-set", anchors=self.net.module.arch_params.anchors)
1162+
dataset_statistics_logger.analyze(
1163+
self.train_loader, all_classes=self.classes, title="Train-set", anchors=unwrap_model(self.net).arch_params.anchors
1164+
)
11651165
dataset_statistics_logger.analyze(self.valid_loader, all_classes=self.classes, title="val-set")
11661166

11671167
sg_trainer_utils.log_uncaught_exceptions(logger)
@@ -1175,7 +1175,7 @@ def forward(self, inputs, targets):
11751175
if isinstance(self.training_params.optimizer, str) or (
11761176
inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer)
11771177
):
1178-
self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr, training_params=self.training_params)
1178+
self.optimizer = build_optimizer(net=unwrap_model(self.net), lr=self.training_params.initial_lr, training_params=self.training_params)
11791179
elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
11801180
self.optimizer = self.training_params.optimizer
11811181
else:
@@ -1248,7 +1248,7 @@ def forward(self, inputs, targets):
12481248

12491249
processing_params = self._get_preprocessing_from_valid_loader()
12501250
if processing_params is not None:
1251-
self.net.module.set_dataset_processing_params(**processing_params)
1251+
unwrap_model(self.net).set_dataset_processing_params(**processing_params)
12521252

12531253
try:
12541254
# HEADERS OF THE TRAINING PROGRESS
@@ -1295,7 +1295,7 @@ def forward(self, inputs, targets):
12951295
num_gpus=get_world_size(),
12961296
)
12971297

1298-
# model switch - we replace self.net.module with the ema model for the testing and saving part
1298+
# model switch - we replace self.net with the ema model for the testing and saving part
12991299
# and then switch it back before the next training epoch
13001300
if self.ema:
13011301
self.ema_model.update_attr(self.net)
@@ -1355,7 +1355,7 @@ def forward(self, inputs, targets):
13551355
def _get_preprocessing_from_valid_loader(self) -> Optional[dict]:
13561356
valid_loader = self.valid_loader
13571357

1358-
if isinstance(self.net.module, HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams):
1358+
if isinstance(unwrap_model(self.net), HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams):
13591359
try:
13601360
return valid_loader.dataset.get_dataset_preprocessing_params()
13611361
except Exception as e:
@@ -1413,7 +1413,7 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
14131413
def hook(module, _):
14141414
module.forward = MultiGPUModeAutocastWrapper(module.forward)
14151415

1416-
self.net.module.register_forward_pre_hook(hook=hook)
1416+
unwrap_model(self.net).register_forward_pre_hook(hook=hook)
14171417

14181418
if self.load_checkpoint:
14191419
scaler_state_dict = core_utils.get_param(self.checkpoint, "scaler_state_dict")
@@ -1439,7 +1439,7 @@ def _validate_final_average_model(self, cleanup_snapshots_pkl_file=False):
14391439
with wait_for_the_master(local_rank):
14401440
average_model_sd = read_ckpt_state_dict(average_model_ckpt_path)["net"]
14411441

1442-
self.net.load_state_dict(average_model_sd)
1442+
unwrap_model(self.net).load_state_dict(average_model_sd)
14431443
# testing the averaged model and save instead of best model if needed
14441444
averaged_model_results_dict = self._validate_epoch(epoch=self.max_epochs)
14451445

@@ -1462,7 +1462,7 @@ def get_arch_params(self):
14621462

14631463
@property
14641464
def get_structure(self):
1465-
return self.net.module.structure
1465+
return unwrap_model(self.net).structure
14661466

14671467
@property
14681468
def get_architecture(self):
@@ -1494,7 +1494,8 @@ def _re_build_model(self, arch_params={}):
14941494

14951495
if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
14961496
logger.warning("Warning: distributed training is not supported in re_build_model()")
1497-
self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids()) if device_config.multi_gpu else core_utils.WrappedModel(self.net)
1497+
if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
1498+
self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids())
14981499

14991500
@property
15001501
def get_module(self):
@@ -1635,7 +1636,7 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
16351636
if "model_name" in get_callable_param_names(sg_logger_cls.__init__):
16361637
if sg_logger_params.get("model_name") is None:
16371638
# Use the model name used in `models.get(...)` if relevant
1638-
sg_logger_params["model_name"] = get_model_name(self.net.module)
1639+
sg_logger_params["model_name"] = get_model_name(unwrap_model(self.net))
16391640

16401641
if sg_logger_params["model_name"] is None:
16411642
raise ValueError(

src/super_gradients/training/utils/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from super_gradients.training.utils.utils import (
22
Timer,
33
HpmStruct,
4-
WrappedModel,
54
convert_to_tensor,
65
get_param,
76
tensor_container_to_device,
@@ -17,7 +16,6 @@
1716
__all__ = [
1817
"Timer",
1918
"HpmStruct",
20-
"WrappedModel",
2119
"convert_to_tensor",
2220
"get_param",
2321
"tensor_container_to_device",

src/super_gradients/training/utils/callbacks/callbacks.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
2424
from super_gradients.common.environment.ddp_utils import multi_process_safe
2525
from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path
26-
26+
from super_gradients.training.utils.utils import unwrap_model
2727

2828
logger = get_logger(__name__)
2929

@@ -75,7 +75,7 @@ def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_bat
7575
self.atol = kwargs.get("atol", 1e-05)
7676

7777
def __call__(self, context: PhaseContext):
78-
model = copy.deepcopy(context.net.module)
78+
model = copy.deepcopy(unwrap_model(context.net))
7979
model = model.cpu()
8080
model.eval() # Put model into eval mode
8181

@@ -204,12 +204,12 @@ def __call__(self, context: PhaseContext) -> None:
204204
:param context: Training phase context
205205
"""
206206
try:
207-
model = copy.deepcopy(context.net)
207+
model = copy.deepcopy(unwrap_model(context.net))
208208
model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
209209
model_state_dict = torch.load(model_state_dict_path)["net"]
210210
model.load_state_dict(state_dict=model_state_dict)
211211

212-
model = model.module.cpu()
212+
model = model.cpu()
213213
if hasattr(model, "prep_model_for_conversion"):
214214
model.prep_model_for_conversion(input_size=self.input_dimensions)
215215

@@ -267,7 +267,9 @@ def perform_scheduling(self, context: PhaseContext):
267267

268268
def update_lr(self, optimizer, epoch, batch_idx=None):
269269
if self.update_param_groups:
270-
param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len)
270+
param_groups = unwrap_model(self.net).update_param_groups(
271+
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
272+
)
271273
optimizer.param_groups = param_groups
272274
else:
273275
# UPDATE THE OPTIMIZERS PARAMETER
@@ -373,7 +375,9 @@ def update_lr(self, optimizer, epoch, batch_idx=None):
373375
:return:
374376
"""
375377
if self.update_param_groups:
376-
param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len)
378+
param_groups = unwrap_model(self.net).update_param_groups(
379+
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
380+
)
377381
optimizer.param_groups = param_groups
378382
else:
379383
# UPDATE THE OPTIMIZERS PARAMETER

src/super_gradients/training/utils/checkpoint_utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from super_gradients.module_interfaces import HasPredict
1515
from super_gradients.training.pretrained_models import MODEL_URLS
1616
from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
17+
from super_gradients.training.utils.utils import unwrap_model
1718

1819
try:
1920
from torch.hub import download_url_to_file, load_state_dict_from_url
@@ -54,6 +55,13 @@ def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Uni
5455
:return:
5556
"""
5657
state_dict = state_dict["net"] if "net" in state_dict else state_dict
58+
59+
# This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper
60+
# and contains "module." prefix in all keys
61+
# If all keys start with "module.", then we remove it.
62+
if all([key.startswith("module.") for key in state_dict.keys()]):
63+
state_dict = collections.OrderedDict([(key[7:], value) for key, value in state_dict.items()])
64+
5765
try:
5866
strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
5967
net.load_state_dict(state_dict, strict=strict_bool)
@@ -217,6 +225,8 @@ def load_checkpoint_to_model(
217225
if isinstance(strict, str):
218226
strict = StrictLoad(strict)
219227

228+
net = unwrap_model(net)
229+
220230
if load_backbone and not hasattr(net, "backbone"):
221231
raise ValueError("No backbone attribute in net - Can't load backbone weights")
222232

@@ -239,7 +249,7 @@ def load_checkpoint_to_model(
239249
message_model = "model" if not load_backbone else "model's backbone"
240250
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)
241251

242-
if (isinstance(net, HasPredict) or (hasattr(net, "module") and isinstance(net.module, HasPredict))) and load_processing_params:
252+
if (isinstance(net, HasPredict)) and load_processing_params:
243253
if "processing_params" not in checkpoint.keys():
244254
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
245255
try:
@@ -275,7 +285,7 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
275285

276286
if (
277287
ckpt_val.shape != model_val.shape
278-
and ckpt_key == "module._backbone._modules_list.0.conv.conv.weight"
288+
and (ckpt_key == "module._backbone._modules_list.0.conv.conv.weight" or ckpt_key == "_backbone._modules_list.0.conv.conv.weight")
279289
and model_key == "_backbone._modules_list.0.conv.weight"
280290
):
281291
model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]

0 commit comments

Comments
 (0)