Skip to content

Commit

Permalink
🪆 Fix for Incorrect ValueError Handling in reward_weights in grpo_tra…
Browse files Browse the repository at this point in the history
…iner.py (#2843)

- Fixed a bug where an extra `len` call inside the error message caused a `TypeError` instead of the expected `ValueError`.
- Replaced `len(len(args.reward_weights))` with the correct `len(args.reward_weights)` to properly calculate the number of reward weights.
- Ensured that a `ValueError` is now raised with an accurate and clear message when the number of reward weights does not match the number of reward functions.

This fix prevents confusion during debugging and ensures proper error handling during validation.

Tested with cases where:
- `args.reward_weights` is None (default case).
- `args.reward_weights` has mismatched lengths with `reward_funcs`.
  • Loading branch information
loveychen authored Feb 13, 2025
1 parent b0f513c commit 8830786
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(len(args.reward_weights))}) must match number of reward "
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
Expand Down

0 comments on commit 8830786

Please sign in to comment.