From 37e474e1e60cf705e4d91dbe7c1248b234efc5c0 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Thu, 19 Jan 2023 11:41:36 +0200 Subject: [PATCH 1/2] black --- .../training/kd_trainer/kd_trainer.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/super_gradients/training/kd_trainer/kd_trainer.py b/src/super_gradients/training/kd_trainer/kd_trainer.py index 33680d03e6..fa32e506d6 100644 --- a/src/super_gradients/training/kd_trainer/kd_trainer.py +++ b/src/super_gradients/training/kd_trainer/kd_trainer.py @@ -3,13 +3,14 @@ from omegaconf import DictConfig from torch.utils.data import DataLoader +from super_gradients.training.utils.distributed_training_utils import setup_device from super_gradients.common import MultiGPUMode from super_gradients.training.dataloaders import dataloaders from super_gradients.training.models import SgModule from super_gradients.training.models.all_architectures import KD_ARCHITECTURES from super_gradients.training.models.kd_modules.kd_module import KDModule from super_gradients.training.sg_trainer import Trainer -from typing import Union +from typing import Union, Dict from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.training import utils as core_utils, models from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES @@ -25,7 +26,6 @@ ) from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback from super_gradients.training.utils.ema import KDModelEMA -from super_gradients.training.utils.sg_trainer_utils import parse_args logger = get_logger(__name__) @@ -47,11 +47,16 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None: @return: output of kd_trainer.train(...) (i.e results tuple) """ # INSTANTIATE ALL OBJECTS IN CFG - cfg = hydra.utils.instantiate(cfg) + setup_device( + device=core_utils.get_param(cfg, "device"), + multi_gpu=core_utils.get_param(cfg, "multi_gpu"), + num_gpus=core_utils.get_param(cfg, "num_gpus"), + ) - kwargs = parse_args(cfg, cls.__init__) + # INSTANTIATE ALL OBJECTS IN CFG + cfg = hydra.utils.instantiate(cfg) - trainer = KDTrainer(**kwargs) + trainer = KDTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir) # INSTANTIATE DATA LOADERS train_dataloader = dataloaders.get( @@ -275,20 +280,22 @@ def _save_best_checkpoint(self, epoch, state): def train( self, model: KDModule = None, - training_params: dict = dict(), + training_params: Dict = None, student: SgModule = None, teacher: torch.nn.Module = None, kd_architecture: Union[KDModule.__class__, str] = "kd_module", - kd_arch_params: dict = dict(), + kd_arch_params: Dict = None, run_teacher_on_eval=False, train_loader: DataLoader = None, valid_loader: DataLoader = None, + additional_configs_to_log: Dict = None, *args, **kwargs, ): """ Trains the student network (wrapped in KDModule network). + :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture, student and teacher (default=None) :param training_params: dict, Same as in Trainer.train() @@ -299,12 +306,21 @@ def train( :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode) :param train_loader: Dataloader for train set. :param valid_loader: Dataloader for validation. + :param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training's + sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}, (optional, default=None) """ kd_net = self.net or model + kd_arch_params = kd_arch_params or dict() if kd_net is None: if student is None or teacher is None: raise ValueError("Must pass student and teacher models or net (KDModule).") kd_net = self._instantiate_kd_net( arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher ) - super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader) + super(KDTrainer, self).train( + model=kd_net, + training_params=training_params, + train_loader=train_loader, + valid_loader=valid_loader, + additional_configs_to_log=additional_configs_to_log, + ) From ee97704e79dd7c74a84e88e0d25218fac57f8557 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Thu, 19 Jan 2023 11:52:39 +0200 Subject: [PATCH 2/2] black on kdtrainer --- src/super_gradients/training/kd_trainer/kd_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/kd_trainer/kd_trainer.py b/src/super_gradients/training/kd_trainer/kd_trainer.py index fa32e506d6..ab003fcac5 100644 --- a/src/super_gradients/training/kd_trainer/kd_trainer.py +++ b/src/super_gradients/training/kd_trainer/kd_trainer.py @@ -1,6 +1,6 @@ import hydra import torch.nn -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader from super_gradients.training.utils.distributed_training_utils import setup_device @@ -85,6 +85,8 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None: load_backbone=cfg.teacher_checkpoint_params.load_backbone, ) + recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} + # TRAIN trainer.train( training_params=cfg.training_hyperparams, @@ -95,6 +97,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None: run_teacher_on_eval=cfg.run_teacher_on_eval, train_loader=train_dataloader, valid_loader=val_dataloader, + additional_configs_to_log=recipe_logged_cfg, ) def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):