Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add push to Hub for PPOTrainer #68

Merged
merged 22 commits into from
Jan 24, 2023
Merged

Add push to Hub for PPOTrainer #68

merged 22 commits into from
Jan 24, 2023

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Dec 31, 2022

This PR adds basic push_to_hub() functionality to the PPOTrainer. Here's a model produced from it: https://huggingface.co/lewtun/dummy-trl-model

There's a few questions I have about the desired functionality for end users:

  • What is the expected usage for a trained trl model? The examples in the codebase are focused on training and I'm wondering if inference is simply a forward pass through the model with the value head, text-generation or something more complicated?
  • Should the reference model also be pushed to the Hub? For now, I've just pushed the active model with the value head

Here's a snippet that shows how this feature works:

import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from datasets import load_dataset
from trl.core import LengthSampler

config = PPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
)

def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    ds = load_dataset(dataset_name, split='train')
    ds = ds.rename_columns({'text': 'review'})
    ds = ds.filter(lambda x: len(x["review"])>200, batched=False)
    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[:input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type='torch')
    return ds

dataset = build_dataset(config)

def collater(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collater)
ppo_trainer.push_to_hub("dummy-trl-model")

# Load from Hub - note it retains the value-head
hub_model = AutoModelForCausalLMWithValueHead.from_pretrained("dummy-trl-model")

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 31, 2022

The documentation is not available anymore as the PR was closed or merged.

@lewtun lewtun mentioned this pull request Dec 31, 2022
26 tasks
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for working on this 💯 !
To the best of my knowledge once a the ppo trainer has been trained the only thing we want to retain is the trained model w/o its value head to perform any task you want with the remaining model that you can do with a standard CausalLM model.
Note that currently you can also load a trained model with PPO using transformers:

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collater)

# train ... then:
ppo_trainer.push_to_hub("dummy-trl-model")

# Load from Hub using transformers
from transformers import AutoModelForCausalLM

hub_model = AutoModelForCausalLM.from_pretrained("dummy-trl-model")

since save_pretrained from AutoModelForCausalLMWithValueHead is just a wrapper around the save_pretrained method from the base model: https://github.com/lvwerra/trl/blob/master/trl/models/modeling_base.py#L105-L115

The only reason I see to share the v_head's weights would be to continue a training. E.g.

from trl import PPOTrainer, AutoModelForCausalLMWithValueHead

model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-gpt2-intermediate-training")
ppo_trainer = PPOTrainer(model, ...)
for i in range(epochs):
    ppo_trainer.step(...)

(I am unsure if this is what we want currently)

If this is what we want we would certainly change some logic of AutoModelForCausalLMWithValueHead class, make it maybe inherit from transformers.PreTrainedModel so that we can call load the v_head as well using from_pretrained, and probably add the class attribute _keys_to_ignore_on_load_missing=["v_head.summary.weight", "v_head.summary.bias"] so that we can initialize a AutoModelForCausalLMWithValueHead from any HF CausalLM model.

For now I can see that the v_head might be not needed for the ref_model as the only place the forward pass from the ref_model is called is here. Also I don't think we should push the ref model as the model is always frozen

Let's wait for @lvwerra and @edbeeching to chime in as I might be missing some details here !

@lewtun lewtun marked this pull request as ready for review January 5, 2023 01:30
Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! Left a comment regarding the value head.

Before we merge could you add some tests??

from trl import AutoModelForCausalLMWithValueHead

tokenizer = AutoTokenizer.from_pretrained({model_name})
model = AutoModelForCausalLMWithValueHead.from_pretrained({model_name})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this will load a new value head as the save_pretrained method of AutoModelForCausalLMWithValueHead does not save the value head. I think we should save it. I don't know if @younesbelkada wants to add the necessary changes to this or a follow-up PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, to use the model you don't need TRL and you could just load it with transformers AutoModelForCausalLM. Installing TRL and the current example could be a second example if you want to continue training or get the value of the value head.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I can take care of that in a follow-up PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has been taken care in #86 !

@@ -41,6 +42,42 @@
from trl.trainer.utils import AdaptiveKLController, FixedKLController


MODEL_CARD_TEMPLATE = """---
license: apache-2.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the desired default license for these models?

- trl
- transformers
- reinforcement-learning
- human-feedback
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I picked this instead of rlhf, since the former is cryptic and we've already agreed on human-feedback for datasets

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe human-feedback should be added by the user manually, since you don't necessarily need to train from human feedback (e.g. all the sentiment examples).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 47289f5

model = AutoModelForCausalLMWithValueHead.from_pretrained({model_name})

inputs = tokenizer("Hello, my llama is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a more natural way to do inference with these models?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could also pass it to a text-generation pipeline

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 47289f5

@@ -658,3 +695,22 @@ def log_stats(

dist.barrier()
dist.all_reduce(rewards, op=torch.distributed.ReduceOp.SUM)

def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, the model card is kind of limited in what users can configure. I can extend this if desired (e.g. closer to the transformers.Trainer which allows all the tags to be configured)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, we could potentially also use the same method since transformers is already a dependency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the create_model_card in transformers is a method of the Trainer, so can't be re-used. I suggest we just keep it simple for now and see if there's a wish from the community to expand it with more tags etc.

@lewtun
Copy link
Member Author

lewtun commented Jan 5, 2023

Thanks for the comments @lvwerra ! I left a few questions that could do with your feedback - in the meantime I'll add some tests :)

Copy link
Member Author

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra @younesbelkada this PR looks good to go on my side - note that I refactored the test filenames to reflect the repo structure. This is what we do in transformers and I personally find it to be a good organising principle as the project grows. Happy to revert if you wish!

model_id = "gpt2"

# get models and tokenizer
self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need fresh models for each unit test? If not, I can move this to setUpClass() which is only called once (and saves time)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I don't think we need a fresh model for each test, yes can you move it to setUpClass() instead 🙏 Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 6c4ab68

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for shipping this! 💯
Let's just refactor the test as you proposed in #68 (comment) otherwise LGTM
I will also take care of refactoring the models test to test push_to_hub in a follow up PR once this is merged

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - thanks for adding!

@lewtun
Copy link
Member Author

lewtun commented Jan 23, 2023

🔴 Don't merge until I have a fix!

Hmm, using the staging endpoint of the Hub for the test is causing some issues because I rely on whoami() to get the username in the model card, and that method doesn't allow me to distinguish between endpoints

if not os.path.exists(path):
os.makedirs(path)

user = whoami()["name"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is the reason why the unit test in the staging endpoint fails - by default whoami() points to hf.co, when we need it to point to https://hub-ci.huggingface.co

One can do this by using HfApi(endpoint="https://hub-ci.huggingface.co"), but then we need to set an env var for endpoint which feels clunky.

For now, I propose we skip the unit test until we find a better solution to test push_to_hub features

OK with you @younesbelkada @lvwerra ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to solve this is by patching the function in the test? we could have a dummy function for whoami that returns whatever we want during the testing?

happy to merge it as is and figure it out later :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as @lvwerra , definitely OK for me to skip this test for now and look at it later!

@lewtun lewtun merged commit 11ac263 into huggingface:main Jan 24, 2023
@lewtun lewtun deleted the push-to-hub branch January 24, 2023 10:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants