-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GRPO memory bottleneck from num_generations in compute_loss #2709
Comments
same here, still suffering from OOM for running 7B model.. |
We were also notcing bottlenecks due to the number generations. What we found was causing issues was the forward pass to compute logits for the reference model. One workaround is to break apart the forward pass into a series of smaller forward calls and then concatenate. So going from:
To:
|
same issue |
this didnt seem to work for me on 8xH100, 3B model, still OOM |
It doesn't solve, but having this loop avoids the peak that you would have here
It might avoid the big decoding peak indeed. Let me try to profile this. -- Another option is grad checkpointing. See #2671 |
You can probably merge the two loops btw: def get_per_token_logps(model, input_ids, num_logits_to_keep):
batch_size = input_ids.size(0)
mini_batch_size = 4 # This could be made configurable
per_token_logps = []
for i in range(0, batch_size, mini_batch_size):
batch_end = min(i + mini_batch_size, batch_size)
mini_batch = input_ids[i:batch_end]
# Compute logits with an extra token
mini_batch_logits = model(mini_batch, num_logits_to_keep=num_logits_to_keep + 1).logits # (B, L, V)
# Exclude the last logit
mini_batch_logits = mini_batch_logits[:, :-1, :] # (B, L-1, V)
# Compute log probabilities
log_probs = mini_batch_logits.log_softmax(dim=-1)
# Select the relevant tokens
input_ids_trimmed = mini_batch[:, -num_logits_to_keep:]
token_log_probs = torch.gather(log_probs, dim=2, index=input_ids_trimmed.unsqueeze(2)).squeeze(2)
per_token_logps.append(token_log_probs)
return torch.cat(per_token_logps, dim=0) |
Some profiling on the import torch
import time
from transformers import AutoModelForCausalLM
def get_per_token_logps_old(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
def get_per_token_logps_new(model, input_ids, num_logits_to_keep, mini_batch_size):
per_token_logps = []
batch_size = input_ids.size(0)
for i in range(0, batch_size, mini_batch_size):
batch_end = min(i + mini_batch_size, batch_size)
mini_input_ids = input_ids[i:batch_end]
mini_logits = model(mini_input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits
mini_logits = mini_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens
log_probs = mini_logits.log_softmax(dim=-1)
labels = mini_input_ids[:, -num_logits_to_keep:].unsqueeze(2)
token_log_prob = torch.gather(log_probs, dim=2, index=labels).squeeze(2)
per_token_logps.append(token_log_prob)
return torch.cat(per_token_logps, dim=0)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B").to("cuda")
# Simulate a batch of 8 sequences of length 512, where half is the prompt and half is the completion
input_ids = torch.randint(10, 100, (8, 512), device="cuda")
num_logits_to_keep = 256
# Call the old method
times = []
per_token_logps = get_per_token_logps_old(model, input_ids, num_logits_to_keep) # Warmup
for _ in range(10):
start = time.time()
per_token_logps = get_per_token_logps_old(model, input_ids, num_logits_to_keep)
times.append(time.time() - start)
print("Time taken (get_per_token_logps_old):", sum(times) / len(times))
# Call the new method
times = []
per_token_logps = get_per_token_logps_new(model, input_ids, num_logits_to_keep) # Warmup
for _ in range(10):
start = time.time()
per_token_logps = get_per_token_logps_new(model, input_ids, num_logits_to_keep)
times.append(time.time() - start)
print("Time taken (get_per_token_logps_new):", sum(times) / len(times)) We also profile the memory usage of the two methods. from pynvml import *
def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
print(f"GPU memory occupied: {info.used/1024**3:.2f} GB.")
per_token_logps = get_per_token_logps_old(model, input_ids, num_logits_to_keep)
print_gpu_utilization() import matplotlib.pyplot as plt
time_take = [0.205, 0.121, 0.102, 0.090]
ref_time = 0.088
memory = [15.2, 15.97, 17.14, 17.07]
ref_memory = 15.92
minibatch_sizes = [1, 2, 4, 8]
# Plot with 2 y axis: time and memory:
fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.set_xlabel('minibatch_size')
ax1.set_ylabel('time', color=color)
ax1.plot(minibatch_sizes, time_take, color=color)
ax1.set_ylim(0, 0.3)
ax1.hlines(ref_time, 1, 8, colors='r', linestyles='dashed', label='ref_time')
ax1.tick_params(axis='y', labelcolor=color)
ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('memory', color=color)
ax2.plot(minibatch_sizes, memory, color=color)
ax2.hlines(ref_memory, 1, 8, colors='b', linestyles='dashed', label='ref_memory')
ax2.set_ylim(0, 18)
ax2.tick_params(axis='y', labelcolor=color)
fig.tight_layout()
fig.savefig('plot.png') Observation: The memory peak of the new method bigger as soon as you use a minibatch of 2. This probably mean that the peak is not related to de decoder, but to the softwax. |
I try with this: def get_per_token_logps_new2(model, input_ids, num_logits_to_keep, mini_batch_size):
per_token_logps = []
batch_size = input_ids.size(0)
for i in range(0, batch_size, mini_batch_size):
batch_end = min(i + mini_batch_size, batch_size)
mini_input_ids = input_ids[i:batch_end]
mini_logits = model(mini_input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits
mini_logits = mini_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
for logits_row, input_ids_row in zip(mini_logits, mini_input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.cat(per_token_logps, dim=0) Not much better: import matplotlib.pyplot as plt
time_take1 = [0.205, 0.121, 0.102, 0.090]
time_take2 = [0.217, 0.120, 0.101, 0.089]
ref_time = 0.088
memory1 = [15.2, 15.97, 17.14, 17.07]
memory_2 = [15.13, 15.75, 16.65, 16.65]
ref_memory = 15.92
minibatch_sizes = [1, 2, 4, 8]
# Rewrite this code, but split into to plot, one above the other
fig, ax = plt.subplots(2, 1)
# Set the size
fig.set_size_inches(4, 7)
colors = ["blue", "orange", "green", "red"]
ax[0].set_xlabel("Minibatch size")
ax[0].set_ylabel("time")
ax[0].plot(minibatch_sizes, time_take1, label="New 1")
ax[0].plot(minibatch_sizes, time_take2, label="New 2")
ax[0].set_ylim(0, 0.3)
ax[0].hlines(ref_time, 1, 8, linestyles="dashed", label="Old")
ax[0].tick_params(axis="y")
ax[0].set_title("Time (lower is better)")
ax[0].legend(loc="lower right")
ax[1].set_ylabel("memory")
ax[1].plot(minibatch_sizes, memory1, label="New 1")
ax[1].plot(minibatch_sizes, memory_2, label="New 2")
ax[1].hlines(ref_memory, 1, 8, linestyles="dashed", label="Old")
ax[1].set_ylim(0, 18)
ax[1].tick_params(axis="y")
ax[1].set_title("Memory (lower is better)")
ax[1].legend(loc="lower right")
fig.tight_layout()
fig.savefig("plot2.png")``` |
Hey @qgallouedec, thanks for digging into this. We recently started using a larger model and did indeed start to run into OOMs during the
The memory benchmarking results look a little surprising to me. Are you running with All that to say: I think this is a good change that would potentially benefit a number of people running into OOMs. Thanks for digging into it. |
Agree with @tgaddair that this fix should indeed help resolving some OOM issues. I'll make a PR to integrate that change as well (flag-controlled) |
The
compute_loss
implementation in GRPOTrainer seems to be bottlenecked bynum_generations
in a way which isn't addressed by grad accumulation, though it's not clear how best to resolve this while still subclassing Trainer, as grad accumulation operates on a per-prompt level for GRPO rather than per-sample (and you need all sample rewards to compute advantage inside ofcompute_loss
).I don't think that looping over the samples in the post-inference forward pass really solves this:
trl/trl/trainer/grpo_trainer.py
Line 435 in 801582e
You still need all the logprobs in memory for all samples in order to compute the backward pass.
It seems addressable from a pytorch/algorithmic perspective, the reward/advantage computation could be done before
get_per_token_logps
is called, but it would break the implicit assumption of Trainer that all batching can occur outside ofcompute_loss
.This seems to be a pretty major barrier to using GRPO with models much larger than 1B, regardless of GPU availability.
Even at bs=1, I'm still getting OOM on 8xH100s with:
num_completions=4
max_prompt_length=200
max_completion_length=200
grad_checkpointing
(though this breaks with zero 3 for me)PEFT causes issues on multi-GPU, as others have mentioned (#2698).
Any thoughts on a good way to enable something like a nested grad accumulation across samples? Would be happy to take a stab at implementing if so.
The text was updated successfully, but these errors were encountered: