Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting rid of "module." heritage #1184

Merged
merged 25 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2613ece
Getting rid of "module." heritage
BloodAxe Jun 16, 2023
719c140
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 22, 2023
80b17b7
Remove import of "WrappedModel",
BloodAxe Jun 22, 2023
52886c9
Merge remote-tracking branch 'origin/feature/SG-000-fix-module' into …
BloodAxe Jun 22, 2023
c649db8
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 26, 2023
7e72376
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
ba840fe
Merge remote-tracking branch 'origin/feature/SG-000-fix-module' into …
BloodAxe Jun 26, 2023
66618b7
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
8025b83
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
dc4924d
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
74f4259
Merge branch 'master' into feature/SG-000-fix-module
shaydeci Jun 26, 2023
84b9f34
Put back WrappedModel class in place, but add deprecation warning whe…
BloodAxe Jun 27, 2023
553192c
Fix _yolox_ckpt_solver (updating condition to account missing "module.")
BloodAxe Jun 27, 2023
c2e0a10
Change python3.8 to python
BloodAxe Jun 27, 2023
d0ee8c8
Merge remote-tracking branch 'origin/feature/SG-000-fix-module' into …
BloodAxe Jun 27, 2023
49bd2ff
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 27, 2023
341cf32
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 27, 2023
6e19867
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
2781de7
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
03540e6
Reorder tests
BloodAxe Jun 28, 2023
c340da8
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
af468c3
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
d15eb30
Add missing unwrap_model after merge with master
BloodAxe Jun 28, 2023
4040f87
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
2df464a
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.training.utils.ema import KDModelEMA
from super_gradients.training.utils.utils import get_real_model

logger = get_logger(__name__)

Expand Down Expand Up @@ -271,12 +272,12 @@ def _save_best_checkpoint(self, epoch, state):
Overrides parent best_ckpt saving to modify the state dict so that we only save the student.
"""
if self.ema:
best_net = core_utils.WrappedModel(self.ema_model.ema.module.student)
best_net = self.ema_model.ema.student
state.pop("ema_net")
else:
best_net = core_utils.WrappedModel(self.net.module.student)
best_net = self.net.student

state["net"] = best_net.state_dict()
state["net"] = get_real_model(best_net).state_dict()
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

def train(
Expand Down
19 changes: 9 additions & 10 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from super_gradients.training.utils.ema import ModelEMA
from super_gradients.training.utils.optimizer_utils import build_optimizer
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
from super_gradients.training.utils.utils import fuzzy_idx_in_list
from super_gradients.training.utils.utils import fuzzy_idx_in_list, get_real_model
from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
from super_gradients.training.metrics import Accuracy, Top5
from super_gradients.training.utils import random_seed
Expand Down Expand Up @@ -405,9 +405,6 @@ def _net_to_device(self):
local_rank = int(device_config.device.split(":")[1])
self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

else:
self.net = core_utils.WrappedModel(self.net)

def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
"""
train_epoch - A single epoch training procedure
Expand Down Expand Up @@ -592,7 +589,7 @@ def _save_checkpoint(
"""
# WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return
if validation_results_tuple is None:
self.sg_logger.add_checkpoint(tag="ckpt_latest_weights_only.pth", state_dict={"net": self.net.state_dict()}, global_step=epoch)
self.sg_logger.add_checkpoint(tag="ckpt_latest_weights_only.pth", state_dict={"net": get_real_model(self.net).state_dict()}, global_step=epoch)
return

# COMPUTE THE CURRENT metric
Expand All @@ -604,15 +601,16 @@ def _save_checkpoint(
)

# BUILD THE state_dict
state = {"net": self.net.state_dict(), "acc": metric, "epoch": epoch}
state = {"net": get_real_model(self.net).state_dict(), "acc": metric, "epoch": epoch}

if optimizer is not None:
state["optimizer_state_dict"] = optimizer.state_dict()

if self.scaler is not None:
state["scaler_state_dict"] = self.scaler.state_dict()

if self.ema:
state["ema_net"] = self.ema_model.ema.state_dict()
state["ema_net"] = get_real_model(self.ema_model.ema).state_dict()

if isinstance(self.net.module, HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
state["processing_params"] = self.valid_loader.dataset.get_dataset_preprocessing_params()
Expand All @@ -638,7 +636,7 @@ def _save_checkpoint(
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))

if self.training_params.average_best_models:
net_for_averaging = self.ema_model.ema if self.ema else self.net
net_for_averaging = get_real_model(self.ema_model.ema if self.ema else self.net)
state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_tuple=validation_results_tuple)
self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)

Expand Down Expand Up @@ -1438,7 +1436,7 @@ def _validate_final_average_model(self, cleanup_snapshots_pkl_file=False):
with wait_for_the_master(local_rank):
average_model_sd = read_ckpt_state_dict(average_model_ckpt_path)["net"]

self.net.load_state_dict(average_model_sd)
get_real_model(self.net).load_state_dict(average_model_sd)
# testing the averaged model and save instead of best model if needed
averaged_model_results_tuple = self._validate_epoch(epoch=self.max_epochs)

Expand Down Expand Up @@ -1494,7 +1492,8 @@ def _re_build_model(self, arch_params={}):

if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
logger.warning("Warning: distributed training is not supported in re_build_model()")
self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids()) if device_config.multi_gpu else core_utils.WrappedModel(self.net)
if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids())

@property
def get_module(self):
Expand Down
7 changes: 7 additions & 0 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Uni
:return:
"""
state_dict = state_dict["net"] if "net" in state_dict else state_dict

# This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper
# and contains "module." prefix in all keys
# If all keys start with "module.", then we remove it.
if all([key.startswith("module.") for key in state_dict.keys()]):
state_dict = collections.OrderedDict([(key[7:], value) for key, value in state_dict.items()])

try:
strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
net.load_state_dict(state_dict, strict=strict_bool)
Expand Down
15 changes: 6 additions & 9 deletions src/super_gradients/training/utils/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
from super_gradients.training import utils as core_utils
from super_gradients.training.models import SgModule
from super_gradients.training.models.kd_modules.kd_module import KDModule
from super_gradients.training.utils.ema_decay_schedules import IDecayFunction, EMA_DECAY_FUNCTIONS
Expand Down Expand Up @@ -172,15 +171,13 @@ def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunct
its final value. beta=15 is ~40% of the training process.
"""
# Only work on the student (we don't want to update and to have a duplicate of the teacher)
super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, decay_function=decay_function)
super().__init__(model=kd_model.student, decay=decay, decay_function=decay_function)

# Overwrite current ema attribute with combination of the student model EMA (current self.ema)
# with already the instantiated teacher, to have the final KD EMA
self.ema = core_utils.WrappedModel(
KDModule(
arch_params=kd_model.module.arch_params,
student=self.ema.module,
teacher=kd_model.module.teacher,
run_teacher_on_eval=kd_model.module.run_teacher_on_eval,
)
self.ema = KDModule(
arch_params=kd_model.arch_params,
student=self.ema,
teacher=kd_model.teacher,
run_teacher_on_eval=kd_model.run_teacher_on_eval,
)
20 changes: 14 additions & 6 deletions src/super_gradients/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path
from typing import Mapping, Optional, Tuple, Union, List, Dict, Any, Iterable
from zipfile import ZipFile
from torch.nn.parallel import DistributedDataParallel

import numpy as np
import torch
Expand Down Expand Up @@ -77,13 +78,20 @@ def validate(self):
validate(self.__dict__, self.schema)


class WrappedModel(nn.Module):
def __init__(self, module):
super(WrappedModel, self).__init__()
self.module = module # that I actually define.
def get_real_model(model: Union[nn.Module, nn.DataParallel, DistributedDataParallel]) -> nn.Module:
"""
Get the real model from a model wrapper (DataParallel, DistributedDataParallel)

def forward(self, x):
return self.module(x)
:param model:
:return:
"""
if isinstance(model, DistributedDataParallel):
return model.module
elif isinstance(model, nn.DataParallel):
return model.module
elif isinstance(model, nn.Module):
return model
raise ValueError(f"Unknown model type: {type(model)}")


def arch_params_deprecated(func):
Expand Down
4 changes: 2 additions & 2 deletions src/super_gradients/training/utils/weight_averaging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import numpy as np
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict
from super_gradients.training.utils.utils import move_state_dict_to_device
from super_gradients.training.utils.utils import move_state_dict_to_device, get_real_model


class ModelWeightAveraging:
Expand Down Expand Up @@ -63,7 +63,7 @@ def update_snapshots_dict(self, model, validation_results_tuple):
require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_tuple)
if require_update:
# moving state dict to cpu
new_sd = model.state_dict()
new_sd = get_real_model(model).state_dict()
new_sd = move_state_dict_to_device(new_sd, "cpu")

averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
Expand Down