Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting max length #176

Merged
merged 12 commits into from
Dec 7, 2022
23 changes: 20 additions & 3 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,15 @@ def fit(
learning_rate: Optional[float] = None,
body_learning_rate: Optional[float] = None,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
show_progress_bar: Optional[bool] = None,
) -> None:
if isinstance(self.model_head, nn.Module): # train with pyTorch
device = self.model_body.device
self.model_body.train()
self.model_head.train()

dataloader = self._prepare_dataloader(x_train, y_train, batch_size)
dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length)
criterion = self.model_head.get_loss_fn()
optimizer = self._prepare_optimizer(learning_rate, body_learning_rate, l2_weight)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
Expand All @@ -243,13 +244,29 @@ def fit(
self.model_head.fit(embeddings, y_train)

def _prepare_dataloader(
self, x_train: List[str], y_train: List[int], batch_size: int, shuffle: bool = True
self,
x_train: List[str],
y_train: List[int],
batch_size: int,
max_length: Optional[int] = None,
shuffle: bool = True,
) -> DataLoader:
max_acceptable_length = self.model_body.get_max_seq_length()
max_length = max_length or max_acceptable_length
blakechi marked this conversation as resolved.
Show resolved Hide resolved
if max_length > max_acceptable_length:
logger.warning(
(
f"The specified `max_length`: {max_length} is greater than the maximum length of the current model body: {max_acceptable_length}. "
f"Using {max_acceptable_length} instead."
)
)
max_length = max_acceptable_length

dataset = SetFitDataset(
x_train,
y_train,
tokenizer=self.model_body.tokenizer,
max_length=self.model_body.get_max_seq_length(),
max_length=max_length,
)
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=SetFitDataset.collate_fn, shuffle=shuffle, pin_memory=True
Expand Down
6 changes: 6 additions & 0 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def train(
learning_rate: Optional[float] = None,
body_learning_rate: Optional[float] = None,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
):
"""
Expand All @@ -279,6 +280,10 @@ def train(
If ignore, will be the same as `learning_rate`.
l2_weight (float, *optional*):
Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression.
max_length (int, *optional*, defaults to `None`):
The maximum number of tokens for one data sample. Currently only for training the differentiable head.
If`None`, will use the maximum number of tokens the model body can accept.
Copy link
Contributor

@PhilipMay PhilipMay Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After "If" there is a space missing.

            If`None`, will use the maximum number of tokens the model body can accept.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding that, @PhilipMay!

If `max_length` is greater than the maximum number of acceptable tokens the model body can accept, it will be set to the maximum number of acceptable tokens.
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
The trial run or the hyperparameter dictionary for hyperparameter search.
"""
Expand Down Expand Up @@ -380,6 +385,7 @@ def train(
learning_rate=learning_rate,
body_learning_rate=body_learning_rate,
l2_weight=l2_weight,
max_length=max_length,
show_progress_bar=True,
)

Expand Down
14 changes: 14 additions & 0 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def setUpClass(cls):

cls.model = model
cls.out_features = num_classes
cls.x_train = x_train
cls.y_train = y_train

@staticmethod
def _build_model(num_classes: int) -> SetFitModel:
Expand Down Expand Up @@ -169,6 +171,18 @@ def test_setfit_model_backward(self):
assert not param.grad.isnan().any().item(), f"Gradients of {name} in the model body have NaN."
assert not param.grad.isinf().any().item(), f"Gradients of {name} in the model body have Inf."

def test_max_length_is_larger_than_max_acceptable_length(self):
max_length = int(1e6)
dataloader = self.model._prepare_dataloader(self.x_train, self.y_train, batch_size=1, max_length=max_length)

assert dataloader.dataset.max_length == self.model.model_body.get_max_seq_length()

def test_max_length_is_smaller_than_max_acceptable_length(self):
max_length = 32
dataloader = self.model._prepare_dataloader(self.x_train, self.y_train, batch_size=1, max_length=max_length)

assert dataloader.dataset.max_length == max_length


def test_setfit_from_pretrained_local_model_without_head(tmp_path):
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
Expand Down
73 changes: 73 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
from transformers.testing_utils import require_optuna
from transformers.utils.hp_naming import TrialShortNamer

from setfit import logging
from setfit.modeling import SetFitModel, SupConLoss
from setfit.trainer import SetFitTrainer
from setfit.utils import BestRun


logging.set_verbosity_warning()
logging.enable_propagation()


class SetFitTrainerTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
Expand Down Expand Up @@ -172,6 +177,74 @@ def test_trainer_raises_error_with_wrong_warmup_proportion(self):
SetFitTrainer(warmup_proportion=-0.1)


class SetFitTrainerDifferentiableHeadTest(TestCase):
def setUp(self):
self.dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
use_differentiable_head=True,
head_params={"out_features": 3},
)
self.num_iterations = 1

def test_trainer_max_length_exceeds_max_acceptable_length(self):
trainer = SetFitTrainer(
model=self.model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.unfreeze(keep_body_frozen=True)
with self.assertLogs(level=logging.WARNING) as cm:
max_length = 4096
max_acceptable_length = self.model.model_body.get_max_seq_length()
trainer.train(
num_epochs=1,
batch_size=3,
learning_rate=1e-2,
l2_weight=0.0,
max_length=max_length,
)
self.assertEqual(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

cm.output,
[
(
f"WARNING:setfit.modeling:The specified `max_length`: {max_length} is greater than the maximum length "
f"of the current model body: {max_acceptable_length}. Using {max_acceptable_length} instead."
)
],
)

def test_trainer_max_length_is_smaller_than_max_acceptable_length(self):
trainer = SetFitTrainer(
model=self.model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.unfreeze(keep_body_frozen=True)

# An alternative way of `assertNoLogs`, which is new in Python 3.10
try:
with self.assertLogs(level=logging.WARNING) as cm:
max_length = 32
trainer.train(
num_epochs=1,
batch_size=3,
learning_rate=1e-2,
l2_weight=0.0,
max_length=max_length,
)
self.assertEqual(cm.output, [])
except AssertionError as e:
if e.args[0] != "no logs of level WARNING or higher triggered on root":
raise AssertionError(e)


class SetFitTrainerMultilabelTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained(
Expand Down