diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index 2546f7ea..318bb401 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -189,8 +189,6 @@ def train( distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance, margin=0.25, ) - - train_steps = len(train_dataloader) * self.num_epochs else: train_examples = [] @@ -210,19 +208,18 @@ def train( train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size) train_loss = self.loss_class(self.student_model.model_body) - train_steps = len(train_dataloader) * num_epochs + total_train_steps = len(train_dataloader) * num_epochs logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_examples)}") logger.info(f" Num epochs = {num_epochs}") - logger.info(f" Total optimization steps = {train_steps}") + logger.info(f" Total optimization steps = {total_train_steps}") logger.info(f" Total train batch size = {batch_size}") - warmup_steps = math.ceil(train_steps * self.warmup_proportion) + warmup_steps = math.ceil(total_train_steps * self.warmup_proportion) self.student_model.model_body.fit( train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, - steps_per_epoch=train_steps, optimizer_params={"lr": learning_rate}, warmup_steps=warmup_steps, show_progress_bar=show_progress_bar,