Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: AssertionError when using automatic prefix caching and prompt_logprobs #8268

Open
1 task done
novoselrok opened this issue Sep 7, 2024 · 20 comments
Open
1 task done
Labels
bug Something isn't working

Comments

@novoselrok
Copy link

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.30.2
Libc version: glibc-2.31

Python version: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:50:21)  [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.10.0-30-cloud-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so.8.9.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        46 bits physical, 48 bits virtual
CPU(s):                               96
On-line CPU(s) list:                  0-95
Thread(s) per core:                   2
Core(s) per socket:                   24
Socket(s):                            2
NUMA node(s):                         2
Vendor ID:                            GenuineIntel
CPU family:                           6
Model:                                85
Model name:                           Intel(R) Xeon(R) CPU @ 2.20GHz
Stepping:                             7
CPU MHz:                              2200.226
BogoMIPS:                             4400.45
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            1.5 MiB
L1i cache:                            1.5 MiB
L2 cache:                             48 MiB
L3 cache:                             77 MiB
NUMA node0 CPU(s):                    0-23,48-71
NUMA node1 CPU(s):                    24-47,72-95
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervi

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.555.43
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.5.40
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] onnxruntime==1.18.1
[pip3] pyzmq==26.0.3
[pip3] sentence-transformers==3.0.1
[pip3] torch==2.4.0
[pip3] torchao==0.1
[pip3] torchtune==0.2.0.dev20240625+cpu
[pip3] torchvision==0.19.0
[pip3] transformers==4.43.4
[pip3] triton==3.0.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
[conda] nvidia-ml-py              12.555.43                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.5.40                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
[conda] pyzmq                     26.0.3                   pypi_0    pypi
[conda] sentence-transformers     3.0.1                    pypi_0    pypi
[conda] torch                     2.4.0                    pypi_0    pypi
[conda] torchao                   0.1                      pypi_0    pypi
[conda] torchtune                 0.2.0.dev20240625+cpu          pypi_0    pypi
[conda] torchvision               0.19.0                   pypi_0    pypi
[conda] transformers              4.43.4                   pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.0@32e7db25365415841ebc7c4215851743fbb1bad1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity
GPU0  X  NV12 NV12 NV12 NV12 NV12 NV12 NV12 0-23,48-71 0
GPU1 NV12  X  NV12 NV12 NV12 NV12 NV12 NV12 0-23,48-71 0
GPU2 NV12 NV12  X  NV12 NV12 NV12 NV12 NV12 0-23,48-71 0
GPU3 NV12 NV12 NV12  X  NV12 NV12 NV12 NV12 0-23,48-71 0
GPU4 NV12 NV12 NV12 NV12  X  NV12 NV12 NV12 24-47,72-95 1
GPU5 NV12 NV12 NV12 NV12 NV12  X  NV12 NV12 24-47,72-95 1
GPU6 NV12 NV12 NV12 NV12 NV12 NV12  X  NV12 24-47,72-95 1
GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12  X  24-47,72-95 1

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

I'm having issues using automatic prefix caching with prompt_logprobs option. The first call to the generate method goes through, but the second call errors with an AssertionError.

Reproduction code:

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = LLM(model_path, tensor_parallel_size=8, dtype="bfloat16", gpu_memory_utilization=0.8, enable_prefix_caching=True)

sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)
tokenizer = AutoTokenizer.from_pretrained(model_path)

chat_prompts = tokenizer.apply_chat_template([[{"role": "user", "content": "Test 1"}]], tokenize=False)
output = model.generate(chat_prompts, sampling_params, use_tqdm=False)

print("OK")

chat_prompts = tokenizer.apply_chat_template([[{"role": "user", "content": "Test 2"}]], tokenize=False)
output = model.generate(chat_prompts, sampling_params, use_tqdm=False) # ERROR!

Full stack trace:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[2], line 10
      7 print("OK")
      9 chat_prompts = tokenizer.apply_chat_template([[{"role": "user", "content": "Test 2"}]], tokenize=False)
---> 10 output = model.generate(chat_prompts, sampling_params, use_tqdm=False) # ERROR!

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/utils.py:1032, in deprecate_kwargs.<locals>.wrapper.<locals>.inner(*args, **kwargs)
   1025             msg += f" {additional_message}"
   1027         warnings.warn(
   1028             DeprecationWarning(msg),
   1029             stacklevel=3,  # The inner function takes up one level
   1030         )
-> 1032 return fn(*args, **kwargs)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/entrypoints/llm.py:347, in LLM.generate(self, prompts, sampling_params, prompt_token_ids, use_tqdm, lora_request, prompt_adapter_request, guided_options_request)
    338     sampling_params = SamplingParams()
    340 self._validate_and_add_requests(
    341     inputs=inputs,
    342     params=sampling_params,
    343     lora_request=lora_request,
    344     prompt_adapter_request=prompt_adapter_request,
    345     guided_options=guided_options_request)
--> 347 outputs = self._run_engine(use_tqdm=use_tqdm)
    348 return LLMEngine.validate_outputs(outputs, RequestOutput)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/entrypoints/llm.py:704, in LLM._run_engine(self, use_tqdm)
    702 total_out_toks = 0
    703 while self.llm_engine.has_unfinished_requests():
--> 704     step_outputs = self.llm_engine.step()
    705     for output in step_outputs:
    706         if output.finished:

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/engine/llm_engine.py:1551, in LLMEngine.step(self)
   1547 if allow_async_output_proc:
   1548     execute_model_req.async_callback = self.async_callbacks[
   1549         virtual_engine]
-> 1551 output = self.model_executor.execute_model(
   1552     execute_model_req=execute_model_req)
   1554 # We need to do this here so that last step's sampled_token_ids can
   1555 # be passed to the next iteration for PP.
   1556 if self.scheduler_config.is_multi_step:

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/executor/distributed_gpu_executor.py:78, in DistributedGPUExecutor.execute_model(self, execute_model_req)
     72     self.parallel_worker_tasks = self._run_workers(
     73         "start_worker_execution_loop",
     74         async_run_tensor_parallel_workers_only=True,
     75         **self.extra_execute_model_run_workers_kwargs)
     77 # Only the driver worker returns the sampling results.
---> 78 driver_outputs = self._driver_execute_model(execute_model_req)
     79 assert driver_outputs is not None
     80 return driver_outputs

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/executor/multiproc_gpu_executor.py:162, in MultiprocessingGPUExecutor._driver_execute_model(self, execute_model_req)
    154 def _driver_execute_model(
    155     self, execute_model_req: Optional[ExecuteModelRequest]
    156 ) -> Optional[List[SamplerOutput]]:
    157     """Run execute_model in the driver worker.
    158 
    159     Passing None will cause the driver to stop the model execution
    160     loop running in each of the remote workers.
    161     """
--> 162     return self.driver_worker.execute_model(execute_model_req)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/worker/worker_base.py:327, in LocalOrDistributedWorkerBase.execute_model(self, execute_model_req)
    322     if (self.observability_config is not None
    323             and self.observability_config.collect_model_execute_time):
    324         orig_model_execute_time = intermediate_tensors.tensors.get(
    325             "model_execute_time", torch.tensor(0)).item()
--> 327 output = self.model_runner.execute_model(
    328     model_input=model_input,
    329     kv_caches=self.kv_cache[worker_input.virtual_engine]
    330     if self.kv_cache is not None else None,
    331     intermediate_tensors=intermediate_tensors,
    332     num_steps=num_steps,
    333     **kwargs,
    334 )
    336 model_execute_time = time.perf_counter() - start_time
    337 if not get_pp_group().is_last_rank:
    338     # output is IntermediateTensors

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/worker/model_runner.py:1493, in ModelRunner.execute_model(self, model_input, kv_caches, intermediate_tensors, num_steps)
   1490     model_input.async_callback()
   1492 # Sample the next token.
-> 1493 output: SamplerOutput = self.model.sample(
   1494     logits=logits,
   1495     sampling_metadata=model_input.sampling_metadata,
   1496 )
   1497 if (self.observability_config is not None
   1498         and self.observability_config.collect_model_forward_time
   1499         and output is not None):
   1500     model_forward_end.synchronize()

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/model_executor/models/llama.py:447, in LlamaForCausalLM.sample(self, logits, sampling_metadata)
    442 def sample(
    443     self,
    444     logits: torch.Tensor,
    445     sampling_metadata: SamplingMetadata,
    446 ) -> Optional[SamplerOutput]:
--> 447     next_tokens = self.sampler(logits, sampling_metadata)
    448     return next_tokens

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/model_executor/layers/sampler.py:305, in Sampler.forward(self, logits, sampling_metadata)
    301 if not sampling_metadata.skip_sampler_cpu_output:
    302     # Pythonize logprobs now (GPU -> CPU); do not defer.
    303     assert not isinstance(maybe_deferred_sample_results,
    304                           SampleResultArgsType)
--> 305     prompt_logprobs, sample_logprobs = get_logprobs(
    306         logprobs, sampling_metadata, maybe_deferred_sample_results)
    308 return _build_sampler_output(
    309     maybe_deferred_sample_results,
    310     sampling_metadata,
   (...)
    313     on_device_tensors=on_device_tensors,
    314     skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/model_executor/layers/sampler.py:1079, in get_logprobs(logprobs, sampling_metadata, sample_results)
   1074             largest_num_logprobs = max(largest_num_logprobs,
   1075                                        sampling_params.logprobs)
   1077         use_beam_search = use_beam_search or sampling_params.use_beam_search
-> 1079     assert len(next_token_ids) == len(query_indices)
   1081 if len(query_indices) == 0:
   1082     empty_sampled_logprob: SampleLogprobs = []

AssertionError: 

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@novoselrok novoselrok added the bug Something isn't working label Sep 7, 2024
@hibukipanim
Copy link

probably similar issue to #5344 (same assert fails)

some more related issues come up when searching for next_token_ids: https://github.com/vllm-project/vllm/issues?q=is%3Aissue+is%3Aopen+next_token_ids

@drubinstein
Copy link

Note sure if it's any help, but I simplified the example a little bit. If the number of tokens in the prefix is > 16 and there's a full cache hit, then the assertion will trigger.

from vllm import LLM, SamplingParams, TokensPrompt

model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"

model = LLM(model_path, tensor_parallel_size=1, dtype="bfloat16", gpu_memory_utilization=0.8, enable_prefix_caching=True, enable_chunked_prefill=True,)
sampling_params = SamplingParams(prompt_logprobs=1,  max_tokens=1)

# works
# prompt = TokensPrompt(prompt_token_ids=list(range(16)))
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")

# fails
prompt = TokensPrompt(prompt_token_ids=list(range(17)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK")
y = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK")

@drubinstein
Copy link

Another update, it looks like the crash is related to the block size. If the number of tokens in the cached prefix is > than the block size, then the assertion will be hit. 16 is the default so that's why I saw it first. As per the example below, if I use a block size of 32, then I can increase the length of TokensPrompt to 32.

Examples:

from vllm import LLM, SamplingParams, TokensPrompt

model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"

model = LLM(
    model_path,
    tensor_parallel_size=1,
    dtype="bfloat16",
    gpu_memory_utilization=0.8,
    enable_prefix_caching=True,
    enable_chunked_prefill=True,
    block_size=32
)
sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)

# works
prompt = TokensPrompt(prompt_token_ids=list(range(31)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)

# fails
prompt = TokensPrompt(prompt_token_ids=list(range(33)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)

@drubinstein
Copy link

Can you try out the new version of vLLM (0.6.3.post1). I believe #9034 may have fixed this error by correctly populating Sequence.

@yejingfu
Copy link

The #9034 cannot fix the issue, I patched this PR but still reproduce the issue.

@drubinstein
Copy link

Unfortunately, I saw the same. I think I got lucky when it worked out.

@ccolas
Copy link

ccolas commented Oct 31, 2024

posted a fix in #3251 that solves some problems (maybe enough for you), but not all
#3251 (comment)
Hope it helps

@hibukipanim
Copy link

@ccolas this looks great.
Can you please consider opening a PR with this fix? 🙏

@fxmarty-amd
Copy link

fxmarty-amd commented Jan 21, 2025

Same issue on ROCm@c040f0e using the offline API.

Repro:

vllm serve unsloth/Llama-3.2-1B-Instruct \
    --tensor-parallel-size 2 \
    --enforce-eager \
    --distributed-executor-backend ray \
    --max-model-len 4000 \
    --enable-prefix-caching

and

curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
        "model": "unsloth/Llama-3.2-1B-Instruct",
        "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
        "max_tokens": 1,
        "temperature": 0, "logprobs": 1, "echo": true
    }'

and

curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
        "model": "unsloth/Llama-3.2-1B-Instruct",
        "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. This is a bit longer prompt",
        "max_tokens": 1,
        "temperature": 0, "logprobs": 1, "echo": true
    }'

as @ccolas mentioned, logprob caching is not supported at the moment in vLLM and an error should be raised in case "logprobs": 1 is used with prefix caching.

Even fixing the wrong computed_len, vllm will later error out at

logprobs = self._create_completion_logprobs(
due to missing logprobs.

@ccolas
Copy link

ccolas commented Jan 21, 2025

if you only care about the logprobs of the end of your prompt (which was my case), then you can prevent caching from the N last block ids (i think the default is N=1, but you need to extend it to cover the max length of prompts you care about).

i did that in vllm/core/block/prefix_cache_block:

def get_computed_block_ids(self,
                                   prev_computed_block_ids: List[int],
                                   block_ids: List[int],
                                   skip_last_block_id: bool = True) -> List[int]:
            prev_prefix_size = len(prev_computed_block_ids)
            cur_size = len(block_ids)
            if skip_last_block_id:
                cur_size -= min(cur_size, 12)  # don't cache the last 12 * block_size tokens, no matter what -- so we can compute logprobs of the end

but seems code has changed since then (i'm on vllm 0.6.3).

i feel like when logprobs=1 and --enable-prefix-caching vllm should cache logprobs, it should be negligible compared to activations right? just one float per prompt token

@mgoin this is the issue i was talking about, is there a strong constraint against caching logprobs? seems i'm not the only one to care :)

see also #3251 (comment)

@fxmarty-amd
Copy link

I came across this issue using vllm as a backend for local-completions with lm-evaluation-harness, which relies on logprobs. Given the multiple open issues, it is probably a common use case indeed.

i feel like when logprobs=1 and --enable-prefix-caching vllm should cache logprobs, it should be negligible compared to activations right? just one float per prompt token

Well if it is caching the logprob of the argmax token only, sure, but if you want somehow to cache for your whole vocab, it can get substantially large.

@ccolas
Copy link

ccolas commented Jan 21, 2025

hm yes ofc, my use case (not sure it's the general one), is that i need to logprobs of the prompt tokens, not even the ones of the argmax tokens. I want to answer "what's the probability of that particular sentence under the model, conditioned on what came before"

but yeah maybe people need more than that. Storing the whole vocab is indeed a lot, but maybe if that's what people truly want then it can be an option, and a user might decide to move along the tradeoff (less space for caching activations because more space is used to cache logprobs) -- this would mean being able to parameterize the logprobs caching: logprobs of prompt, logprobs of argmax tokens, logprobs of all tokens

@fxmarty-amd
Copy link

Yes sorry, this is probably the most common use case. It appears though that passing "logprobs": 1, "logprobs": 2, etc. also populates choices -> prompt_logprobs and choices -> logprobs -> top_logprobs with the top-k logprob on top of the next token logprob, so you might need to handle a bit tricky case where:

  • Your first request is with "logprobs": 1 and you cache your next token logprobs (plus most likely token logprob if it is an other one).
  • Next POST query with the same prefix (that has been partially cached if longer than the block size) uses "logprobs": 2 => then you are probably screwed as the sampler in your first request most likely did not output the second topmost logprob, and so you didn't cache it. In this case, you would need to recompute everything.

I think though that this is a bit of an edge case, and it may be safe to assume (or check and error out if not) that the logprobs parameter in requests stay the same throughout, in case prefix caching + logprobs caching is used.

I have a working ugly prototype with logprobs caching relying on block_tables and slot_mapping (maybe there is a better way though) will try to clean it and open a PR.

@fxmarty-amd
Copy link

fxmarty-amd commented Jan 23, 2025

Actually, something I did not anticipate is that the scheduler needs to be modified for this to be doable, as the logprob cache is shifted compared to the KV cache.

Assume block size of 4 for simplicity. A first query is computed, and we cache logprobs:

request: x y p q    a b c d    e m
         | | | |    | | | |    |
cache:   y p q a    b c d e    m
         -------    -------    ------- 
         block 1    block 2    block 3

as the first token logprob can not be known.

Then, the second query is:

x y p q    a b c d   g h b
-------    -------   

but with prefix caching focused on the aligned KV cache we schedule only

g h b

and we attempt to reuse the logprob cache from the first two blocks. But the last logprob in cache is for e! So it is wrong. And at this point, we have no chance to know the logprob for g, as it is our start token.

Instead, what we should schedule in the second request is

a b c d   g h b
-------   

even though the KV cache for a b c d is known.

Said differently, the implementation of prefix caching described in https://docs.vllm.ai/en/v0.5.3/automatic_prefix_caching/details.html does not work well with logprob caching, we need an indirection like:

hash(prefix tokens + block tokens + first next token) <--> KV Block

Does this sound reasonable? If you have pointers as to where to modify the scheduler it is helpful.

@fxmarty-amd
Copy link

fxmarty-amd commented Jan 23, 2025

Related slack thread - https://vllm-dev.slack.com/archives/C07QP347J4D/p1737660481543449

TL;DR Although there are multiple open/closed issues for this, logprob caching might be a bit too much of an edge case although it would enable true prefix caching for logprob requests. For now @mgoin proposes simply to disable prefix caching for logprob requests.

@hibukipanim
Copy link

actually this is will be a great improvement if we can start vllm with prefix-caching enabled for all requests except the ones which require prompt logprobs 🙏

@ccolas
Copy link

ccolas commented Jan 24, 2025

How about just throwing a readable error telling people this is not supported so they should set one of the arguments to False? Or, if caching is disabled in the background, users should at least be warned this is happening. Maybe they need the two features together, and then they should know vllm doesn't support it and can go use something else.

The fix I use allows me to get the logprobs of the end of the prompt AND caching, by limiting caching up to the last N blocks only, so the end doesn't gets cached. I'm guessing this usage is common (eg, you want to logprobs of different possible answers, but would be happy caching all what comes before: context / question). So having caching automatically disabled when i set logprobs = 1 would actually be annoying for me.

Maybe another feature that could be interesting is to let users decide what to cache, but that's another matter.

@fxmarty-amd
Copy link

@ccolas so you get logprobs only for a partial rightmost part of your query? How does this work with a query like

curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
        "model": "unsloth/Llama-3.2-1B-Instruct",
        "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
        "max_tokens": 1,
        "temperature": 0, "logprobs": 1, "echo": true
    }'

where at some point in vllm response logic you would hit

for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
that from what I understand looks for logprobs for the full query?

@ccolas
Copy link

ccolas commented Jan 24, 2025

using the fix i presented here: #3251 (comment), vllm gives me a bunch of logprobs:

  • for all the prompt, if nothing was cached
  • for the uncached part only, if something was cached

I only need the last X logprobs (corresponding to the last X tokens I care about). But sometimes some of these are cached, so vllm doesn't send the corresponding logprobs. This is why I added the second fix #3251 (comment), to make sure I cache everything possible but the last N blocks, making sure that N*16 tokens covers the longest tokens sequence I care about for my application.

This is very hacky.

Fix #1 could be added to vllm without an issue, but would only be of limited use without control of what gets cached or not.

One way to allow for my use case without caching logprobs is to add a parameter logprob_ids, that tells vllm which tokens you want the logprobs for, eg logprob_ids=[False, False, True, True] to get the logprobs of the last two tokens of your 4-token prompt. This would default to logprob_ids=[True] * len(prompt_tokens).
Then you use that to manipulate caching: you reuse the cache if possible only up to the first True, then you recompute.
With the default value of logprob_ids, in practice the cache would be disabled when logprobs=1 and logprob_ids is not defined.

@drubinstein
Copy link

Is it possible to use the Pooling Models in VLLM to get the last N log probs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants