Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/sg 000 update kd train from config #638

Merged
merged 3 commits into from
Jan 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import hydra
import torch.nn
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.common import MultiGPUMode
from super_gradients.training.dataloaders import dataloaders
from super_gradients.training.models import SgModule
from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
from super_gradients.training.models.kd_modules.kd_module import KDModule
from super_gradients.training.sg_trainer import Trainer
from typing import Union
from typing import Union, Dict
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training import utils as core_utils, models
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
Expand All @@ -25,7 +26,6 @@
)
from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
from super_gradients.training.utils.ema import KDModelEMA
from super_gradients.training.utils.sg_trainer_utils import parse_args

logger = get_logger(__name__)

Expand All @@ -47,11 +47,16 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
@return: output of kd_trainer.train(...) (i.e results tuple)
"""
# INSTANTIATE ALL OBJECTS IN CFG
cfg = hydra.utils.instantiate(cfg)
setup_device(
device=core_utils.get_param(cfg, "device"),
multi_gpu=core_utils.get_param(cfg, "multi_gpu"),
num_gpus=core_utils.get_param(cfg, "num_gpus"),
)

kwargs = parse_args(cfg, cls.__init__)
# INSTANTIATE ALL OBJECTS IN CFG
cfg = hydra.utils.instantiate(cfg)

trainer = KDTrainer(**kwargs)
trainer = KDTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)

# INSTANTIATE DATA LOADERS
train_dataloader = dataloaders.get(
Expand Down Expand Up @@ -80,6 +85,8 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
load_backbone=cfg.teacher_checkpoint_params.load_backbone,
)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}

# TRAIN
trainer.train(
training_params=cfg.training_hyperparams,
Expand All @@ -90,6 +97,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
run_teacher_on_eval=cfg.run_teacher_on_eval,
train_loader=train_dataloader,
valid_loader=val_dataloader,
additional_configs_to_log=recipe_logged_cfg,
)

def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
Expand Down Expand Up @@ -275,20 +283,22 @@ def _save_best_checkpoint(self, epoch, state):
def train(
self,
model: KDModule = None,
training_params: dict = dict(),
training_params: Dict = None,
student: SgModule = None,
teacher: torch.nn.Module = None,
kd_architecture: Union[KDModule.__class__, str] = "kd_module",
kd_arch_params: dict = dict(),
kd_arch_params: Dict = None,
run_teacher_on_eval=False,
train_loader: DataLoader = None,
valid_loader: DataLoader = None,
additional_configs_to_log: Dict = None,
*args,
**kwargs,
):
"""
Trains the student network (wrapped in KDModule network).


:param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
student and teacher (default=None)
:param training_params: dict, Same as in Trainer.train()
Expand All @@ -299,12 +309,21 @@ def train(
:param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
:param train_loader: Dataloader for train set.
:param valid_loader: Dataloader for validation.
:param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training's
sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}, (optional, default=None)
"""
kd_net = self.net or model
kd_arch_params = kd_arch_params or dict()
if kd_net is None:
if student is None or teacher is None:
raise ValueError("Must pass student and teacher models or net (KDModule).")
kd_net = self._instantiate_kd_net(
arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
)
super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader)
super(KDTrainer, self).train(
model=kd_net,
training_params=training_params,
train_loader=train_loader,
valid_loader=valid_loader,
additional_configs_to_log=additional_configs_to_log,
)