-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow setting max length #176
Allow setting max length #176
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this nice quality of life improvement @blakechi 🔥 !
The PR is looking great, and I've left a few small comments that would be nice to address before we merge
tests/test_trainer.py
Outdated
column_mapping={"text_new": "text", "label_new": "label"}, | ||
) | ||
trainer.unfreeze(keep_body_frozen=True) | ||
trainer.train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we think about a way to actually test the large value has been overwritten with the model max length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right. I will modify the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just updated it. I tested whether the warning raises up correctly and checked the overwritten value in test_modeling.py
.
tests/test_trainer.py
Outdated
column_mapping={"text_new": "text", "label_new": "label"}, | ||
) | ||
trainer.unfreeze(keep_body_frozen=True) | ||
trainer.train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here - should we be testing the behaviour explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree. Will push an update later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just updated it!
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay and thanks for iterating @blakechi 🤘
I've left one nit and then think we can merge!
l2_weight=0.0, | ||
max_length=max_length, | ||
) | ||
self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Co-authored-by: lewtun <[email protected]>
No worries at all! Thanks for the review 🙏🏻 |
src/setfit/trainer.py
Outdated
@@ -279,6 +280,10 @@ def train( | |||
If ignore, will be the same as `learning_rate`. | |||
l2_weight (float, *optional*): | |||
Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression. | |||
max_length (int, *optional*, defaults to `None`): | |||
The maximum number of tokens for one data sample. Currently only for training the differentiable head. | |||
If`None`, will use the maximum number of tokens the model body can accept. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After "If" there is a space missing.
If`None`, will use the maximum number of tokens the model body can accept.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding that, @PhilipMay!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating - looks great!
This PR is opened for resolving #172 .
max_length
when callingSetFitTrainer.train
.max_length
will be set to the maximum number of tokens the model body can handle if it's too large.A snippet of usage: