Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
Browse files Browse the repository at this point in the history
…actor_v2
  • Loading branch information
tomaarsen committed Jan 23, 2023
2 parents 9fc55a6 + 0cb8ffd commit 7d4ad00
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/setfit/trainer_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def train_embeddings(
distance_metric=args.distance_metric,
margin=args.margin,
)

train_steps = len(train_dataloader) * args.embedding_num_epochs
else:
train_examples = []

Expand All @@ -144,19 +142,18 @@ def train_embeddings(
batch_size = args.embedding_batch_size
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = self.loss_class(self.student_model.model_body)
train_steps = len(train_dataloader) * args.embedding_num_epochs

total_train_steps = len(train_dataloader) * args.embedding_num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_examples)}")
logger.info(f" Num epochs = {args.embedding_num_epochs}")
logger.info(f" Total optimization steps = {train_steps}")
logger.info(f" Total optimization steps = {total_train_steps}")
logger.info(f" Total train batch size = {batch_size}")

warmup_steps = math.ceil(train_steps * args.warmup_proportion)
warmup_steps = math.ceil(total_train_steps * args.warmup_proportion)
self.student_model.model_body.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=args.embedding_num_epochs,
steps_per_epoch=train_steps,
optimizer_params={"lr": args.body_embedding_learning_rate},
warmup_steps=warmup_steps,
show_progress_bar=args.show_progress_bar,
Expand Down

0 comments on commit 7d4ad00

Please sign in to comment.