From 989f4a458f70a8ab07b0444949ddef5097f4cd70 Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Mon, 8 Jan 2024 16:52:25 -0800 Subject: [PATCH 1/2] Fix FSDP error Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP. --- trl/trainer/reward_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 44a5b79223..8b66f454d1 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -220,11 +220,11 @@ def compute_loss( rewards_chosen = model( input_ids=inputs["input_ids_chosen"], attention_mask=inputs["attention_mask_chosen"], - )[0] + )["logits"] rewards_rejected = model( input_ids=inputs["input_ids_rejected"], attention_mask=inputs["attention_mask_rejected"], - )[0] + )["logits"] # calculate loss, optionally modulate with margin if "margin" in inputs: loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() From b4153bcbcaa4ae844d9f6230f2cbca5e13c7d84b Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Tue, 9 Jan 2024 09:12:30 -0800 Subject: [PATCH 2/2] Apply suggestions from code review force return_dict Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- trl/trainer/reward_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 8b66f454d1..f2af50b634 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -220,10 +220,12 @@ def compute_loss( rewards_chosen = model( input_ids=inputs["input_ids_chosen"], attention_mask=inputs["attention_mask_chosen"], + return_dict=True, )["logits"] rewards_rejected = model( input_ids=inputs["input_ids_rejected"], attention_mask=inputs["attention_mask_rejected"], + return_dict=True, )["logits"] # calculate loss, optionally modulate with margin if "margin" in inputs: