Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💔 Decouple loss computing and generation in GRPO #2762

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 73 additions & 54 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,32 +366,41 @@ def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
self._signature_columns = ["prompt"]

# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this method, we now:

  • generate
  • compute reward
  • compute ref log probs

return inputs

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
Expand All @@ -418,18 +427,21 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0)
prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)

# Compute prompt length and extract completion ids
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)

# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
Expand All @@ -439,49 +451,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

# Concatenate prompt_mask with completion_mask for logit computation
prompt_mask_repeated = prompt_inputs["attention_mask"].repeat_interleave(self.num_generations, dim=0)
attention_mask = torch.cat([prompt_mask_repeated, completion_mask], dim=1) # (B*G, P+C)

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted to method

# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)

logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = get_per_token_logps(model, prompt_completion_ids, attention_mask, logits_to_keep)

with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(
model, prompt_completion_ids, attention_mask, logits_to_keep
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)

# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is later computed in compute_loss


# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]

# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts

rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
Expand Down Expand Up @@ -521,15 +512,7 @@ def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
Comment on lines -524 to -531
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is later computed in compute_loss


reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
Expand All @@ -539,15 +522,51 @@ def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())

self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model

prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)

mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

return loss

def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
Expand Down
Loading