Skip to content

Commit 1fa5f59

Browse files
authored
Bug/sg 000 update kd train from config (#638)
* black * black on kdtrainer
1 parent 027c0a8 commit 1fa5f59

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

src/super_gradients/training/kd_trainer/kd_trainer.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import hydra
22
import torch.nn
3-
from omegaconf import DictConfig
3+
from omegaconf import DictConfig, OmegaConf
44
from torch.utils.data import DataLoader
55

6+
from super_gradients.training.utils.distributed_training_utils import setup_device
67
from super_gradients.common import MultiGPUMode
78
from super_gradients.training.dataloaders import dataloaders
89
from super_gradients.training.models import SgModule
910
from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
1011
from super_gradients.training.models.kd_modules.kd_module import KDModule
1112
from super_gradients.training.sg_trainer import Trainer
12-
from typing import Union
13+
from typing import Union, Dict
1314
from super_gradients.common.abstractions.abstract_logger import get_logger
1415
from super_gradients.training import utils as core_utils, models
1516
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
@@ -25,7 +26,6 @@
2526
)
2627
from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
2728
from super_gradients.training.utils.ema import KDModelEMA
28-
from super_gradients.training.utils.sg_trainer_utils import parse_args
2929

3030
logger = get_logger(__name__)
3131

@@ -47,11 +47,16 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
4747
@return: output of kd_trainer.train(...) (i.e results tuple)
4848
"""
4949
# INSTANTIATE ALL OBJECTS IN CFG
50-
cfg = hydra.utils.instantiate(cfg)
50+
setup_device(
51+
device=core_utils.get_param(cfg, "device"),
52+
multi_gpu=core_utils.get_param(cfg, "multi_gpu"),
53+
num_gpus=core_utils.get_param(cfg, "num_gpus"),
54+
)
5155

52-
kwargs = parse_args(cfg, cls.__init__)
56+
# INSTANTIATE ALL OBJECTS IN CFG
57+
cfg = hydra.utils.instantiate(cfg)
5358

54-
trainer = KDTrainer(**kwargs)
59+
trainer = KDTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)
5560

5661
# INSTANTIATE DATA LOADERS
5762
train_dataloader = dataloaders.get(
@@ -80,6 +85,8 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
8085
load_backbone=cfg.teacher_checkpoint_params.load_backbone,
8186
)
8287

88+
recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
89+
8390
# TRAIN
8491
trainer.train(
8592
training_params=cfg.training_hyperparams,
@@ -90,6 +97,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
9097
run_teacher_on_eval=cfg.run_teacher_on_eval,
9198
train_loader=train_dataloader,
9299
valid_loader=val_dataloader,
100+
additional_configs_to_log=recipe_logged_cfg,
93101
)
94102

95103
def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
@@ -275,20 +283,22 @@ def _save_best_checkpoint(self, epoch, state):
275283
def train(
276284
self,
277285
model: KDModule = None,
278-
training_params: dict = dict(),
286+
training_params: Dict = None,
279287
student: SgModule = None,
280288
teacher: torch.nn.Module = None,
281289
kd_architecture: Union[KDModule.__class__, str] = "kd_module",
282-
kd_arch_params: dict = dict(),
290+
kd_arch_params: Dict = None,
283291
run_teacher_on_eval=False,
284292
train_loader: DataLoader = None,
285293
valid_loader: DataLoader = None,
294+
additional_configs_to_log: Dict = None,
286295
*args,
287296
**kwargs,
288297
):
289298
"""
290299
Trains the student network (wrapped in KDModule network).
291300
301+
292302
:param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
293303
student and teacher (default=None)
294304
:param training_params: dict, Same as in Trainer.train()
@@ -299,12 +309,21 @@ def train(
299309
:param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
300310
:param train_loader: Dataloader for train set.
301311
:param valid_loader: Dataloader for validation.
312+
:param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training's
313+
sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}, (optional, default=None)
302314
"""
303315
kd_net = self.net or model
316+
kd_arch_params = kd_arch_params or dict()
304317
if kd_net is None:
305318
if student is None or teacher is None:
306319
raise ValueError("Must pass student and teacher models or net (KDModule).")
307320
kd_net = self._instantiate_kd_net(
308321
arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
309322
)
310-
super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader)
323+
super(KDTrainer, self).train(
324+
model=kd_net,
325+
training_params=training_params,
326+
train_loader=train_loader,
327+
valid_loader=valid_loader,
328+
additional_configs_to_log=additional_configs_to_log,
329+
)

0 commit comments

Comments
 (0)