Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional Additional Loss to Center Reward Models' Outputs #1932

Merged
merged 13 commits into from
Aug 17, 2024
Merged

Optional Additional Loss to Center Reward Models' Outputs #1932

merged 13 commits into from
Aug 17, 2024

Conversation

RylanSchaeffer
Copy link
Contributor

@RylanSchaeffer RylanSchaeffer commented Aug 15, 2024

In this issue, I requested an optional additional loss function to center the rewards output by a trained reward model: #1931

This PR is a sketch of what this might look like.

Please let me know if this seems sensible and what changes, if any, are appropriate.

@qgallouedec qgallouedec linked an issue Aug 17, 2024 that may be closed by this pull request
@qgallouedec
Copy link
Member

I'm currently training a model with this new parameter to see what it looks like.

@RylanSchaeffer
Copy link
Contributor Author

RylanSchaeffer commented Aug 17, 2024 via email

Co-authored-by: Quentin Gallouédec <[email protected]>
@qgallouedec
Copy link
Member

I'm hesitant to set the default.

I wouldn't set it as default either.

@qgallouedec
Copy link
Member

This one #1932 (comment) is about suggesting a value in the doc, not setting it as default.

@RylanSchaeffer
Copy link
Contributor Author

Right. I was more asking you for guidance about what default value we should choose. I think you and I are on the same page.

@RylanSchaeffer
Copy link
Contributor Author

RylanSchaeffer commented Aug 17, 2024

@qgallouedec for my education and for posterity's sake, can you share your experimental results here (once completed)?

@qgallouedec
Copy link
Member

qgallouedec commented Aug 17, 2024

Sure!

Training

Here are the wandb runs

center_rewards_coefficient = None: https://wandb.ai/huggingface/trl/runs/u6zob8ml (brown)
center_rewards_coefficient = 0.01: https://wandb.ai/huggingface/trl/runs/d73qlevz (green)

Screenshot 2024-08-17 at 22 04 39

As expected the loss is a bit larger.

Playing with the trained reward model

center_rewards_coefficient = None: https://huggingface.co/qgallouedec/reward_modeling_anthropic_hh
center_rewards_coefficient = 0.01: https://huggingface.co/qgallouedec/reward_modeling_anthropic_hh_crc

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_id = "qgallouedec/reward_modeling_anthropic_hh"  # without the coef
# model_id = "qgallouedec/reward_modeling_anthropic_hh_crc"  # with the coef

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1)

dataset = load_dataset("Anthropic/hh-rlhf", split="test")

examples = dataset[:8]

input_chosen = tokenizer(examples["chosen"], return_tensors="pt", padding=True)
input_rejected = tokenizer(examples["rejected"], return_tensors="pt", padding=True)

output_chosen = model(**input_chosen)
output_rejected = model(**input_rejected)

mean_chosen = output_chosen.logits.mean().item()
mean_rejected = output_rejected.logits.mean().item()
print(mean_chosen, mean_rejected)
center_rewards_coefficient = Rejected Chosen
None -3.3083 -2.1824
0.01 -0.6140 0.2871

So overall it's looking good!

@qgallouedec
Copy link
Member

I'll just add a little piece of documentation and we're good to merge!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec merged commit 42933fa into huggingface:main Aug 17, 2024
9 checks passed
@qgallouedec
Copy link
Member

Thanks for your first contribution @RylanSchaeffer!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Additional Optional Loss to Center Reward Models
3 participants