Skip to content

Commit

Permalink
Update kto_trainer.mdx with hparam recs
Browse files Browse the repository at this point in the history
  • Loading branch information
kawine authored Oct 1, 2024
1 parent 847790d commit 4a8645a
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ For a full example have a look at [`examples/scripts/kto.py`].

Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.
You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below).

## Expected dataset format

Expand Down Expand Up @@ -49,7 +50,8 @@ kto_dataset_dict = {
```

where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
In theory, the dataset must contain at least one desirable and one undesirable completion; however, some people have success running KTO on _only_ desirable or undesirable data (in the latter case, it is best to use a conservative learning rate).


## Expected model format
Expand All @@ -59,13 +61,17 @@ The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that

For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.

The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
The `beta` refers to the hyperparameter that controls how quickly the loss saturates, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).

The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.

<Tip>
It is strongly recommended you use a learning rate between `5e-7` and `5e-6` with an effective batch size between `8` and `32`, for both LoRA and full finetuning. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, using smaller batch sizes and/or more training epochs will give you better results.
Every choice of `beta` has a maximum learning rate it will tolerate before learning degenerates. For the default `beta = 0.1', this learning rate is `1e-6` for most models. The lower beta is, the lower your learning rate should be. In general, we strongly recommend a learning rate between `5e-7` and `5e-6`. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, use more epochs.
</Tip>

<Tip>
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
</Tip>

```py
Expand Down

0 comments on commit 4a8645a

Please sign in to comment.