Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.forward requires num_logits_to_keep, not logits_to_keep #2770

Closed
5 tasks done
richardwth opened this issue Feb 5, 2025 · 10 comments · Fixed by #2773
Closed
5 tasks done

model.forward requires num_logits_to_keep, not logits_to_keep #2770

richardwth opened this issue Feb 5, 2025 · 10 comments · Fixed by #2773
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO

Comments

@richardwth
Copy link

richardwth commented Feb 5, 2025

Reproduction

In _get_per_token_logps method of grpo_trainer.py, the model is called as

logits = model(
    input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits  # (B, L, V)

But on Transformers 4.48, Qwen 2's model.forward has no logits_to_keep argument. It should be num_logits_to_keep, i.e.,

logits = model(
    input_ids=input_ids, attention_mask=attention_mask, num_logits_to_keep=num_logits_to_keep + 1
).logits  # (B, L, V)

System Info

  • Platform: Linux-5.10.134-16.3.al8.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.16
  • PyTorch version: 2.5.1
  • CUDA device(s): NVIDIA L20Z, NVIDIA L20Z, NVIDIA L20Z, NVIDIA L20Z, NVIDIA L20Z, NVIDIA L20Z, NVIDIA L20Z, NVIDIA L20Z
  • Transformers version: 4.48.0
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • Datasets version: 3.2.0
  • HF Hub version: 0.27.1
  • TRL version: 0.15.0.dev0
  • bitsandbytes version: 0.45.0
  • DeepSpeed version: 0.16.2
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.60.0
  • PEFT version: 0.6.2

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added 🐛 bug Something isn't working 🏋 GRPO Related to GRPO labels Feb 5, 2025
@qgallouedec
Copy link
Member

What model do you use?

@richardwth
Copy link
Author

What model do you use?

I am using Qwen2.5 7B and Llama3 8B.

@qgallouedec
Copy link
Member

Can't reproduce 🤔 (with both transformers 4.49 dev and 4.48):

>>> from transformers import AutoModelForCausalLM
>>> import torch
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda")
>>> input_ids = torch.randint(100, 200, (4, 256), device="cuda")
>>> model(input_ids, logits_to_keep=128)
CausalLMOutputWithPast(loss=None, logits=tensor([[[ 3.1937,  5.0300,  5.1271,  ...,  0.9952,  0.9950,  0.9953],
...
         [-5.3295, -7.6368, -1.7083,  ...,  4.3898,  4.3899,  4.3898]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>), past_key_values=DynamicCache(), hidden_states=None, attentions=None)
>>>

@Co1lin
Copy link

Co1lin commented Feb 5, 2025

@qgallouedec Hey, could you check the shape of the returned logits?

out = model(input_ids, logits_to_keep=128)
out.logits.shape

On my side, using logits_to_keep=128 will return torch.Size([4, 256, 152064]). Instead, using num_logits_to_keep=128 gives torch.Size([4, 128, 152064]).

@Co1lin
Copy link

Co1lin commented Feb 5, 2025

@qgallouedec Also, do you think this is related to #2731?

@qgallouedec
Copy link
Member

Nice finding. Indeed, with 4.48, you'll get 256, while with 4.49, you'll get 128, as expected.
You'll need to clear the cache as well after upgrading or load with force_download=True

@richardwth
Copy link
Author

Right... Transformers' latest modeling_qwen2.py says:

@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")

So this is a bug related to some models of some Transformers versions. I've modified my issue to reflect this.

@nopepper
Copy link

nopepper commented Feb 5, 2025

!!! Thank you so much

I was baffled and debugging my code for hours because of this. After applying the fix, the difference is obvious:

Before:

Image

After:

Image

This is a pretty major bug that makes the method unusable for Qwen models (and maybe others) using the latest stable transformers version.

It may also be great to add some sort of check or assertion to make sure the shape of the logits matches what is expected.

@Co1lin
Copy link

Co1lin commented Feb 5, 2025

@nopepper Hey, could you share the versions of transformers and trl that you are using now? Also, which model are you using?

@Superskyyy
Copy link
Contributor

Superskyyy commented Feb 6, 2025

Interestingly even with this bug I was able to train a Qwen2.5 model properly and it actually converged, everything was normal... now my brain exploded. Checked the code and it was indeed as you guys found, 256 not 128

Update: Maybe because I was using a (without num_) commit from the prompt cache PR, which did no_grad on the prompt part. And offsets the issue. Only thing was the bug increased HBM usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants