-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add push to Hub for PPOTrainer #68
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for working on this 💯 !
To the best of my knowledge once a the ppo trainer has been trained the only thing we want to retain is the trained model w/o its value head to perform any task you want with the remaining model that you can do with a standard CausalLM
model.
Note that currently you can also load a trained model with PPO using transformers
:
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collater)
# train ... then:
ppo_trainer.push_to_hub("dummy-trl-model")
# Load from Hub using transformers
from transformers import AutoModelForCausalLM
hub_model = AutoModelForCausalLM.from_pretrained("dummy-trl-model")
since save_pretrained
from AutoModelForCausalLMWithValueHead
is just a wrapper around the save_pretrained
method from the base model: https://github.com/lvwerra/trl/blob/master/trl/models/modeling_base.py#L105-L115
The only reason I see to share the v_head
's weights would be to continue a training. E.g.
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-gpt2-intermediate-training")
ppo_trainer = PPOTrainer(model, ...)
for i in range(epochs):
ppo_trainer.step(...)
(I am unsure if this is what we want currently)
If this is what we want we would certainly change some logic of AutoModelForCausalLMWithValueHead
class, make it maybe inherit from transformers.PreTrainedModel
so that we can call load the v_head
as well using from_pretrained
, and probably add the class attribute _keys_to_ignore_on_load_missing=["v_head.summary.weight", "v_head.summary.bias"]
so that we can initialize a AutoModelForCausalLMWithValueHead
from any HF CausalLM
model.
For now I can see that the v_head
might be not needed for the ref_model
as the only place the forward pass from the ref_model
is called is here. Also I don't think we should push the ref model as the model is always frozen
Let's wait for @lvwerra and @edbeeching to chime in as I might be missing some details here !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks! Left a comment regarding the value head.
Before we merge could you add some tests??
trl/trainer/ppo_trainer.py
Outdated
from trl import AutoModelForCausalLMWithValueHead | ||
|
||
tokenizer = AutoTokenizer.from_pretrained({model_name}) | ||
model = AutoModelForCausalLMWithValueHead.from_pretrained({model_name}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this will load a new value head as the save_pretrained
method of AutoModelForCausalLMWithValueHead
does not save the value head. I think we should save it. I don't know if @younesbelkada wants to add the necessary changes to this or a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, to use the model you don't need TRL and you could just load it with transformers AutoModelForCausalLM
. Installing TRL and the current example could be a second example if you want to continue training or get the value of the value head.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I can take care of that in a follow-up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it has been taken care in #86 !
@@ -41,6 +42,42 @@ | |||
from trl.trainer.utils import AdaptiveKLController, FixedKLController | |||
|
|||
|
|||
MODEL_CARD_TEMPLATE = """--- | |||
license: apache-2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the desired default license for these models?
trl/trainer/ppo_trainer.py
Outdated
- trl | ||
- transformers | ||
- reinforcement-learning | ||
- human-feedback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I picked this instead of rlhf
, since the former is cryptic and we've already agreed on human-feedback
for datasets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe human-feedback
should be added by the user manually, since you don't necessarily need to train from human feedback (e.g. all the sentiment examples).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 47289f5
model = AutoModelForCausalLMWithValueHead.from_pretrained({model_name}) | ||
|
||
inputs = tokenizer("Hello, my llama is cute", return_tensors="pt") | ||
outputs = model(**inputs, labels=inputs["input_ids"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a more natural way to do inference with these models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could also pass it to a text-generation
pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 47289f5
@@ -658,3 +695,22 @@ def log_stats( | |||
|
|||
dist.barrier() | |||
dist.all_reduce(rewards, op=torch.distributed.ReduceOp.SUM) | |||
|
|||
def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, the model card is kind of limited in what users can configure. I can extend this if desired (e.g. closer to the transformers.Trainer
which allows all the tags to be configured)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, we could potentially also use the same method since transformers
is already a dependency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, the create_model_card
in transformers
is a method of the Trainer
, so can't be re-used. I suggest we just keep it simple for now and see if there's a wish from the community to expand it with more tags etc.
Thanks for the comments @lvwerra ! I left a few questions that could do with your feedback - in the meantime I'll add some tests :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lvwerra @younesbelkada this PR looks good to go on my side - note that I refactored the test filenames to reflect the repo structure. This is what we do in transformers
and I personally find it to be a good organising principle as the project grows. Happy to revert if you wish!
model_id = "gpt2" | ||
|
||
# get models and tokenizer | ||
self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need fresh models for each unit test? If not, I can move this to setUpClass()
which is only called once (and saves time)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I don't think we need a fresh model for each test, yes can you move it to setUpClass()
instead 🙏 Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 6c4ab68
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for shipping this! 💯
Let's just refactor the test as you proposed in #68 (comment) otherwise LGTM
I will also take care of refactoring the models test to test push_to_hub
in a follow up PR once this is merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me - thanks for adding!
🔴 Don't merge until I have a fix! Hmm, using the staging endpoint of the Hub for the test is causing some issues because I rely on |
if not os.path.exists(path): | ||
os.makedirs(path) | ||
|
||
user = whoami()["name"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is the reason why the unit test in the staging endpoint fails - by default whoami()
points to hf.co, when we need it to point to https://hub-ci.huggingface.co
One can do this by using HfApi(endpoint="https://hub-ci.huggingface.co")
, but then we need to set an env var for endpoint
which feels clunky.
For now, I propose we skip the unit test until we find a better solution to test push_to_hub
features
OK with you @younesbelkada @lvwerra ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way to solve this is by patching the function in the test? we could have a dummy function for whoami that returns whatever we want during the testing?
happy to merge it as is and figure it out later :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as @lvwerra , definitely OK for me to skip this test for now and look at it later!
This PR adds basic
push_to_hub()
functionality to thePPOTrainer
. Here's a model produced from it: https://huggingface.co/lewtun/dummy-trl-modelThere's a few questions I have about the desired functionality for end users:
trl
model? The examples in the codebase are focused on training and I'm wondering if inference is simply a forward pass through the model with the value head, text-generation or something more complicated?Here's a snippet that shows how this feature works: