Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting max length #176

Merged
merged 12 commits into from
Dec 7, 2022

Conversation

blakechi
Copy link
Contributor

This PR is opened for resolving #172 .

  • Allow users to set max_length when calling SetFitTrainer.train.
  • max_length will be set to the maximum number of tokens the model body can handle if it's too large.
  • Add relevant tests.

A snippet of usage:

from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
num_classes = 2
train_dataset = dataset["train"].shuffle(seed=42).select(range(8 * num_classes))
eval_dataset = dataset["validation"]

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    use_differentiable_head=True,
    head_params={"out_features": num_classes},
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=16,
    num_iterations=20,
    num_epochs=1,
    column_mapping={"sentence": "text", "label": "label"}
)

# Train the differentiable head only
trainer.unfreeze(keep_body_frozen=True)

# Train with custom `max_length`
trainer.train(
    num_epochs=1,
    batch_size=16,
    body_learning_rate=1e-5,
    learning_rate=1e-2,
    l2_weight=0.0,
    max_length=64,  # set your preferred length here
)

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this nice quality of life improvement @blakechi 🔥 !

The PR is looking great, and I've left a few small comments that would be nice to address before we merge

src/setfit/modeling.py Outdated Show resolved Hide resolved
src/setfit/trainer.py Outdated Show resolved Hide resolved
src/setfit/trainer.py Outdated Show resolved Hide resolved
src/setfit/modeling.py Outdated Show resolved Hide resolved
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.unfreeze(keep_body_frozen=True)
trainer.train(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we think about a way to actually test the large value has been overwritten with the model max length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. I will modify the test

Copy link
Contributor Author

@blakechi blakechi Nov 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just updated it. I tested whether the warning raises up correctly and checked the overwritten value in test_modeling.py.

column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.unfreeze(keep_body_frozen=True)
trainer.train(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here - should we be testing the behaviour explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. Will push an update later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just updated it!

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay and thanks for iterating @blakechi 🤘

I've left one nit and then think we can merge!

src/setfit/modeling.py Outdated Show resolved Hide resolved
l2_weight=0.0,
max_length=max_length,
)
self.assertEqual(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@blakechi
Copy link
Contributor Author

Sorry for the delay and thanks for iterating @blakechi 🤘

I've left one nit and then think we can merge!

No worries at all! Thanks for the review 🙏🏻

@@ -279,6 +280,10 @@ def train(
If ignore, will be the same as `learning_rate`.
l2_weight (float, *optional*):
Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression.
max_length (int, *optional*, defaults to `None`):
The maximum number of tokens for one data sample. Currently only for training the differentiable head.
If`None`, will use the maximum number of tokens the model body can accept.
Copy link
Contributor

@PhilipMay PhilipMay Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After "If" there is a space missing.

            If`None`, will use the maximum number of tokens the model body can accept.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding that, @PhilipMay!

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating - looks great!

@lewtun lewtun merged commit 0f828e4 into huggingface:main Dec 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants