-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
] officially support SFT (Supervised Finetuning)
#323
[core
] officially support SFT (Supervised Finetuning)
#323
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting into great shape! Left a few comments.
trl/trainer/sft_trainer.py
Outdated
peft_config: Optional[Dict] = None, | ||
dataset_text_field: Optional[str] = None, | ||
packing: Optional[bool] = True, | ||
dataset_kwargs: Optional[Dict] = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would avoid kwargs fields as much as possible. what values can be passed here? can't have them as separate kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dataset_kwargs
corresponds to the kwargs of ConstantLengthDataset
, there are 6 optional of them, I think we should move them to proper kwarg as we can always modify that class, but for prepare_int8_training
kwargs I would maybe keep them as they are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe we should educate users to create models outside the trainer in case they want to have full control over that function and remove prepare_int8_training_kwargs
. Wdyt?
trl/trainer/sft_trainer.py
Outdated
|
||
def _prepare_non_packed_dataloader(self, tokenizer, dataset, dataset_text_field, data_collator, max_seq_len): | ||
# tokenize the dataset | ||
dataset = dataset.map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just tokenize the dataset so you have the input ids. i would not pad at all (the collator will do this, and if a batch has elements that are all shorter it will be faster).
then you can just pass the tokenized dataset along the data collator (data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
we did something similar here: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
then we can maybe just call the function _tokenize_dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great pointer!
Co-authored-by: Leandro von Werra <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this sweet feature @younesbelkada 🔥 !
I've left a few questions and suggestions for things that could help improve user understanding. I'll let the core maintainer approve this :)
Whether to use an infinite dataset or not. Defaults to `False`. | ||
num_of_sequences (`Optional[int]`): | ||
The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`. | ||
chars_per_token (`Optional[float]`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does trl
provide a helper method for this? If yes, it would be nice to see a small example in the docs of how this dataset works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a pointer to an example that uses this in 91b2643
trl/trainer/sft_trainer.py
Outdated
num_of_sequences: Optional[int] = 1024, | ||
chars_per_token: Optional[float] = 3.6, | ||
prepare_in_int8_kwargs: Optional[Dict] = {}, | ||
**pretrained_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am slightly opposed to pass kwargs like that. I think if people want to use something other than default they should just load the model outside. It's just one additional line of code for them. We on the other hand need to then worry about those: e.g. if someone has a typo in one of the above kwargs they will get a weird error because it's passed to the model.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah totally aligned on this!
Co-authored-by: Leandro von Werra <[email protected]>
…into add-sft-trainer
The IMDB dataset is a sentence labelled by 1 or 0 indicating whether the sentence is positive feedback or negative feedback. Question is, does this SFTTrainer train against the labels or the context of each sentence? |
No, the SFTTrainer only trains on the text with the causal language modeling objective. |
Is this function suitable to train a instruction-following LLM? I reviewed the code and can't find the code about labels that avoiding to calculate the loss of instruction, prompts and labels. |
See the docs here: https://huggingface.co/docs/trl/sft_trainer#train-on-completions-only |
What does this PR do?
This PR introduces
SFTTrainer
class. A handy and easy to use class to train your supervised fine-tuned model on instruction-based datasets.The API is easy to use, and also modular enough if you want to customize your training for advanced users.
You just need to pass a model id, optionally pass a
PeftConfig
to train adapters only. Pass alsofrom_pretrained
kwargs directly toSFTTrainer
for advanced users, for example to load your model in 8bit mode.The PR also introduces
ConstantLengthDataset
, a handy class to create instruction-based datasets. Just pass a tokenizer, a dataset, and a function to specify the formatting you want to have, and you should be good to goQuickstart:
cc @lvwerra