-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: Models produce different output with different batch sizes #9567
Comments
I also encountered the issue of inconsistent inference results with different batch sizes. This was caused by a bug related to cudagraph, which is currently being fixed by #9549. I'm not sure if this is related to your problem. |
Thanks for the pointer @jeejeelee! I just tested with the main branch and still see this behavior though :( |
This also still occurs with |
From looking at the logprobs returned, it seems like divergence happens pretty quickly in the sequence, and it's larger than just precision error. As an example, here are logprobs from a response from the batched run:
And the same request from the serial run:
|
@tlrmchlsmth I tried to dump out some intermediate states of the model in both batched and serial runs to check the outputs of the cutlass kernel and from where I spot checked at least it looks like it gives the same results, so I haven't been able to track this down any further. Do you know if this is expected behavior? ie is the kernel supposed to sacrifice this much accuracy for speed when processing batches? |
@joerunde thanks for digging in further, we should have somebody from Neural Magic dig in as well. It it possible that this could be a bug outside of GEMM as well? Losing accuracy with larger batch sizes is definitely not expected behavior. What can happen is that changes in the problem size can result in different block sizes used for the GEMM, which can affect the order of the accumulation. If this is on an A100 then we're using the Marlin FP8 kernel so could be a problem there. Do you know if the same thing happens on an H100? |
Oof, H100s are hard to come by but I'll ask around to see if I can snag some time on one to run this script and let you know |
@joerunde no worries, I'll run it |
H100 output on current main:
|
@tlrmchlsmth thanks for finding an H100!
Yeah, so I went and checked this with plain old I can go check on some other non-llama models as well to see if this is a llama-specific issue, just having some gpu acquisition problems atm :/ |
Ah, actually when using dtype=float32 with |
@tjohnson31415 reminded me that when looking for precision issues, it's actually the logits that we care about and not the logprobs that are calculated from the logits. It's possible that changing a logit to the next representable number will cause a much larger difference in the calculated logprob. It seems less than ideal that the benchmark scores change so much because of this on quantized models, but like @robertgshaw2-neuralmagic said we were only using 250 samples to test Any objections to closing as working as expected? |
@joerunde I don't have any objections. I think you've done a thorough job running this down. Good call on checking fp32 as well |
Your current environment
The output of `python collect_env.py`
Model Input Dumps
No response
🐛 Describe the bug
When the
nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test
model runs requests with temperature=0, the output changes depending on how the scheduler batches the requests. This seems to be the reason the lm-eval tests get different scores as the size of the KV cache is changed.Slack thread for more context: https://vllm-dev.slack.com/archives/C07R5PAL2L9/p1729409919734939
Here's a small repro script that uses
--max-num-seqs
to force different batch sizes:test_batch_weirdness.py
And the input data for that: request_data_small.json
On my A100 machine this produces a diff like so:
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: