-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
/ xxxTrainer
] Automatic tagging
#1329
Conversation
Verified that this PR do not create any conflict with the previous tagging logic already in place: import datasets
import peft
import transformers
import trl
model_dir = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = transformers.AutoModelForCausalLM.from_pretrained(model_dir)
ds_train = datasets.load_dataset("imdb", split="train[:10]")
trainer = trl.SFTTrainer(
model=model,
args=transformers.TrainingArguments(
output_dir="test-automatic-tagging-from-trainer",
max_steps=1,
remove_unused_columns=True,
),
peft_config=peft.LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="Causal_LM",
),
train_dataset=ds_train,
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=8,
)
model.push_to_hub("ybelkada/test-automatic-tagging")
trainer.push_to_hub() https://huggingface.co/ybelkada/test-automatic-tagging-from-trainer |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
cc @kashif to also add in KTO Trainer
ah yes thanks! added to #1181 |
* automatic tagging * add comments * fix tests * fix
What does this PR do?
This PR injects trl / dpo / sft etc tags on the model at the trainer's init. That way models that get pushed with model.push_to_hub() will also get the correct tags instead of users that call trainer.push_to_hub
cc @lvwerra @osanseviero for awareness