-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix bugs in KTO implementation #1380
Conversation
…e batch_size losses
thanks @kawine having a look now! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for all these fixes @kawine - the PR looks great 🔥 !
I've left a few nits, but apart from that LGTM once the style checks pass
--lora_alpha 16 \ | ||
--evaluation_strategy "steps" \ | ||
--logging_first_step True \ | ||
--learning_rate 1e-3 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my general understanding, are large learning rates like this the typical values for KTO runs? I noticed in your config the you have 5e-7 which is closer to what I expect.
We aim for these simple examples to "just work", so having a good default here would be great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied over the suggested commands in scripts/dpo.py and changed dpo -> kto. The command for DPO also seems to use a different learning rate than the 5e-7 that the original DPO repo uses?
All our experiments in the paper used the same training configs for DPO and KTO, so I'd recommend they be the same here as well (what exactly those configs are gpt2, I don't have a strong opinion on).
--model_name_or_path=gpt2 \ | ||
--per_device_train_batch_size 4 \ | ||
--max_steps 1000 \ | ||
--learning_rate 1e-3 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does KTO with LoRA work better with 10x larger learning rates than full training? If so, I would increase it here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, this was just copy-pasted from scripts/dpo.py. In all our experiments so far, we've used identical training configs for both, usually what the DPO repo recommended.
@kawine feel free to just apply the suggestions which are stylistic fixes to get the CI to pass (apart from the |
add reference to paper Co-authored-by: lewtun <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
* add warning for imbalanced data * update documentation * update script commands to be same as in dpo * use batch_size KL examples and batch_size target examples to calculate batch_size losses * fix deepspeed issue * speed up forward with no_grad for KL * add some removed metrics * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py * Update trl/trainer/kto_trainer.py add reference to paper Co-authored-by: lewtun <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Update trl/trainer/kto_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * add more detailed comments * convert assert to ValueError * Update kto_trainer.py * precommit formatting --------- Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: lewtun <[email protected]>
Hi there, KTO author here with a PR. Current KTO implementation puts the mismatched examples (x,y') used to calculate the KL term and the matched examples (x,y) from the data into the same stream. This causes two problems:
If there are no mismatched examples in a batch, you can NaN for the KL term and a NaN for the loss. This is not fully fixed by increasing batch size (as suggested here)-- you can get a KL estimate that is not NaN but still poor because it is made with so 1 or 2 examples.
The number of individual examples != the batch size. For example, if you have a batch of 32, depending on how many mismatched examples used to calculate the KL term there are, you can end up with 10 positive/desirables and 2 negative/undesirable losses; 0 positive and 8 negative, etc. This leads to high-variance updates.
(also fixed a minor issue that caused problems with running KTO with accelerate)