-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[KTOTrainer
] add BCO (reward shift and underlying distribution matching)
#1599
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@seanexp can you kindly run:
in the root folder of TRL to clean up the formatting? |
thanks @seanexp one questions, would it make sense in the example script to have an option to try with different embedding models? or are things hard-coded for nomic-embed? |
@seanexp you might need to add scikit-learn dep here: https://github.com/huggingface/trl/blob/main/setup.py#L72 |
oh sorry @seanexp i thought the scikit-learn is in the tests but its in the main trainer... so we need to add that as a regular dependency! 🙇🏽 |
I found that adding Should I still add |
@seanexp yeah lets remove the dep from the test then... you can see here that even though transformers is installed, the tests fail: https://github.com/huggingface/trl/actions/runs/8891165233/job/24412685192#step:5:2843 |
Seems like the dep in the test is necessary. Without the dep the test fails. Shall we leave as it is? @kashif |
@seanexp ok lets leave the dependency in the tests and add a helper here: https://github.com/huggingface/trl/blob/main/trl/import_utils.py for scikit-learn and then in the trainer we can check if its available when BCO is the loss type and if not, we can ask that they install it via |
@seanexp something like this in the KTOConfig: https://github.com/huggingface/trl/blob/main/trl/trainer/ddpo_config.py#L116-L120 |
Users can try different embedding models. Please note that users have to modify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for this great contribution @seanexp ! Overall looks very clean, some CI is failing though:
=========================== short test summary info ============================
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_lora_save - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_0_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_1_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_2_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_3_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_4_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_5_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_6_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_7_gpt2 - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_bco_udm - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_without_providing_ref_model - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_kto_trainer_without_providing_ref_model_with_lora - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_kto_trainer.py::KTOTrainerTester::test_tokenize_and_process_tokens - AttributeError: 'NoneType' object has no attribute 'gradient_accumulation_kwargs'
FAILED tests/test_cli.py::test_sft_cli - AssertionError: An error occured while running the CLI, please double check
Can you have a look before we merge it? 🙏 Thanks !
hmm... I'll take a look few hours later! |
Not calling The tests now run properly on my machine :) @younesbelkada |
good catch @seanexp 🥇 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woah great catch ! Thanks again for adding this nice feature !
add Binary Classifier Optimization (BCO) loss function from https://arxiv.org/abs/2404.04656
Implemented BCE loss, reward shift and underlying distribution matching.
Also added example script at
examples/scripts/bco.py