Skip to content

Commit 845c31a

Browse files
shaydeciofrimasad
authored andcommitted
added student adapter (#820)
Co-authored-by: Ofri Masad <[email protected]>
1 parent 285686e commit 845c31a

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/super_gradients/training/models/kd_modules/kd_module.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ class implementing Knowledge Distillation logic as an SgModule
2424
run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
2525
arch_params: HpmStruct- Architecture H.P.
2626
27-
Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net s.t
28-
teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
29-
different input format from the student (for example different normalization).
27+
Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net to act as if
28+
teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
29+
different input format from the student (for example different normalization).
30+
Equivalent arg for the student model, can be passed through student_input_adapter.
3031
3132
"""
3233

@@ -36,6 +37,7 @@ def __init__(self, arch_params: HpmStruct, student: SgModule, teacher: torch.nn.
3637
self.student = student
3738
self.teacher = teacher
3839
self.teacher_input_adapter = get_param(self.arch_params, "teacher_input_adapter")
40+
self.student_input_adapter = get_param(self.arch_params, "student_input_adapter")
3941
self.run_teacher_on_eval = run_teacher_on_eval
4042
self._freeze_teacher()
4143

@@ -62,10 +64,17 @@ def eval(self):
6264
self.teacher.eval()
6365

6466
def forward(self, x):
67+
if self.student_input_adapter is not None:
68+
student_output = self.student(self.student_input_adapter(x))
69+
else:
70+
student_output = self.student(x)
71+
6572
if self.teacher_input_adapter is not None:
66-
return KDOutput(student_output=self.student(x), teacher_output=self.teacher(self.teacher_input_adapter(x)))
73+
teacher_output = self.teacher(self.teacher_input_adapter(x))
6774
else:
68-
return KDOutput(student_output=self.student(x), teacher_output=self.teacher(x))
75+
teacher_output = self.teacher(x)
76+
77+
return KDOutput(student_output=student_output, teacher_output=teacher_output)
6978

7079
def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
7180
return self.student.initialize_param_groups(lr, training_params)

0 commit comments

Comments
 (0)