Skip to content

Commit

Permalink
Fix FSDP error (huggingface#1196)
Browse files Browse the repository at this point in the history
* Fix FSDP error

Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP.

* Apply suggestions from code review

force return_dict

Co-authored-by: Younes Belkada <[email protected]>

---------

Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
2 people authored and Andrew Lapp committed May 10, 2024
1 parent 29cf6c8 commit f30932e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,13 @@ def compute_loss(
rewards_chosen = model(
input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"],
)[0]
return_dict=True,
)["logits"]
rewards_rejected = model(
input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"],
)[0]
return_dict=True,
)["logits"]
# calculate loss, optionally modulate with margin
if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
Expand Down

0 comments on commit f30932e

Please sign in to comment.