Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting max length #176

Merged
merged 12 commits into from
Dec 7, 2022
24 changes: 21 additions & 3 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union
Expand Down Expand Up @@ -209,14 +210,15 @@ def fit(
learning_rate: Optional[float] = None,
body_learning_rate: Optional[float] = None,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
show_progress_bar: Optional[bool] = None,
) -> None:
if isinstance(self.model_head, nn.Module): # train with pyTorch
device = self.model_body.device
self.model_body.train()
self.model_head.train()

dataloader = self._prepare_dataloader(x_train, y_train, batch_size)
dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length)
criterion = self.model_head.get_loss_fn()
optimizer = self._prepare_optimizer(learning_rate, body_learning_rate, l2_weight)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
Expand All @@ -243,13 +245,29 @@ def fit(
self.model_head.fit(embeddings, y_train)

def _prepare_dataloader(
self, x_train: List[str], y_train: List[int], batch_size: int, shuffle: bool = True
self,
x_train: List[str],
y_train: List[int],
batch_size: int,
max_length: Optional[int] = None,
shuffle: bool = True,
) -> DataLoader:
max_acceptable_length = self.model_body.get_max_seq_length()
max_length = max_length or max_acceptable_length
blakechi marked this conversation as resolved.
Show resolved Hide resolved
if max_length > max_acceptable_length:
warnings.warn(
blakechi marked this conversation as resolved.
Show resolved Hide resolved
(
f"The specified `max_length`: {max_length} is greater than the maximum length of the current model body: {max_acceptable_length}. "
f"Change `max_length` to {max_acceptable_length}."
blakechi marked this conversation as resolved.
Show resolved Hide resolved
)
)
max_length = max_acceptable_length

dataset = SetFitDataset(
x_train,
y_train,
tokenizer=self.model_body.tokenizer,
max_length=self.model_body.get_max_seq_length(),
max_length=max_length,
)
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=SetFitDataset.collate_fn, shuffle=shuffle, pin_memory=True
Expand Down
6 changes: 6 additions & 0 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def train(
learning_rate: Optional[float] = None,
body_learning_rate: Optional[float] = None,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
):
"""
Expand All @@ -279,6 +280,10 @@ def train(
If ignore, will be the same as `learning_rate`.
l2_weight (float, *optional*):
Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression.
max_length (int, *optional*):
blakechi marked this conversation as resolved.
Show resolved Hide resolved
The maximum number of tokens for one data sample. Currently only for training the differentiable head.
If ignore, will use the maximum number of tokens the model body can accept.
blakechi marked this conversation as resolved.
Show resolved Hide resolved
If `max_length` is greater than the maximum number of acceptable tokens the model body can accept, it will be set to the maximum number of acceptable tokens.
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
The trial run or the hyperparameter dictionary for hyperparameter search.
"""
Expand Down Expand Up @@ -380,6 +385,7 @@ def train(
learning_rate=learning_rate,
body_learning_rate=body_learning_rate,
l2_weight=l2_weight,
max_length=max_length,
show_progress_bar=True,
)

Expand Down
47 changes: 47 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,53 @@ def test_trainer_raises_error_with_wrong_warmup_proportion(self):
SetFitTrainer(warmup_proportion=-0.1)


class SetFitTrainerDifferentiableHeadTest(TestCase):
def setUp(self):
self.dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
use_differentiable_head=True,
head_params={"out_features": 3},
)
self.num_iterations = 1

def test_trainer_max_length_exceeds_max_acceptable_length(self):
trainer = SetFitTrainer(
model=self.model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.unfreeze(keep_body_frozen=True)
trainer.train(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we think about a way to actually test the large value has been overwritten with the model max length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. I will modify the test

Copy link
Contributor Author

@blakechi blakechi Nov 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just updated it. I tested whether the warning raises up correctly and checked the overwritten value in test_modeling.py.

num_epochs=1,
batch_size=3,
learning_rate=1e-2,
l2_weight=0.0,
max_length=4096,
)

def test_trainer_max_length_is_smaller_than_max_acceptable_length(self):
trainer = SetFitTrainer(
model=self.model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.unfreeze(keep_body_frozen=True)
trainer.train(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here - should we be testing the behaviour explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. Will push an update later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just updated it!

num_epochs=1,
batch_size=3,
learning_rate=1e-2,
l2_weight=0.0,
max_length=32,
)


class SetFitTrainerMultilabelTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained(
Expand Down