-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: AssertionError when using automatic prefix caching and prompt_logprobs #8268
Comments
probably similar issue to #5344 (same assert fails) some more related issues come up when searching for |
Note sure if it's any help, but I simplified the example a little bit. If the number of tokens in the prefix is > 16 and there's a full cache hit, then the assertion will trigger. from vllm import LLM, SamplingParams, TokensPrompt
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = LLM(model_path, tensor_parallel_size=1, dtype="bfloat16", gpu_memory_utilization=0.8, enable_prefix_caching=True, enable_chunked_prefill=True,)
sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)
# works
# prompt = TokensPrompt(prompt_token_ids=list(range(16)))
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")
# fails
prompt = TokensPrompt(prompt_token_ids=list(range(17)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK")
y = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK") |
Another update, it looks like the crash is related to the block size. If the number of tokens in the cached prefix is > than the block size, then the assertion will be hit. 16 is the default so that's why I saw it first. As per the example below, if I use a block size of 32, then I can increase the length of TokensPrompt to 32. Examples: from vllm import LLM, SamplingParams, TokensPrompt
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = LLM(
model_path,
tensor_parallel_size=1,
dtype="bfloat16",
gpu_memory_utilization=0.8,
enable_prefix_caching=True,
enable_chunked_prefill=True,
block_size=32
)
sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)
# works
prompt = TokensPrompt(prompt_token_ids=list(range(31)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
# fails
prompt = TokensPrompt(prompt_token_ids=list(range(33)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs) |
Can you try out the new version of vLLM (0.6.3.post1). I believe #9034 may have fixed this error by correctly populating Sequence. |
The #9034 cannot fix the issue, I patched this PR but still reproduce the issue. |
Unfortunately, I saw the same. I think I got lucky when it worked out. |
posted a fix in #3251 that solves some problems (maybe enough for you), but not all |
@ccolas this looks great. |
Same issue on ROCm@c040f0e using the offline API. Repro:
and
and
as @ccolas mentioned, logprob caching is not supported at the moment in vLLM and an error should be raised in case Even fixing the wrong
|
if you only care about the logprobs of the end of your prompt (which was my case), then you can prevent caching from the N last block ids (i think the default is N=1, but you need to extend it to cover the max length of prompts you care about). i did that in vllm/core/block/prefix_cache_block:
but seems code has changed since then (i'm on vllm 0.6.3). i feel like when @mgoin this is the issue i was talking about, is there a strong constraint against caching logprobs? seems i'm not the only one to care :) see also #3251 (comment) |
I came across this issue using vllm as a backend for local-completions with lm-evaluation-harness, which relies on logprobs. Given the multiple open issues, it is probably a common use case indeed.
Well if it is caching the logprob of the argmax token only, sure, but if you want somehow to cache for your whole vocab, it can get substantially large. |
hm yes ofc, my use case (not sure it's the general one), is that i need to logprobs of the prompt tokens, not even the ones of the argmax tokens. I want to answer "what's the probability of that particular sentence under the model, conditioned on what came before" but yeah maybe people need more than that. Storing the whole vocab is indeed a lot, but maybe if that's what people truly want then it can be an option, and a user might decide to move along the tradeoff (less space for caching activations because more space is used to cache logprobs) -- this would mean being able to parameterize the logprobs caching: logprobs of prompt, logprobs of argmax tokens, logprobs of all tokens |
Yes sorry, this is probably the most common use case. It appears though that passing
I think though that this is a bit of an edge case, and it may be safe to assume (or check and error out if not) that the I have a working ugly prototype with logprobs caching relying on |
Actually, something I did not anticipate is that the scheduler needs to be modified for this to be doable, as the logprob cache is shifted compared to the KV cache. Assume block size of 4 for simplicity. A first query is computed, and we cache logprobs:
as the first token logprob can not be known. Then, the second query is:
but with prefix caching focused on the aligned KV cache we schedule only
and we attempt to reuse the logprob cache from the first two blocks. But the last logprob in cache is for Instead, what we should schedule in the second request is
even though the KV cache for Said differently, the implementation of prefix caching described in https://docs.vllm.ai/en/v0.5.3/automatic_prefix_caching/details.html does not work well with logprob caching, we need an indirection like:
Does this sound reasonable? If you have pointers as to where to modify the scheduler it is helpful. |
Related slack thread - https://vllm-dev.slack.com/archives/C07QP347J4D/p1737660481543449 TL;DR Although there are multiple open/closed issues for this, logprob caching might be a bit too much of an edge case although it would enable true prefix caching for logprob requests. For now @mgoin proposes simply to disable prefix caching for logprob requests. |
actually this is will be a great improvement if we can start vllm with prefix-caching enabled for all requests except the ones which require prompt logprobs 🙏 |
How about just throwing a readable error telling people this is not supported so they should set one of the arguments to False? Or, if caching is disabled in the background, users should at least be warned this is happening. Maybe they need the two features together, and then they should know vllm doesn't support it and can go use something else. The fix I use allows me to get the logprobs of the end of the prompt AND caching, by limiting caching up to the last N blocks only, so the end doesn't gets cached. I'm guessing this usage is common (eg, you want to logprobs of different possible answers, but would be happy caching all what comes before: context / question). So having caching automatically disabled when i set logprobs = 1 would actually be annoying for me. Maybe another feature that could be interesting is to let users decide what to cache, but that's another matter. |
@ccolas so you get logprobs only for a partial rightmost part of your query? How does this work with a query like
where at some point in vllm response logic you would hit vllm/vllm/entrypoints/openai/serving_completion.py Lines 494 to 505 in 3aec49e
|
using the fix i presented here: #3251 (comment), vllm gives me a bunch of logprobs:
I only need the last X logprobs (corresponding to the last X tokens I care about). But sometimes some of these are cached, so vllm doesn't send the corresponding logprobs. This is why I added the second fix #3251 (comment), to make sure I cache everything possible but the last N blocks, making sure that N*16 tokens covers the longest tokens sequence I care about for my application. This is very hacky. Fix #1 could be added to vllm without an issue, but would only be of limited use without control of what gets cached or not. One way to allow for my use case without caching logprobs is to add a parameter |
Is it possible to use the Pooling Models in VLLM to get the last N log probs? |
Your current environment
The output of `python collect_env.py`
🐛 Describe the bug
I'm having issues using automatic prefix caching with prompt_logprobs option. The first call to the
generate
method goes through, but the second call errors with anAssertionError
.Reproduction code:
Full stack trace:
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: