Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bugs in KTO implementation #1380

Merged
merged 27 commits into from
Feb 29, 2024
Merged

fix bugs in KTO implementation #1380

merged 27 commits into from
Feb 29, 2024

Conversation

kawine
Copy link
Contributor

@kawine kawine commented Feb 28, 2024

Hi there, KTO author here with a PR. Current KTO implementation puts the mismatched examples (x,y') used to calculate the KL term and the matched examples (x,y) from the data into the same stream. This causes two problems:

  • If there are no mismatched examples in a batch, you can NaN for the KL term and a NaN for the loss. This is not fully fixed by increasing batch size (as suggested here)-- you can get a KL estimate that is not NaN but still poor because it is made with so 1 or 2 examples.

  • The number of individual examples != the batch size. For example, if you have a batch of 32, depending on how many mismatched examples used to calculate the KL term there are, you can end up with 10 positive/desirables and 2 negative/undesirable losses; 0 positive and 8 negative, etc. This leads to high-variance updates.

(also fixed a minor issue that caused problems with running KTO with accelerate)

@kawine
Copy link
Contributor Author

kawine commented Feb 28, 2024

cc @kashif @lewtun

@kashif
Copy link
Collaborator

kashif commented Feb 28, 2024

thanks @kawine having a look now!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for all these fixes @kawine - the PR looks great 🔥 !

I've left a few nits, but apart from that LGTM once the style checks pass

--lora_alpha 16 \
--evaluation_strategy "steps" \
--logging_first_step True \
--learning_rate 1e-3 \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my general understanding, are large learning rates like this the typical values for KTO runs? I noticed in your config the you have 5e-7 which is closer to what I expect.

We aim for these simple examples to "just work", so having a good default here would be great

Copy link
Contributor Author

@kawine kawine Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied over the suggested commands in scripts/dpo.py and changed dpo -> kto. The command for DPO also seems to use a different learning rate than the 5e-7 that the original DPO repo uses?

All our experiments in the paper used the same training configs for DPO and KTO, so I'd recommend they be the same here as well (what exactly those configs are gpt2, I don't have a strong opinion on).

--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does KTO with LoRA work better with 10x larger learning rates than full training? If so, I would increase it here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, this was just copy-pasted from scripts/dpo.py. In all our experiments so far, we've used identical training configs for both, usually what the DPO repo recommended.

@kashif
Copy link
Collaborator

kashif commented Feb 28, 2024

@kawine feel free to just apply the suggestions which are stylistic fixes to get the CI to pass (apart from the assert issue flagged above)

@kashif kashif merged commit 14e0d78 into huggingface:main Feb 29, 2024
9 checks passed
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: lewtun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants