From 0a830b9d4b300fc63e4806f42b884fa742d14798 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 28 Dec 2024 14:22:08 +0100 Subject: [PATCH 1/2] revert orpo changes --- trl/trainer/orpo_trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 65d80802be..344b79b580 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -775,12 +775,17 @@ def cross_entropy_loss(logits, labels): loss = loss_fct(logits, labels) return loss - labels = concatenated_batch["concatenated_labels"].clone() + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( all_logits, - labels, + concatenated_batch["concatenated_labels"], average_log_prob=True, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, From 8e7648b54b9da2c215ed93823ce8fc255285c0f7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 8 Jan 2025 15:50:00 +0100 Subject: [PATCH 2/2] add comment --- trl/trainer/orpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index ddac34555f..803bda6699 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -784,6 +784,7 @@ def cross_entropy_loss(logits, labels): labels = concatenated_batch["concatenated_input_ids"].clone() attention_mask = concatenated_batch["concatenated_attention_mask"] labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + # orpo chosen nll loss is computed over the full prompt and response chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps(