Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] officially support SFT (Supervised Finetuning) #323

Merged
merged 33 commits into from
May 3, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Apr 26, 2023

What does this PR do?

This PR introduces SFTTrainer class. A handy and easy to use class to train your supervised fine-tuned model on instruction-based datasets.
The API is easy to use, and also modular enough if you want to customize your training for advanced users.

You just need to pass a model id, optionally pass a PeftConfig to train adapters only. Pass also from_pretrained kwargs directly to SFTTrainer for advanced users, for example to load your model in 8bit mode.

The PR also introduces ConstantLengthDataset, a handy class to create instruction-based datasets. Just pass a tokenizer, a dataset, and a function to specify the formatting you want to have, and you should be good to go

Quickstart:

from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
)

trainer.train()

cc @lvwerra

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 26, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada marked this pull request as ready for review April 26, 2023 15:34
@younesbelkada younesbelkada requested review from lvwerra and lewtun April 26, 2023 15:34
Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting into great shape! Left a few comments.

peft_config: Optional[Dict] = None,
dataset_text_field: Optional[str] = None,
packing: Optional[bool] = True,
dataset_kwargs: Optional[Dict] = {},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would avoid kwargs fields as much as possible. what values can be passed here? can't have them as separate kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset_kwargs corresponds to the kwargs of ConstantLengthDataset, there are 6 optional of them, I think we should move them to proper kwarg as we can always modify that class, but for prepare_int8_training kwargs I would maybe keep them as they are

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe we should educate users to create models outside the trainer in case they want to have full control over that function and remove prepare_int8_training_kwargs. Wdyt?


def _prepare_non_packed_dataloader(self, tokenizer, dataset, dataset_text_field, data_collator, max_seq_len):
# tokenize the dataset
dataset = dataset.map(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just tokenize the dataset so you have the input ids. i would not pad at all (the collator will do this, and if a batch has elements that are all shorter it will be faster).

then you can just pass the tokenized dataset along the data collator (data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False))

we did something similar here: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt

then we can maybe just call the function _tokenize_dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great pointer!

@younesbelkada younesbelkada requested a review from lvwerra April 28, 2023 14:08
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this sweet feature @younesbelkada 🔥 !

I've left a few questions and suggestions for things that could help improve user understanding. I'll let the core maintainer approve this :)

Whether to use an infinite dataset or not. Defaults to `False`.
num_of_sequences (`Optional[int]`):
The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
chars_per_token (`Optional[float]`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does trl provide a helper method for this? If yes, it would be nice to see a small example in the docs of how this dataset works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a pointer to an example that uses this in 91b2643

num_of_sequences: Optional[int] = 1024,
chars_per_token: Optional[float] = 3.6,
prepare_in_int8_kwargs: Optional[Dict] = {},
**pretrained_kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly opposed to pass kwargs like that. I think if people want to use something other than default they should just load the model outside. It's just one additional line of code for them. We on the other hand need to then worry about those: e.g. if someone has a typo in one of the above kwargs they will get a weird error because it's passed to the model.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah totally aligned on this!

@younesbelkada younesbelkada requested a review from lvwerra May 3, 2023 08:29
@younesbelkada younesbelkada merged commit c60fd91 into huggingface:main May 3, 2023
@younesbelkada younesbelkada deleted the add-sft-trainer branch May 3, 2023 08:42
@tigerinus
Copy link

The IMDB dataset is a sentence labelled by 1 or 0 indicating whether the sentence is positive feedback or negative feedback.

Question is, does this SFTTrainer train against the labels or the context of each sentence?

@lvwerra
Copy link
Member

lvwerra commented Aug 10, 2023

No, the SFTTrainer only trains on the text with the causal language modeling objective.

@hezhiyang2000
Copy link

Is this function suitable to train a instruction-following LLM? I reviewed the code and can't find the code about labels that avoiding to calculate the loss of instruction, prompts and labels.

@lvwerra
Copy link
Member

lvwerra commented Sep 4, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants