Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤖 Properly unwrap torch.compile-ed models in GRPO #2750

Merged
merged 10 commits into from
Feb 4, 2025

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Feb 3, 2025

What does this PR do?

when using torch compile, there is one more layer to unwrap before we can send the state dict to vlllm

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@shirinyamani
Copy link
Contributor

That's a very good point! Thanks for pointing it out!, but quick follow-up for my knowledge, I could not find torch.compile() in the main grpo_trainer, could you please help me understand which part of code is specifying the memory hierarchy-aware like torch.compile in your viewpoint ?

bc to the best of my knowledge we could do sth like model = torch.compile(model) but of course we gotta make sure the compatibility with rest of the computation.

@winglian @kashif

@winglian
Copy link
Contributor Author

winglian commented Feb 4, 2025

The base TrainingArguments from transformers includes a torch_compile option, so you can simply set that on GRPOConfig

That's a very good point! Thanks for pointing it out!, but quick follow-up for my knowledge, I could not find torch.compile() in the main grpo_trainer, could you please help me understand which part of code is specifying the memory hierarchy-aware like torch.compile in your viewpoint ?

bc to the best of my knowledge we could do sth like model = torch.compile(model) but of course we gotta make sure the compatibility with rest of the computation.

@winglian
Copy link
Contributor Author

winglian commented Feb 4, 2025

@qgallouedec I rebased this so the merge conflict should be resolved. thanks!

@winglian
Copy link
Contributor Author

winglian commented Feb 4, 2025

I could also move this into the unwrap_model_for_generation function, but I'm not 100% on the deepspeed behavior.
Screenshot 2025-02-04 at 11 50 18 AM

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @winglian, I've added a test, and made sure that it's also compatible with reward models. Can be merged once the CI is green :)

@qgallouedec qgallouedec changed the title properly unwrap torch.compile-ed models with GRPO 🤖 Properly unwrap torch.compile-ed models in GRPO Feb 4, 2025
@qgallouedec qgallouedec merged commit bd946f9 into huggingface:main Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants