-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
💔 Decouple loss computing and generation in GRPO #2762
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -366,32 +366,41 @@ def _set_signature_columns_if_needed(self): | |
if self._signature_columns is None: | ||
self._signature_columns = ["prompt"] | ||
|
||
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device. | ||
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step. | ||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: | ||
return inputs | ||
|
||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||
if return_outputs: | ||
raise ValueError("The GRPOTrainer does not support returning outputs") | ||
# Get the per-token log probabilities for the completions for the model and the reference model | ||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): | ||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded | ||
logits = model( | ||
input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1 | ||
).logits # (B, L, V) | ||
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
|
||
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. | ||
per_token_logps = [] | ||
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]): | ||
log_probs = logits_row.log_softmax(dim=-1) | ||
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) | ||
per_token_logps.append(token_log_prob) | ||
return torch.stack(per_token_logps) | ||
|
||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: | ||
device = self.accelerator.device | ||
prompts = [x["prompt"] for x in inputs] | ||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] | ||
prompt_inputs = self.processing_class( | ||
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False | ||
) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] | ||
|
||
if self.max_prompt_length is not None: | ||
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :] | ||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :] | ||
prompt_ids = prompt_ids[:, -self.max_prompt_length :] | ||
prompt_mask = prompt_mask[:, -self.max_prompt_length :] | ||
|
||
# Generate completions using either vLLM or regular generation | ||
if self.args.use_vllm: | ||
# First, have main process load weights if needed | ||
if self.state.global_step != self._last_loaded_step: | ||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: | ||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: | ||
state_dict = unwrapped_model.state_dict() | ||
if self.accelerator.is_main_process: | ||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model | ||
|
@@ -418,18 +427,21 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
# Pad the completions, and concatenate them with the prompts | ||
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] | ||
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) | ||
prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0) | ||
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1) | ||
prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0) | ||
prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0) | ||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) | ||
else: | ||
# Regular generation path | ||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: | ||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: | ||
prompt_completion_ids = unwrapped_model.generate( | ||
**prompt_inputs, generation_config=self.generation_config | ||
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config | ||
) | ||
|
||
# Compute prompt length and extract completion ids | ||
prompt_length = prompt_inputs["input_ids"].size(1) | ||
completion_ids = prompt_completion_ids[:, prompt_length:] | ||
# Compute prompt length and extract completion ids | ||
prompt_length = prompt_ids.size(1) | ||
prompt_ids = prompt_completion_ids[:, :prompt_length] | ||
completion_ids = prompt_completion_ids[:, prompt_length:] | ||
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0) | ||
|
||
# Mask everything after the first EOS token | ||
is_eos = completion_ids == self.processing_class.eos_token_id | ||
|
@@ -439,49 +451,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() | ||
|
||
# Concatenate prompt_mask with completion_mask for logit computation | ||
prompt_mask_repeated = prompt_inputs["attention_mask"].repeat_interleave(self.num_generations, dim=0) | ||
attention_mask = torch.cat([prompt_mask_repeated, completion_mask], dim=1) # (B*G, P+C) | ||
|
||
# Get the per-token log probabilities for the completions for the model and the reference model | ||
def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. converted to method |
||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded | ||
logits = model( | ||
input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1 | ||
).logits # (B, L, V) | ||
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
|
||
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. | ||
per_token_logps = [] | ||
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]): | ||
log_probs = logits_row.log_softmax(dim=-1) | ||
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) | ||
per_token_logps.append(token_log_prob) | ||
return torch.stack(per_token_logps) | ||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) | ||
|
||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens | ||
per_token_logps = get_per_token_logps(model, prompt_completion_ids, attention_mask, logits_to_keep) | ||
|
||
with torch.inference_mode(): | ||
if self.ref_model is not None: | ||
ref_per_token_logps = get_per_token_logps( | ||
ref_per_token_logps = self._get_per_token_logps( | ||
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep | ||
) | ||
else: | ||
with self.accelerator.unwrap_model(model).disable_adapter(): | ||
ref_per_token_logps = get_per_token_logps( | ||
model, prompt_completion_ids, attention_mask, logits_to_keep | ||
with self.accelerator.unwrap_model(self.model).disable_adapter(): | ||
ref_per_token_logps = self._get_per_token_logps( | ||
self.model, prompt_completion_ids, attention_mask, logits_to_keep | ||
) | ||
|
||
# Compute the KL divergence between the model and the reference model | ||
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is later computed in |
||
|
||
# Decode the generated completions | ||
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) | ||
if is_conversational(inputs[0]): | ||
completions = [[{"role": "assistant", "content": completion}] for completion in completions] | ||
|
||
# Compute the rewards | ||
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] | ||
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts | ||
|
||
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) | ||
for i, (reward_func, reward_processing_class) in enumerate( | ||
|
@@ -521,15 +512,7 @@ def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep): | |
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) | ||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) | ||
|
||
# x - x.detach() allows for preserving gradients from x | ||
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) | ||
per_token_loss = -(per_token_loss - self.beta * per_token_kl) | ||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() | ||
|
||
# Log the metrics | ||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() | ||
self._metrics["completion_length"].append(completion_length) | ||
Comment on lines
-524
to
-531
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is later computed in |
||
|
||
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) | ||
for i, reward_func in enumerate(self.reward_funcs): | ||
if isinstance(reward_func, PreTrainedModel): | ||
|
@@ -539,15 +522,51 @@ def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep): | |
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) | ||
|
||
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) | ||
|
||
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) | ||
|
||
return { | ||
"prompt_ids": prompt_ids, | ||
"prompt_mask": prompt_mask, | ||
"completion_ids": completion_ids, | ||
"completion_mask": completion_mask, | ||
"ref_per_token_logps": ref_per_token_logps, | ||
"advantages": advantages, | ||
} | ||
|
||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||
if return_outputs: | ||
raise ValueError("The GRPOTrainer does not support returning outputs") | ||
# Compute the per-token log probabilities for the model | ||
|
||
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] | ||
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] | ||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1) | ||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) | ||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens | ||
|
||
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) | ||
|
||
# Compute the KL divergence between the model and the reference model | ||
ref_per_token_logps = inputs["ref_per_token_logps"] | ||
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 | ||
|
||
# x - x.detach() allows for preserving gradients from x | ||
advantages = inputs["advantages"] | ||
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) | ||
per_token_loss = -(per_token_loss - self.beta * per_token_kl) | ||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() | ||
|
||
# Log the metrics | ||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() | ||
self._metrics["completion_length"].append(completion_length) | ||
|
||
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() | ||
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) | ||
|
||
return loss | ||
|
||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None): | ||
inputs = self._prepare_inputs(inputs) | ||
with torch.no_grad(): | ||
with self.compute_loss_context_manager(): | ||
loss = self.compute_loss(model, inputs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this method, we now: