Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤖 Properly unwrap torch.compile-ed models in GRPO #2750

Merged
merged 10 commits into from
Feb 4, 2025
33 changes: 33 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import tempfile
import unittest

Expand Down Expand Up @@ -390,3 +391,35 @@ def test_training_vllm(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") # compiling seems to be broken on Windows
def test_training_torch_compile(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
torch_compile=True,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
5 changes: 2 additions & 3 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class GRPOConfig(TrainingArguments):
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.

> Parameters that control the training
Expand Down Expand Up @@ -147,15 +147,14 @@ class GRPOConfig(TrainingArguments):
"out-of-memory (OOM) errors during initialization."
},
)

vllm_dtype: Optional[str] = field(
default="auto",
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)

# Parameters that control the training
learning_rate: float = field(
default=1e-6,
Expand Down
11 changes: 8 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather_object
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
from packaging import version
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
Expand Down Expand Up @@ -402,7 +404,10 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
Expand Down Expand Up @@ -479,7 +484,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
Expand Down Expand Up @@ -516,7 +521,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
# Log the metrics
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
Expand Down