@@ -24,9 +24,10 @@ class implementing Knowledge Distillation logic as an SgModule
24
24
run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
25
25
arch_params: HpmStruct- Architecture H.P.
26
26
27
- Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net s.t
28
- teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
29
- different input format from the student (for example different normalization).
27
+ Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net to act as if
28
+ teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
29
+ different input format from the student (for example different normalization).
30
+ Equivalent arg for the student model, can be passed through student_input_adapter.
30
31
31
32
"""
32
33
@@ -36,6 +37,7 @@ def __init__(self, arch_params: HpmStruct, student: SgModule, teacher: torch.nn.
36
37
self .student = student
37
38
self .teacher = teacher
38
39
self .teacher_input_adapter = get_param (self .arch_params , "teacher_input_adapter" )
40
+ self .student_input_adapter = get_param (self .arch_params , "student_input_adapter" )
39
41
self .run_teacher_on_eval = run_teacher_on_eval
40
42
self ._freeze_teacher ()
41
43
@@ -62,10 +64,17 @@ def eval(self):
62
64
self .teacher .eval ()
63
65
64
66
def forward (self , x ):
67
+ if self .student_input_adapter is not None :
68
+ student_output = self .student (self .student_input_adapter (x ))
69
+ else :
70
+ student_output = self .student (x )
71
+
65
72
if self .teacher_input_adapter is not None :
66
- return KDOutput ( student_output = self . student ( x ), teacher_output = self .teacher (self .teacher_input_adapter (x ) ))
73
+ teacher_output = self .teacher (self .teacher_input_adapter (x ))
67
74
else :
68
- return KDOutput (student_output = self .student (x ), teacher_output = self .teacher (x ))
75
+ teacher_output = self .teacher (x )
76
+
77
+ return KDOutput (student_output = student_output , teacher_output = teacher_output )
69
78
70
79
def initialize_param_groups (self , lr : float , training_params : HpmStruct ) -> list :
71
80
return self .student .initialize_param_groups (lr , training_params )
0 commit comments