From 59fb789e8faf6ea1ef3e0fd97727cb48035e982a Mon Sep 17 00:00:00 2001 From: shayaharon Date: Tue, 4 Apr 2023 12:08:33 +0300 Subject: [PATCH] added student adapter --- .../training/models/kd_modules/kd_module.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/super_gradients/training/models/kd_modules/kd_module.py b/src/super_gradients/training/models/kd_modules/kd_module.py index 601dd9359f..2bf674af90 100644 --- a/src/super_gradients/training/models/kd_modules/kd_module.py +++ b/src/super_gradients/training/models/kd_modules/kd_module.py @@ -24,9 +24,10 @@ class implementing Knowledge Distillation logic as an SgModule run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode) arch_params: HpmStruct- Architecture H.P. - Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net s.t - teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a - different input format from the student (for example different normalization). + Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net to act as if + teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a + different input format from the student (for example different normalization). + Equivalent arg for the student model, can be passed through student_input_adapter. """ @@ -36,6 +37,7 @@ def __init__(self, arch_params: HpmStruct, student: SgModule, teacher: torch.nn. self.student = student self.teacher = teacher self.teacher_input_adapter = get_param(self.arch_params, "teacher_input_adapter") + self.student_input_adapter = get_param(self.arch_params, "student_input_adapter") self.run_teacher_on_eval = run_teacher_on_eval self._freeze_teacher() @@ -62,10 +64,17 @@ def eval(self): self.teacher.eval() def forward(self, x): + if self.student_input_adapter is not None: + student_output = self.student(self.student_input_adapter(x)) + else: + student_output = self.student(x) + if self.teacher_input_adapter is not None: - return KDOutput(student_output=self.student(x), teacher_output=self.teacher(self.teacher_input_adapter(x))) + teacher_output = self.teacher(self.teacher_input_adapter(x)) else: - return KDOutput(student_output=self.student(x), teacher_output=self.teacher(x)) + teacher_output = self.teacher(x) + + return KDOutput(student_output=student_output, teacher_output=teacher_output) def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list: return self.student.initialize_param_groups(lr, training_params)