Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: chunked prefill scheudler uses up swap on many n>=2 requests #5578

Closed
toslunar opened this issue Jun 16, 2024 · 6 comments · May be fixed by #13539
Closed

[Bug]: chunked prefill scheudler uses up swap on many n>=2 requests #5578

toslunar opened this issue Jun 16, 2024 · 6 comments · May be fixed by #13539
Labels
bug Something isn't working stale

Comments

@toslunar
Copy link
Contributor

Your current environment

The output of `python collect_env.py`

🐛 Describe the bug

Sending many n>=2 (or best_of>=2) requests fills up CPU KV cache, more often if chunked prefill is enabled.

_schedule_chunked_prefill schedules prefills even if there are swapped seq groups
https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L871-L873
while _schedule_default does not
https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L763-L766

To reproduce,

import vllm
print(vllm.__version__)
from vllm import LLM, SamplingParams

long_text = open(vllm.core.scheduler.__file__).read()
prompts = [f"```python\n" + long_text[i:i+1000] for i in range(10000)]

llm = LLM(
    model="facebook/opt-125m",
    enable_chunked_prefill=True,
    disable_log_stats=False,
    max_num_batched_tokens=4096,
    num_gpu_blocks_override=8192,
)

sampling_params = SamplingParams(max_tokens=1000, n=8)
llm.generate(prompts, sampling_params)

consumes CPU KV cache (Running: 39 reqs, Swapped: 129 reqs in the end)

output
0.5.0.post1
/home/kataoka/venv1/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
INFO 06-16 21:23:26 config.py:707] Chunked prefill is enabled (EXPERIMENTAL).
INFO 06-16 21:23:26 llm_engine.py:161] Initializing an LLM engine (v0.5.0.post1) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=facebook/opt-125m)
INFO 06-16 21:23:31 weight_utils.py:218] Using model weights format ['*.bin']
INFO 06-16 21:23:32 model_runner.py:160] Loading model weights took 0.2389 GB
INFO 06-16 21:23:32 llm_engine.py:317] Overriding num_gpu_blocks=127899 with num_gpu_blocks_override=8192
INFO 06-16 21:23:32 gpu_executor.py:83] # GPU blocks: 8192, # CPU blocks: 7281
INFO 06-16 21:23:35 model_runner.py:889] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 06-16 21:23:35 model_runner.py:893] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 06-16 21:23:39 model_runner.py:965] Graph capturing finished in 4 secs.
Processed prompts:   0%|                  | 0/10000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]INFO 06-16 21:23:45 metrics.py:341] Avg prompt throughput: 707.6 tokens/s, Avg generation throughput: 15.2 tokens/s, Running: 12 reqs, Swapped: 0 reqs, Pending: 9988 reqs, GPU KV cache usage: 3.4%, CPU KV cache usage: 0.0%.
INFO 06-16 21:23:50 metrics.py:341] Avg prompt throughput: 1640.1 tokens/s, Avg generation throughput: 15576.1 tokens/s, Running: 35 reqs, Swapped: 0 reqs, Pending: 9965 reqs, GPU KV cache usage: 68.1%, CPU KV cache usage: 0.0%.
WARNING 06-16 21:23:53 scheduler.py:1089] Sequence group 37 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=1
INFO 06-16 21:23:55 metrics.py:341] Avg prompt throughput: 1913.5 tokens/s, Avg generation throughput: 15094.3 tokens/s, Running: 33 reqs, Swapped: 29 reqs, Pending: 9938 reqs, GPU KV cache usage: 99.6%, CPU KV cache usage: 25.7%.
WARNING 06-16 21:23:57 scheduler.py:1089] Sequence group 75 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=51
INFO 06-16 21:24:00 metrics.py:341] Avg prompt throughput: 3919.1 tokens/s, Avg generation throughput: 9969.3 tokens/s, Running: 30 reqs, Swapped: 87 reqs, Pending: 9883 reqs, GPU KV cache usage: 99.5%, CPU KV cache usage: 75.7%.
WARNING 06-16 21:24:01 scheduler.py:1089] Sequence group 123 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=101
Processed prompts:   0%|  | 24/10000 [00:18<1:18:53,  2.11it/s, est. speed input: 463.46 toks/s, output: 8753.30 toks/s]INFO 06-16 21:24:05 metrics.py:341] Avg prompt throughput: 929.1 tokens/s, Avg generation throughput: 11742.2 tokens/s, Running: 36 reqs, Swapped: 69 reqs, Pending: 9870 reqs, GPU KV cache usage: 68.0%, CPU KV cache usage: 36.8%.
Processed prompts:   0%|  | 32/10000 [00:24<1:34:23,  1.76it/s, est. speed input: 465.61 toks/s, output: 8705.86 toks/s]INFO 06-16 21:24:10 metrics.py:341] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 14147.8 tokens/s, Running: 37 reqs, Swapped: 61 reqs, Pending: 9870 reqs, GPU KV cache usage: 88.3%, CPU KV cache usage: 33.1%.
Processed prompts:   0%|  | 34/10000 [00:29<3:36:54,  1.31s/it, est. speed input: 407.31 toks/s, output: 7561.83 toks/s]INFO 06-16 21:24:15 metrics.py:341] Avg prompt throughput: 2706.1 tokens/s, Avg generation throughput: 13395.4 tokens/s, Running: 34 reqs, Swapped: 100 reqs, Pending: 9832 reqs, GPU KV cache usage: 99.2%, CPU KV cache usage: 70.4%.
WARNING 06-16 21:24:15 scheduler.py:1089] Sequence group 165 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=151
Processed prompts:   0%|    | 47/10000 [00:35<42:17,  3.92it/s, est. speed input: 473.45 toks/s, output: 8904.44 toks/s]INFO 06-16 21:24:20 metrics.py:341] Avg prompt throughput: 3737.0 tokens/s, Avg generation throughput: 8906.6 tokens/s, Running: 37 reqs, Swapped: 137 reqs, Pending: 9779 reqs, GPU KV cache usage: 76.6%, CPU KV cache usage: 77.9%.
Processed prompts:   1%|   | 58/10000 [00:36<31:11,  5.31it/s, est. speed input: 562.67 toks/s, output: 10551.00 toks/s]INFO 06-16 21:24:25 metrics.py:341] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 14857.3 tokens/s, Running: 36 reqs, Swapped: 125 reqs, Pending: 9779 reqs, GPU KV cache usage: 76.2%, CPU KV cache usage: 71.7%.
Processed prompts:   1%|  | 61/10000 [00:43<2:07:09,  1.30it/s, est. speed input: 497.61 toks/s, output: 9320.60 toks/s]WARNING 06-16 21:24:29 scheduler.py:1089] Sequence group 222 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=201
INFO 06-16 21:24:30 metrics.py:341] Avg prompt throughput: 627.9 tokens/s, Avg generation throughput: 14683.6 tokens/s, Running: 39 reqs, Swapped: 129 reqs, Pending: 9770 reqs, GPU KV cache usage: 99.3%, CPU KV cache usage: 91.0%.
Processed prompts:   1%|  | 63/10000 [00:46<2:32:04,  1.09it/s, est. speed input: 480.87 toks/s, output: 8954.46 toks/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/home/kataoka/Untitled3.py", line 36, in <module>
[rank0]:     llm.generate(prompts, sampling_params)
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/utils.py", line 691, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 304, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 556, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 765, in step
[rank0]:     seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 948, in schedule
[rank0]:     scheduler_outputs = self._schedule()
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 921, in _schedule
[rank0]:     return self._schedule_chunked_prefill()
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 857, in _schedule_chunked_prefill
[rank0]:     remaining_running, running_scheduled = self._schedule_running(
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 434, in _schedule_running
[rank0]:     preempted_mode = self._preempt(victim_seq_group,
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 1101, in _preempt
[rank0]:     self._preempt_by_swap(seq_group, blocks_to_swap_out)
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 1122, in _preempt_by_swap
[rank0]:     self._swap_out(seq_group, blocks_to_swap_out)
[rank0]:   File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 1142, in _swap_out
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: Aborted due to the lack of CPU swap space. Please increase the swap space to avoid this error.
Processed prompts:   1%|  | 63/10000 [00:47<2:05:35,  1.32it/s, est. speed input: 480.87 toks/s, output: 8954.46 toks/s]
@toslunar toslunar added the bug Something isn't working label Jun 16, 2024
@simon-mo
Copy link
Collaborator

@rkooo567 any possible causes?

@toslunar
Copy link
Contributor Author

To make my suggestion clear,

-        # Schedule new prefills.
-        remaining_waiting, prefills = self._schedule_prefills(
-            self.waiting, budget, curr_loras, enable_chunking=True)
+        if len(remaining_swapped) == 0:
+            # Schedule new prefills.
+            remaining_waiting, prefills = self._schedule_prefills(
+                self.waiting, budget, curr_loras, enable_chunking=True)

on https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L871-L873 fixes the issue.

However, the condition if len(remaining_swapped) == 0 looks too strict and may affect performance when the most of the requests are n == best_of == 1. Something like "CPU KV cache usage < 50%" could be better.

@rkooo567
Copy link
Collaborator

I think n>1 creates more sequences, so it is more likely to use swap/preemption (because there's higher pressure to kv cache). Checking remaining_swapped==0 makes sense to me actually. We should prioritize swapped requests over prefill anyway. (and if all swaps are scheduled, remaining swap becomes 0 anyway). @toslunar would you like to create a PR?

@toslunar
Copy link
Contributor Author

Thank you @rkooo567. It makes sense.

I created a PR. The diff is slightly different than my previous comment.

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 26, 2024
Copy link

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale
Projects
None yet
3 participants