Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO memory bottleneck from num_generations in compute_loss #2709

Open
willccbb opened this issue Jan 31, 2025 · 10 comments
Open

GRPO memory bottleneck from num_generations in compute_loss #2709

willccbb opened this issue Jan 31, 2025 · 10 comments
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO ⚡ PEFT Related to PEFT

Comments

@willccbb
Copy link

willccbb commented Jan 31, 2025

The compute_loss implementation in GRPOTrainer seems to be bottlenecked by num_generations in a way which isn't addressed by grad accumulation, though it's not clear how best to resolve this while still subclassing Trainer, as grad accumulation operates on a per-prompt level for GRPO rather than per-sample (and you need all sample rewards to compute advantage inside of compute_loss).

I don't think that looping over the samples in the post-inference forward pass really solves this:

# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.

You still need all the logprobs in memory for all samples in order to compute the backward pass.

It seems addressable from a pytorch/algorithmic perspective, the reward/advantage computation could be done before get_per_token_logps is called, but it would break the implicit assumption of Trainer that all batching can occur outside of compute_loss.

This seems to be a pretty major barrier to using GRPO with models much larger than 1B, regardless of GPU availability.

Even at bs=1, I'm still getting OOM on 8xH100s with:

  • 7 training GPUs (1 for vLLM) + multi_gpu and/or deepspeed zero 3
  • num_completions=4
  • max_prompt_length=200
  • max_completion_length=200
  • with or without grad_checkpointing (though this breaks with zero 3 for me)

PEFT causes issues on multi-GPU, as others have mentioned (#2698).

Any thoughts on a good way to enable something like a nested grad accumulation across samples? Would be happy to take a stab at implementing if so.

@github-actions github-actions bot added 🏋 GRPO Related to GRPO ⚡ PEFT Related to PEFT 🐛 bug Something isn't working labels Jan 31, 2025
@fkxie
Copy link

fkxie commented Jan 31, 2025

same here, still suffering from OOM for running 7B model..

@tgaddair
Copy link

We were also notcing bottlenecks due to the number generations. What we found was causing issues was the forward pass to compute logits for the reference model. One workaround is to break apart the forward pass into a series of smaller forward calls and then concatenate.

So going from:

def get_per_token_logps(model, input_ids, num_logits_to_keep):
            # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
            logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

            # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
            per_token_logps = []
            for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
                log_probs = logits_row.log_softmax(dim=-1)
                token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
                per_token_logps.append(token_log_prob)
            return torch.stack(per_token_logps)

To:

def get_per_token_logps(model, input_ids, num_logits_to_keep):
            # Process input_ids in mini-batches of size 4
            batch_size = input_ids.size(0)
            mini_batch_size = 4. # whatever the max you can handle, this could be made configurable
            all_logits = []

            for i in range(0, batch_size, mini_batch_size):
                batch_end = min(i + mini_batch_size, batch_size)
                mini_batch = input_ids[i:batch_end]

                # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
                mini_batch_logits = model(mini_batch, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)
                all_logits.append(mini_batch_logits)

            # Concatenate all mini-batch results
            logits = torch.cat(all_logits, dim=0)
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

            # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
            per_token_logps = []
            for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
                log_probs = logits_row.log_softmax(dim=-1)
                token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
                per_token_logps.append(token_log_prob)
            return torch.stack(per_token_logps)

@zaddy6
Copy link

zaddy6 commented Jan 31, 2025

same issue

@zaddy6
Copy link

zaddy6 commented Jan 31, 2025

We were also notcing bottlenecks due to the number generations. What we found was causing issues was the forward pass to compute logits for the reference model. One workaround is to break apart the forward pass into a series of smaller forward calls and then concatenate.

So going from:

def get_per_token_logps(model, input_ids, num_logits_to_keep):
            # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
            logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

            # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
            per_token_logps = []
            for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
                log_probs = logits_row.log_softmax(dim=-1)
                token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
                per_token_logps.append(token_log_prob)
            return torch.stack(per_token_logps)

To:

def get_per_token_logps(model, input_ids, num_logits_to_keep):
            # Process input_ids in mini-batches of size 4
            batch_size = input_ids.size(0)
            mini_batch_size = 4. # whatever the max you can handle, this could be made configurable
            all_logits = []

            for i in range(0, batch_size, mini_batch_size):
                batch_end = min(i + mini_batch_size, batch_size)
                mini_batch = input_ids[i:batch_end]

                # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
                mini_batch_logits = model(mini_batch, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)
                all_logits.append(mini_batch_logits)

            # Concatenate all mini-batch results
            logits = torch.cat(all_logits, dim=0)
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

            # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
            per_token_logps = []
            for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
                log_probs = logits_row.log_softmax(dim=-1)
                token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
                per_token_logps.append(token_log_prob)
            return torch.stack(per_token_logps)

this didnt seem to work for me on 8xH100, 3B model, still OOM

@qgallouedec
Copy link
Member

I don't think that looping over the samples in the post-inference forward pass really solves this:

It doesn't solve, but having this loop avoids the peak that you would have here

One workaround is to break apart the forward pass into a series of smaller forward calls and then concatenate.

It might avoid the big decoding peak indeed. Let me try to profile this.

--

Another option is grad checkpointing. See #2671

@qgallouedec
Copy link
Member

You can probably merge the two loops btw:

def get_per_token_logps(model, input_ids, num_logits_to_keep):
    batch_size = input_ids.size(0)
    mini_batch_size = 4  # This could be made configurable
    per_token_logps = []

    for i in range(0, batch_size, mini_batch_size):
        batch_end = min(i + mini_batch_size, batch_size)
        mini_batch = input_ids[i:batch_end]

        # Compute logits with an extra token
        mini_batch_logits = model(mini_batch, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)

        # Exclude the last logit
        mini_batch_logits = mini_batch_logits[:, :-1, :]  # (B, L-1, V)

        # Compute log probabilities
        log_probs = mini_batch_logits.log_softmax(dim=-1)

        # Select the relevant tokens
        input_ids_trimmed = mini_batch[:, -num_logits_to_keep:]
        token_log_probs = torch.gather(log_probs, dim=2, index=input_ids_trimmed.unsqueeze(2)).squeeze(2)

        per_token_logps.append(token_log_probs)

    return torch.cat(per_token_logps, dim=0)

@qgallouedec
Copy link
Member

Some profiling on the get_per_token_logps. Here we use minibatch for the forward pass to reduce the memory peak.

import torch
import time
from transformers import AutoModelForCausalLM


def get_per_token_logps_old(model, input_ids, num_logits_to_keep):
    # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
    logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)
    logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

    # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)


def get_per_token_logps_new(model, input_ids, num_logits_to_keep, mini_batch_size):
    per_token_logps = []
    batch_size = input_ids.size(0)
    for i in range(0, batch_size, mini_batch_size):
        batch_end = min(i + mini_batch_size, batch_size)
        mini_input_ids = input_ids[i:batch_end]
        mini_logits = model(mini_input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits
        mini_logits = mini_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred

        # Compute the log probabilities for the input tokens
        log_probs = mini_logits.log_softmax(dim=-1)
        labels = mini_input_ids[:, -num_logits_to_keep:].unsqueeze(2)
        token_log_prob = torch.gather(log_probs, dim=2, index=labels).squeeze(2)
        per_token_logps.append(token_log_prob)
    return torch.cat(per_token_logps, dim=0)


model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B").to("cuda")

# Simulate a batch of 8 sequences of length 512, where half is the prompt and half is the completion
input_ids = torch.randint(10, 100, (8, 512), device="cuda")
num_logits_to_keep = 256



# Call the old method
times = []
per_token_logps = get_per_token_logps_old(model, input_ids, num_logits_to_keep) # Warmup
for _ in range(10):
    start = time.time()
    per_token_logps = get_per_token_logps_old(model, input_ids, num_logits_to_keep)
    times.append(time.time() - start)
print("Time taken (get_per_token_logps_old):", sum(times) / len(times))


# Call the new method
times = []
per_token_logps = get_per_token_logps_new(model, input_ids, num_logits_to_keep) # Warmup
for _ in range(10):
    start = time.time()
    per_token_logps = get_per_token_logps_new(model, input_ids, num_logits_to_keep)
    times.append(time.time() - start)
print("Time taken (get_per_token_logps_new):", sum(times) / len(times))

We also profile the memory usage of the two methods.

from pynvml import *

def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used/1024**3:.2f} GB.")

per_token_logps = get_per_token_logps_old(model, input_ids, num_logits_to_keep)
print_gpu_utilization()
import matplotlib.pyplot as plt

time_take = [0.205, 0.121, 0.102, 0.090]
ref_time = 0.088

memory = [15.2, 15.97, 17.14, 17.07]
ref_memory = 15.92

minibatch_sizes = [1, 2, 4, 8]

# Plot with 2 y axis: time and memory:
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('minibatch_size')
ax1.set_ylabel('time', color=color)
ax1.plot(minibatch_sizes, time_take, color=color)
ax1.set_ylim(0, 0.3)
ax1.hlines(ref_time, 1, 8, colors='r', linestyles='dashed', label='ref_time')
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('memory', color=color)
ax2.plot(minibatch_sizes, memory, color=color)
ax2.hlines(ref_memory, 1, 8, colors='b', linestyles='dashed', label='ref_memory')
ax2.set_ylim(0, 18)
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()
fig.savefig('plot.png')

Image

Observation: The memory peak of the new method bigger as soon as you use a minibatch of 2. This probably mean that the peak is not related to de decoder, but to the softwax.

@qgallouedec
Copy link
Member

I try with this:

def get_per_token_logps_new2(model, input_ids, num_logits_to_keep, mini_batch_size):
    per_token_logps = []
    batch_size = input_ids.size(0)
    for i in range(0, batch_size, mini_batch_size):
        batch_end = min(i + mini_batch_size, batch_size)
        mini_input_ids = input_ids[i:batch_end]
        mini_logits = model(mini_input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits
        mini_logits = mini_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
        
        for logits_row, input_ids_row in zip(mini_logits, mini_input_ids[:, -num_logits_to_keep:]):
            log_probs = logits_row.log_softmax(dim=-1)
            token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
            per_token_logps.append(token_log_prob)
    return torch.cat(per_token_logps, dim=0)

Not much better:

Image

import matplotlib.pyplot as plt

time_take1 = [0.205, 0.121, 0.102, 0.090]
time_take2 = [0.217, 0.120, 0.101, 0.089]
ref_time = 0.088

memory1 = [15.2, 15.97, 17.14, 17.07]
memory_2 = [15.13, 15.75, 16.65, 16.65]
ref_memory = 15.92

minibatch_sizes = [1, 2, 4, 8]

# Rewrite this code, but split into to plot, one above the other
fig, ax = plt.subplots(2, 1)
# Set the size
fig.set_size_inches(4, 7)

colors = ["blue", "orange", "green", "red"]
ax[0].set_xlabel("Minibatch size")
ax[0].set_ylabel("time")
ax[0].plot(minibatch_sizes, time_take1, label="New 1")
ax[0].plot(minibatch_sizes, time_take2, label="New 2")
ax[0].set_ylim(0, 0.3)
ax[0].hlines(ref_time, 1, 8, linestyles="dashed", label="Old")
ax[0].tick_params(axis="y")
ax[0].set_title("Time (lower is better)")
ax[0].legend(loc="lower right")

ax[1].set_ylabel("memory")
ax[1].plot(minibatch_sizes, memory1, label="New 1")
ax[1].plot(minibatch_sizes, memory_2, label="New 2")
ax[1].hlines(ref_memory, 1, 8, linestyles="dashed", label="Old")
ax[1].set_ylim(0, 18)
ax[1].tick_params(axis="y")
ax[1].set_title("Memory (lower is better)")
ax[1].legend(loc="lower right")


fig.tight_layout()
fig.savefig("plot2.png")```

@tgaddair
Copy link

tgaddair commented Feb 1, 2025

Hey @qgallouedec, thanks for digging into this. We recently started using a larger model and did indeed start to run into OOMs during the log_softmax operation. However, this implementation (I believe the second one you tried) resolved the issue for us:

def get_per_token_logps(model, input_ids, num_logits_to_keep):
            # Process input_ids in mini-batches of size 1 and compute log probs
            batch_size = input_ids.size(0)
            mini_batch_size = 1
            per_token_logps = []

            for i in range(0, batch_size, mini_batch_size):
                batch_end = min(i + mini_batch_size, batch_size)
                mini_batch = input_ids[i:batch_end]

                # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
                mini_batch_logits = model(mini_batch, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)
                logits = mini_batch_logits[:, :-1, :]  # (B, L-1, V), exclude the last logit
                
                # Compute log probs for this mini-batch
                log_probs = logits.log_softmax(dim=-1)
                mini_batch_ids = mini_batch[:, -num_logits_to_keep:]
                token_log_prob = torch.gather(log_probs, dim=2, 
                                            index=mini_batch_ids.unsqueeze(2)).squeeze(2)
                per_token_logps.append(token_log_prob)

            return torch.cat(per_token_logps, dim=0)

The memory benchmarking results look a little surprising to me. Are you running with CUDA_LAUNCH_BLOCKING=1? If not, I would suspect that the async execution might throwing things off if you read the memory usage before execution has completed. Another way to workaround that would be to call torch.cuda.synchronize() before calculating memory usage. It also looks like you're capturing final memory, not peak memory, which might be misleading if the log_softmax peaks in memory usage higher than the final output.

All that to say: I think this is a good change that would potentially benefit a number of people running into OOMs. Thanks for digging into it.

@andyl98
Copy link
Contributor

andyl98 commented Feb 4, 2025

Agree with @tgaddair that this fix should indeed help resolving some OOM issues. I'll make a PR to integrate that change as well (flag-controlled)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO ⚡ PEFT Related to PEFT
Projects
None yet
Development

No branches or pull requests

6 participants