Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPOTrainer #1382

Merged
merged 35 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9d45ce3
add CPOTrainer
fe1ixxu Feb 29, 2024
ed1e7c6
add docs
kashif Mar 4, 2024
0e76df1
fix formatting
kashif Mar 4, 2024
8c95468
removed precompute_ref_log_probs arg
kashif Mar 4, 2024
551a0df
remove precompute_ref_log_probs
kashif Mar 4, 2024
4850ac9
typos
kashif Mar 7, 2024
7c4e6b4
finish cpo trainer doc
fe1ixxu Mar 7, 2024
ba760de
remove redundant lines
fe1ixxu Mar 7, 2024
44256a5
typo
fe1ixxu Mar 7, 2024
f4d07cb
formatting
kashif Mar 15, 2024
8ca08f8
compute chosen nll loss also for enc-dec models
fe1ixxu Mar 15, 2024
4d2c292
fix gradient error of inplace operation for enc-dec models
fe1ixxu Mar 16, 2024
abd4a45
formatting
kashif Mar 17, 2024
2137ac9
use CPOConfig
kashif Mar 17, 2024
606f294
formatting
kashif Mar 17, 2024
afd02c2
use model_init_kwargs from CPOConfig
kashif Mar 17, 2024
a1836ac
comments in example
kashif Mar 17, 2024
da4a4ee
fix doc string
kashif Mar 17, 2024
ef1cd23
fix typo in docstring
kashif Mar 17, 2024
2aabcf4
update year
kashif Mar 17, 2024
18975f1
Merge branch 'huggingface:main' into cpo-trainer
kashif Mar 17, 2024
bcb6cd7
fixed typo
kashif Mar 17, 2024
27672c3
Merge branch 'main' into cpo-trainer
kashif Mar 18, 2024
39c8e61
use preference dataset
kashif Mar 18, 2024
93591bd
fix learning rate
kashif Mar 19, 2024
6b24fd1
move dataset_num_proc to configs
kashif Mar 19, 2024
6def5dd
Update cpo paper link from HF: cpo_trainer.mdx
fe1ixxu Mar 21, 2024
d9a48de
update description for CPO: cpo_trainer.mdx
fe1ixxu Mar 21, 2024
b300473
remove _prepare_deepspeed for cpo
fe1ixxu Mar 21, 2024
4a0dcd0
Add explanation to CPO loss
fe1ixxu Mar 21, 2024
25c5495
format
fe1ixxu Mar 21, 2024
328434e
fix bug when lengths are given
kashif Mar 22, 2024
dd9344a
Merge remote-tracking branch 'upstream/main' into cpo-trainer
kashif Mar 22, 2024
8c842be
add CPOTrainer to README
kashif Mar 22, 2024
2ff65bd
fix grammer
kashif Mar 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
title: DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: cpo_trainer
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: iterative_sft_trainer
Expand Down
102 changes: 102 additions & 0 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# CPO Trainer

Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://arxiv.org/pdf/2401.08417.pdf) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level CPO trains models to
avoid generating adequate but not perfect translations in the Machine Translation (MT) task.

CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.

## Expected dataset format

The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:

- `prompt`
- `chosen`
- `rejected`

for example:

```py
cpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.


## Expected model format
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.

## Using the `CPOTrainer`
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.

```py
cpo_config = CPOConfig(
beta=0.1,
)

cpo_trainer = CPOTrainer(
model,
args=cpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:

```py
cpo_trainer.train()
```

## Loss functions

Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.

The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).


## Logging

While training and evaluating we record the following reward metrics:

* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses

## CPOTrainer

[[autodoc]] CPOTrainer


## CPOConfig

[[autodoc]] CPOConfig
121 changes: 121 additions & 0 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the CPO training script with the following command with some example arguments.
In general, the optimal configuration for CPO will be similar to that of DPO:

# regular:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-cpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns

# peft:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-cpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""

import multiprocessing
from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config


@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
args, cpo_args, model_config = parser.parse_args_into_dataclasses()

################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
peft_config = get_peft_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
ds = load_dataset(args.dataset)
if cpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]

################
# Training
################
trainer = CPOTrainer(
model,
args=cpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)

# train and save the model
trainer.train()
trainer.save_model(cpo_args.output_dir)
3 changes: 2 additions & 1 deletion examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,6 @@ def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_d
peft_config=get_peft_config(model_args),
)

# 5. train
# 5. train and save the model
kto_trainer.train()
kto_trainer.save_model(kto_args.output_dir)
Loading
Loading