-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Push v_head
when using AutoModelForCausalLMWithValueHead
#86
[core] Push v_head
when using AutoModelForCausalLMWithValueHead
#86
Conversation
v_head
when using AutoModelForCausalLMWithValueHead
The documentation is not available anymore as the PR was closed or merged. |
trl/models/modeling_base.py
Outdated
# state_dict is removed from the model after loading it. | ||
# since this is only done in edge cases we force the user to do it only | ||
# if they explicitly set resume_training to True | ||
if resume_training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone is using the AutoModelForCausalLMWithValueHead
to load a model I think we can assume that they want to use the Value head so I we should always load it if it's available. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK for me!
We should advice users to use AutoModelForCausalLM
when using the model outside trl
(i.e. not for training with PPO)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more description in 853d52d
- update based on comments - add more tests - update docs
What does this PR do?
This PR addresses saving
v_head
weights when callingsave_pretrained
. Before this PR if a user wanted to share intermediate-trainedv_head
weights it was not possible. The workaround that I have found needs to load the state dict twice, once by thefrom_pretrained
method that is called by thepretrained_model
attribute and once afterwards to manually load thev_head
.This PR introduces a new kwarg on the method
from_pretrained
termed asresume_training
as I think this is quite an edge case.Added also some tests
TODOs:
push_to_hub
cc @lvwerra @lewtun