Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KTO] fix interleaving, reporting, hanging bugs #1499

Merged
merged 64 commits into from
Apr 3, 2024
Merged
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
6ee3be4
add warning for imbalanced data
kawine Feb 25, 2024
22dd810
update documentation
kawine Feb 25, 2024
8d14930
update script commands to be same as in dpo
kawine Feb 25, 2024
8a490af
use batch_size KL examples and batch_size target examples to calculat…
kawine Feb 25, 2024
f826600
fix deepspeed issue
kawine Feb 25, 2024
688ed6c
speed up forward with no_grad for KL
kawine Feb 26, 2024
587517b
Merge branch 'huggingface:main' into main
kawine Feb 28, 2024
e128f09
add some removed metrics
kawine Feb 28, 2024
2d860b8
Update trl/trainer/kto_trainer.py
kashif Feb 28, 2024
48d25ff
Update trl/trainer/kto_trainer.py
kashif Feb 28, 2024
392bcc0
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
a42049f
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
5696814
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
000d5d8
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
2738d1f
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
d7f63c5
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
824da55
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
4399af4
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
69094be
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
73f7ed7
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
5b95aca
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
3102901
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
ca68f24
add more detailed comments
kawine Feb 28, 2024
94fb375
convert assert to ValueError
kawine Feb 28, 2024
8f7e788
Update kto_trainer.py
kawine Feb 29, 2024
ed19ed5
precommit formatting
kawine Feb 29, 2024
310bd97
Merge branch 'main' of https://github.com/kawine/trl into main
kawine Feb 29, 2024
639f4de
Merge branch 'huggingface:main' into main
kawine Mar 1, 2024
ee7d6a4
remove nans in metrics by gathering across machines
kawine Mar 1, 2024
7ae95c2
fix formatting
kawine Mar 1, 2024
1b96b2d
fix choice of mismatched examples for KL term
kawine Mar 4, 2024
81b60da
describe weights
kawine Mar 4, 2024
1f145b9
fix hanging issue in distributed training
kawine Mar 6, 2024
83ed882
linting
kawine Mar 7, 2024
9c5480d
Merge branch 'main' of https://github.com/kawine/trl into main
kawine Mar 7, 2024
15251ff
move metrics to cpu
kawine Mar 7, 2024
8f9fdfe
Update trl/trainer/kto_trainer.py
kawine Mar 7, 2024
600aad8
Update trl/trainer/kto_trainer.py
kashif Mar 8, 2024
8b5367e
Update trl/trainer/kto_trainer.py
kashif Mar 8, 2024
5cc6fed
Merge branch 'huggingface:main' into main
kawine Mar 9, 2024
03dfe90
Merge branch 'huggingface:main' into main
kawine Mar 11, 2024
1680de6
fix tokenization error: lack of bos
kawine Mar 11, 2024
80fa86d
change user warning for weight hyperparams
kawine Mar 11, 2024
8f112ce
minor update to docs
kawine Mar 11, 2024
0cc2d8f
reshape attention mask
kawine Mar 12, 2024
eed3044
reformat
kawine Mar 12, 2024
5d7fdd1
Merge branch 'main' of https://github.com/kawine/trl into main
kawine Mar 12, 2024
0bfd326
add test for bos/eos tokens
kawine Mar 12, 2024
86af5dc
Merge branch 'huggingface:main' into main
kawine Mar 14, 2024
a1dfa81
move dependency location
kawine Mar 14, 2024
19afc89
Update tests/test_kto_trainer.py
kashif Mar 14, 2024
2d4039e
Merge branch 'huggingface:main' into main
kawine Mar 24, 2024
bbd5715
don't report nan metrics
kawine Mar 24, 2024
f603aeb
Merge branch 'main' of https://github.com/kawine/trl into main
kawine Mar 24, 2024
856b796
don't report nan metrics and remove data interleaving
kawine Mar 24, 2024
7f0bea8
merge latest changes in trl
kawine Mar 31, 2024
8cf28a6
fix bugs in calculating metrics
kawine Mar 31, 2024
aef50f1
no need to gather KL term
kawine Mar 31, 2024
3e10bae
minor changes
kawine Apr 1, 2024
2a38b15
use nanmean for losses
kawine Apr 1, 2024
e1b6132
Merge branch 'huggingface:main' into main
kawine Apr 2, 2024
7130212
remove disabling of wandb
kawine Apr 2, 2024
2fb641f
revert changes
kawine Apr 3, 2024
0b44e42
Merge branch 'main' of https://github.com/kawine/trl into main
kawine Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 57 additions & 52 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn.functional as F
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset, concatenate_datasets, interleave_datasets
from datasets import Dataset, concatenate_datasets
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -485,6 +485,10 @@ def make_inputs_require_grad(module, input, output):
self.undesirable_weight = args.undesirable_weight

with PartialState().local_main_process_first():
# Shuffle the datasets
train_dataset = train_dataset.shuffle(seed=args.data_seed)
if eval_dataset is not None:
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
# Tokenize and prepare the training datasets
train_dataset = train_dataset.map(
_tokenize,
Expand All @@ -500,8 +504,8 @@ def make_inputs_require_grad(module, input, output):
raise ValueError(
"Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
)
# Note: for best results, mismatched outputs y' used to estimate the KL term for a batch should be the
# same as the matched outputs y used to estimate the rewards in that batch, just paired with different x
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
train_kl_dataset = train_dataset.map(
_get_kl_dataset, batched=True, batch_size=total_batch_size, desc="Extracting KL train dataset"
)
Expand Down Expand Up @@ -601,30 +605,12 @@ def make_inputs_require_grad(module, input, output):
UserWarning,
)

# split the dataset and interleave them together with equal probability of choosing chosen or rejected
interleaved_train_dataset = interleave_datasets(
[desirable, undesirable],
stopping_strategy="all_exhausted",
)
interleaved_train_dataset = interleaved_train_dataset.shuffle(seed=args.data_seed)

if eval_dataset is not None:
interleaved_eval_dataset = interleave_datasets(
[
eval_dataset.filter(lambda x: x["label"], num_proc=args.dataset_num_proc),
eval_dataset.filter(lambda x: not x["label"], num_proc=args.dataset_num_proc),
],
stopping_strategy="all_exhausted",
)
else:
interleaved_eval_dataset = None

super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=interleaved_train_dataset,
eval_dataset=interleaved_eval_dataset,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
Expand Down Expand Up @@ -974,11 +960,11 @@ def kto_loss(
"""Compute the KTO loss for a batch of policy and reference model log probabilities.

Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)

Returns:
Expand All @@ -996,17 +982,17 @@ def kto_loss(
chosen_rewards = self.beta * chosen_logratios.detach()
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
chosen_losses = torch.Tensor([torch.nan]).to(self.accelerator.device)
chosen_rewards = torch.Tensor([torch.nan]).to(self.accelerator.device)
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)

if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_losses = 1 - F.sigmoid(self.beta * (KL - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
rejected_losses = torch.Tensor([torch.nan]).to(self.accelerator.device)
rejected_rewards = torch.Tensor([torch.nan]).to(self.accelerator.device)
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)

losses = torch.cat(
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
Expand Down Expand Up @@ -1061,7 +1047,7 @@ def get_batch_loss_metrics(
reference_KL_logps,
) = self.forward(self.ref_model, batch)

losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
losses, chosen_rewards, rejected_rewards, KL = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
Expand All @@ -1070,25 +1056,38 @@ def get_batch_loss_metrics(
reference_KL_logps,
)

mean_chosen_reward = chosen_rewards.nanmean().detach()
mean_rejected_reward = rejected_rewards.nanmean().detach()
mean_chosen_logps = policy_chosen_logps.nanmean().detach()
mean_rejected_logps = policy_rejected_logps.nanmean().detach()
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen)
all_num_rejected = self.accelerator.gather(num_rejected)

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather(mean_chosen_reward).nanmean().cpu()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather(mean_rejected_reward).nanmean().cpu()
metrics[f"{prefix}rewards/margins"] = metrics[f"{prefix}rewards/chosen"] - metrics[f"{prefix}rewards/rejected"]
metrics[f"{prefix}kl"] = kl.item() # has already been gathered in kto_loss
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather(mean_chosen_logps).nanmean().cpu()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather(mean_rejected_logps).nanmean().cpu()

loss = (
losses.mean()
if losses.shape[0] != 0
else torch.tensor(float("nan"), requires_grad=True).to(self.accelerator.device)
)
return loss, metrics

if all_num_chosen.sum().item() > 0:
metrics[f"{prefix}rewards/chosen"] = (
(self.accelerator.gather(chosen_rewards.mean()) * all_num_chosen).nansum() / all_num_chosen.sum()
).item()
metrics[f"{prefix}logps/chosen"] = (
(self.accelerator.gather(policy_chosen_logps.mean()) * all_num_chosen).nansum() / all_num_chosen.sum()
).item()

if all_num_rejected.sum().item() > 0:
metrics[f"{prefix}rewards/rejected"] = (
(self.accelerator.gather(rejected_rewards.mean()) * all_num_rejected).nansum() / all_num_rejected.sum()
).item()
metrics[f"{prefix}logps/rejected"] = (
(self.accelerator.gather(policy_rejected_logps.mean()) * all_num_rejected).nansum()
/ all_num_rejected.sum()
).item()

metrics[f"{prefix}kl"] = KL.item()
if all_num_chosen.sum().item() > 0 and all_num_rejected.sum().item() > 0:
metrics[f"{prefix}rewards/margins"] = (
metrics[f"{prefix}rewards/chosen"] - metrics[f"{prefix}rewards/rejected"]
)

return losses.nanmean(), metrics

def compute_loss(
self,
Expand All @@ -1101,9 +1100,13 @@ def compute_loss(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
loss = loss.to(self.args.device)
# force log the metrics
if self.accelerator.is_main_process:
self.store_metrics(metrics, train_eval="train")
Expand All @@ -1114,10 +1117,12 @@ def compute_loss(

def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
if isinstance(value, list):
self._stored_metrics[train_eval][key].extend(value)
else:
self._stored_metrics[train_eval][key].append(value)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# We use a sequential sampler for training as the order of the interleaved dataset is important
if self.train_dataset is None or not has_length(self.train_dataset):
return None
return SequentialSampler(self.train_dataset)
Expand Down
Loading