Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added student adapter #820

Merged
merged 5 commits into from
Apr 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/super_gradients/training/models/kd_modules/kd_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class implementing Knowledge Distillation logic as an SgModule
run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
arch_params: HpmStruct- Architecture H.P.

Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net s.t
teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
different input format from the student (for example different normalization).
Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net to act as if
teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
different input format from the student (for example different normalization).
Equivalent arg for the student model, can be passed through student_input_adapter.

"""

Expand All @@ -36,6 +37,7 @@ def __init__(self, arch_params: HpmStruct, student: SgModule, teacher: torch.nn.
self.student = student
self.teacher = teacher
self.teacher_input_adapter = get_param(self.arch_params, "teacher_input_adapter")
self.student_input_adapter = get_param(self.arch_params, "student_input_adapter")
self.run_teacher_on_eval = run_teacher_on_eval
self._freeze_teacher()

Expand All @@ -62,10 +64,17 @@ def eval(self):
self.teacher.eval()

def forward(self, x):
if self.student_input_adapter is not None:
student_output = self.student(self.student_input_adapter(x))
else:
student_output = self.student(x)

if self.teacher_input_adapter is not None:
return KDOutput(student_output=self.student(x), teacher_output=self.teacher(self.teacher_input_adapter(x)))
teacher_output = self.teacher(self.teacher_input_adapter(x))
else:
return KDOutput(student_output=self.student(x), teacher_output=self.teacher(x))
teacher_output = self.teacher(x)

return KDOutput(student_output=student_output, teacher_output=teacher_output)

def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
return self.student.initialize_param_groups(lr, training_params)
Expand Down