Skip to content

Commit

Permalink
Fix squared optimization steps bug in distillation trainer (#284)
Browse files Browse the repository at this point in the history
Based on the bug fix from #280
  • Loading branch information
tomaarsen authored Jan 23, 2023
1 parent 29c0348 commit 0cb8ffd
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/setfit/trainer_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def train(
distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance,
margin=0.25,
)

train_steps = len(train_dataloader) * self.num_epochs
else:
train_examples = []

Expand All @@ -210,19 +208,18 @@ def train(

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = self.loss_class(self.student_model.model_body)
train_steps = len(train_dataloader) * num_epochs

total_train_steps = len(train_dataloader) * num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_examples)}")
logger.info(f" Num epochs = {num_epochs}")
logger.info(f" Total optimization steps = {train_steps}")
logger.info(f" Total optimization steps = {total_train_steps}")
logger.info(f" Total train batch size = {batch_size}")

warmup_steps = math.ceil(train_steps * self.warmup_proportion)
warmup_steps = math.ceil(total_train_steps * self.warmup_proportion)
self.student_model.model_body.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
steps_per_epoch=train_steps,
optimizer_params={"lr": learning_rate},
warmup_steps=warmup_steps,
show_progress_bar=show_progress_bar,
Expand Down

0 comments on commit 0cb8ffd

Please sign in to comment.