Skip to content

Commit

Permalink
force return_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
mgerstgrasser committed Jan 9, 2024
1 parent 989f4a4 commit c9ed225
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,12 @@ def compute_loss(
rewards_chosen = model(
input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"],
return_dict=True,
)["logits"]
rewards_rejected = model(
input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"],
return_dict=True,
)["logits"]
# calculate loss, optionally modulate with margin
if "margin" in inputs:
Expand Down

0 comments on commit c9ed225

Please sign in to comment.