1
1
import hydra
2
2
import torch .nn
3
- from omegaconf import DictConfig
3
+ from omegaconf import DictConfig , OmegaConf
4
4
from torch .utils .data import DataLoader
5
5
6
+ from super_gradients .training .utils .distributed_training_utils import setup_device
6
7
from super_gradients .common import MultiGPUMode
7
8
from super_gradients .training .dataloaders import dataloaders
8
9
from super_gradients .training .models import SgModule
9
10
from super_gradients .training .models .all_architectures import KD_ARCHITECTURES
10
11
from super_gradients .training .models .kd_modules .kd_module import KDModule
11
12
from super_gradients .training .sg_trainer import Trainer
12
- from typing import Union
13
+ from typing import Union , Dict
13
14
from super_gradients .common .abstractions .abstract_logger import get_logger
14
15
from super_gradients .training import utils as core_utils , models
15
16
from super_gradients .training .pretrained_models import PRETRAINED_NUM_CLASSES
25
26
)
26
27
from super_gradients .training .utils .callbacks import KDModelMetricsUpdateCallback
27
28
from super_gradients .training .utils .ema import KDModelEMA
28
- from super_gradients .training .utils .sg_trainer_utils import parse_args
29
29
30
30
logger = get_logger (__name__ )
31
31
@@ -47,11 +47,16 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
47
47
@return: output of kd_trainer.train(...) (i.e results tuple)
48
48
"""
49
49
# INSTANTIATE ALL OBJECTS IN CFG
50
- cfg = hydra .utils .instantiate (cfg )
50
+ setup_device (
51
+ device = core_utils .get_param (cfg , "device" ),
52
+ multi_gpu = core_utils .get_param (cfg , "multi_gpu" ),
53
+ num_gpus = core_utils .get_param (cfg , "num_gpus" ),
54
+ )
51
55
52
- kwargs = parse_args (cfg , cls .__init__ )
56
+ # INSTANTIATE ALL OBJECTS IN CFG
57
+ cfg = hydra .utils .instantiate (cfg )
53
58
54
- trainer = KDTrainer (** kwargs )
59
+ trainer = KDTrainer (experiment_name = cfg . experiment_name , ckpt_root_dir = cfg . ckpt_root_dir )
55
60
56
61
# INSTANTIATE DATA LOADERS
57
62
train_dataloader = dataloaders .get (
@@ -80,6 +85,8 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
80
85
load_backbone = cfg .teacher_checkpoint_params .load_backbone ,
81
86
)
82
87
88
+ recipe_logged_cfg = {"recipe_config" : OmegaConf .to_container (cfg , resolve = True )}
89
+
83
90
# TRAIN
84
91
trainer .train (
85
92
training_params = cfg .training_hyperparams ,
@@ -90,6 +97,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
90
97
run_teacher_on_eval = cfg .run_teacher_on_eval ,
91
98
train_loader = train_dataloader ,
92
99
valid_loader = val_dataloader ,
100
+ additional_configs_to_log = recipe_logged_cfg ,
93
101
)
94
102
95
103
def _validate_args (self , arch_params , architecture , checkpoint_params , ** kwargs ):
@@ -275,20 +283,22 @@ def _save_best_checkpoint(self, epoch, state):
275
283
def train (
276
284
self ,
277
285
model : KDModule = None ,
278
- training_params : dict = dict () ,
286
+ training_params : Dict = None ,
279
287
student : SgModule = None ,
280
288
teacher : torch .nn .Module = None ,
281
289
kd_architecture : Union [KDModule .__class__ , str ] = "kd_module" ,
282
- kd_arch_params : dict = dict () ,
290
+ kd_arch_params : Dict = None ,
283
291
run_teacher_on_eval = False ,
284
292
train_loader : DataLoader = None ,
285
293
valid_loader : DataLoader = None ,
294
+ additional_configs_to_log : Dict = None ,
286
295
* args ,
287
296
** kwargs ,
288
297
):
289
298
"""
290
299
Trains the student network (wrapped in KDModule network).
291
300
301
+
292
302
:param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
293
303
student and teacher (default=None)
294
304
:param training_params: dict, Same as in Trainer.train()
@@ -299,12 +309,21 @@ def train(
299
309
:param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
300
310
:param train_loader: Dataloader for train set.
301
311
:param valid_loader: Dataloader for validation.
312
+ :param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training's
313
+ sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}, (optional, default=None)
302
314
"""
303
315
kd_net = self .net or model
316
+ kd_arch_params = kd_arch_params or dict ()
304
317
if kd_net is None :
305
318
if student is None or teacher is None :
306
319
raise ValueError ("Must pass student and teacher models or net (KDModule)." )
307
320
kd_net = self ._instantiate_kd_net (
308
321
arch_params = HpmStruct (** kd_arch_params ), architecture = kd_architecture , run_teacher_on_eval = run_teacher_on_eval , student = student , teacher = teacher
309
322
)
310
- super (KDTrainer , self ).train (model = kd_net , training_params = training_params , train_loader = train_loader , valid_loader = valid_loader )
323
+ super (KDTrainer , self ).train (
324
+ model = kd_net ,
325
+ training_params = training_params ,
326
+ train_loader = train_loader ,
327
+ valid_loader = valid_loader ,
328
+ additional_configs_to_log = additional_configs_to_log ,
329
+ )
0 commit comments