diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index a6e96c0bb4339..9875b80a971cd 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -13,7 +13,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="google/gemma-2-2b-it") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -21,4 +21,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index c874608e40a23..395e0d3da55da 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import List +from typing import List, Tuple import pytest @@ -120,7 +120,7 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, check_answers(indices, answer, test_texts) -def prep_prompts(batch_size: int): +def prep_prompts(batch_size: int, assign_range: Tuple[int, int] = (800, 1100)): """ Generate prompts which a bunch of assignments, then asking for the value of one of them. @@ -136,7 +136,7 @@ def prep_prompts(batch_size: int): indices.append(idx) prompt = "```python\n# We set a number of variables, " + \ f"x{idx} will be important later\n" - ln = random.randint(800, 1100) + ln = random.randint(*assign_range) for k in range(30, ln): v = random.randint(10, 99) if k == idx: @@ -148,7 +148,10 @@ def prep_prompts(batch_size: int): return prompts, answer, indices -def check_answers(indices: List[int], answer: List[int], outputs: List[str]): +def check_answers(indices: List[int], + answer: List[int], + outputs: List[str], + accept_rate=0.7): answer2 = [int(text[0:2].strip()) for text in outputs] print(list(zip(indices, zip(answer, answer2)))) numok = 0 @@ -157,7 +160,7 @@ def check_answers(indices: List[int], answer: List[int], outputs: List[str]): numok += 1 frac_ok = numok / len(answer) print(f"Num OK: {numok}/{len(answer)} {frac_ok}") - assert frac_ok > 0.7 + assert frac_ok >= accept_rate def check_window(prompts: List[str]): diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 60cf4384d3fde..5abbe47584964 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,10 +5,10 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + KVCacheBlock, PrefixLengthRange, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens) + hash_request_tokens, intersect_ranges) from vllm.v1.request import Request @@ -49,7 +49,9 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3)) + block_hash = BlockHashType(hash_value=123, + kv_cache_group_id=0, + token_ids=(1, 2, 3)) block.block_hash = block_hash assert block.block_hash == block_hash @@ -190,11 +192,11 @@ def test_hash_block_tokens(): curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") - block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids, + block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids, 0, extra_keys) assert isinstance(block_hash, BlockHashType) assert block_hash.hash_value == hash( - (parent_block_hash, curr_block_token_ids, extra_keys)) + (parent_block_hash, curr_block_token_ids, 0, extra_keys)) assert block_hash.token_ids == curr_block_token_ids assert block_hash.extra_keys == extra_keys @@ -214,7 +216,7 @@ def test_hash_request_tokens(): ) block_size = 3 - block_hashes = hash_request_tokens(block_size, request) + block_hashes = hash_request_tokens(block_size, request, 0) assert len(block_hashes) == 2 assert isinstance(block_hashes[0], BlockHashType) @@ -255,8 +257,8 @@ def test_hash_tokens_different_mm_input(): mm_hashes=["hash3", "hash2"], ) block_size = 3 - block_hashes1 = hash_request_tokens(block_size, request1) - block_hashes2 = hash_request_tokens(block_size, request2) + block_hashes1 = hash_request_tokens(block_size, request1, 0) + block_hashes2 = hash_request_tokens(block_size, request2, 0) assert block_hashes1[0] != block_hashes2[0] assert block_hashes1[1] != block_hashes2[1] @@ -270,10 +272,35 @@ def test_hash_request_tokens_no_mm_inputs(): ) block_size = 3 - block_hashes = hash_request_tokens(block_size, request) + block_hashes = hash_request_tokens(block_size, request, 0) assert len(block_hashes) == 2 assert block_hashes[0].token_ids == (0, 1, 2) assert block_hashes[0].extra_keys is None assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1].extra_keys is None + + +def test_prefix_length_range_intersection(): + range0 = [ + PrefixLengthRange(1, 5), + PrefixLengthRange(10, 14), + PrefixLengthRange(16, 18) + ] + range1 = [ + PrefixLengthRange(2, 6), + PrefixLengthRange(8, 12), + PrefixLengthRange(15, 17) + ] + range2 = [PrefixLengthRange(3, 11), PrefixLengthRange(13, 19)] + ranges = [range0, range1, range2] + + intersection = intersect_ranges(ranges) + assert intersection == [ + PrefixLengthRange(3, 5), + PrefixLengthRange(10, 11), + PrefixLengthRange(16, 17) + ] + + +# TODO: add tests for hash of kv_cache_group_id diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a6c0162d3f308..2fc68fe6c85d0 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 """Compare the with and without prefix caching.""" import pytest +import torch from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroup) def make_request(request_id, @@ -32,12 +35,21 @@ def make_request(request_id, ) +def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + tensors={}, + groups=[ + KVCacheGroup(['layer'], + FullAttentionSpec(block_size, 1, 1, 1, torch.float32)) + ], + ) + + def test_prefill(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -51,17 +63,19 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(req0.kv_block_hashes) == 3 - assert not computed_blocks + assert len( + manager.managers.managers[0].req_to_block_hashes[req0.request_id]) == 3 + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + blocks = manager.allocate_slots(req0, 55, computed_blocks, + num_computed_tokens) + assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] # Check full block metadata parent_block_hash = None for block_id in (0, 1, 2): block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) - block_hash = hash_block_tokens(parent_block_hash, block_tokens) + block_hash = hash_block_tokens(parent_block_hash, block_tokens, 0) assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value @@ -76,13 +90,15 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 - assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert len( + manager.managers.managers[0].req_to_block_hashes[req1.request_id]) == 3 + assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5, 6] - for block in computed_blocks: + blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks, + num_computed_tokens) + assert [b.block_id for b in blocks[0]] == [5, 6] + for block in computed_blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 3 free blocks left. @@ -107,12 +123,14 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(req2.kv_block_hashes) == 3 - assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert len( + manager.managers.managers[0].req_to_block_hashes[req2.request_id]) == 3 + assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [7, 8] + blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks, + num_computed_tokens) + assert [b.block_id for b in blocks[0]] == [7, 8] # Although we only have 5 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -128,11 +146,12 @@ def test_prefill(): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 9)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks, + num_computed_tokens) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] + assert [b.block_id for b in blocks[0]] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert manager.free_block_queue.num_free_blocks == 0 assert manager.free_block_queue.free_list_head is None assert manager.free_block_queue.free_list_tail is None @@ -140,10 +159,8 @@ def test_prefill(): def test_decode(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -156,18 +173,20 @@ def test_decode(): unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + blocks = manager.allocate_slots(req0, 55, computed_blocks, + num_computed_tokens) + assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) - assert new_blocks is not None and len(new_blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is None + assert new_blocks is not None and len(new_blocks[0]) == 0 + assert manager.managers.managers[0].req_to_blocks[ + req0.request_id][-2].block_hash is None # Append slots without allocating a new block, but start using the # preallocated block. @@ -177,8 +196,9 @@ def test_decode(): for _ in range(5 + 10): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 15) - assert new_blocks is not None and len(new_blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None + assert new_blocks is not None and len(new_blocks[0]) == 0 + assert manager.managers.managers[0].req_to_blocks[ + req0.request_id][-2].block_hash is not None # Append slots with allocating a new block. req0.num_computed_tokens = 74 @@ -188,15 +208,13 @@ def test_decode(): req0.append_output_token_ids(12) new_blocks = manager.allocate_slots(req0, 17) # Plus one preallocated block. - assert new_blocks is not None and len(new_blocks) == 2 + assert new_blocks is not None and len(new_blocks[0]) == 2 def test_evict(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -204,19 +222,21 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) - assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated + blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 7 # 5 full + 1 partial + 1 preallocated # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) - assert len(blocks) == 3 # 3 full blocks + blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 assert manager.free_block_queue.num_free_blocks == 0 @@ -231,10 +251,11 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert [b.block_id for b in computed_blocks] == [0, 1] + assert [b.block_id for b in computed_blocks[0]] == [0, 1] assert num_computed_tokens == 2 * 16 - blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [6, 5] + blocks = manager.allocate_slots(req2, 3, computed_blocks, + num_computed_tokens) + assert [b.block_id for b in blocks[0]] == [6, 5] assert manager.free_block_queue.num_free_blocks == 6 @@ -245,10 +266,8 @@ def test_hash_block_correct_reuse(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=1, + make_kv_cache_config(block_size, 1), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -257,10 +276,11 @@ def test_hash_block_correct_reuse(): num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens, computed_blocks) - assert len(blocks) == 1 + blocks = manager.allocate_slots(req, num_tokens, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 1 # Deallocate the block. manager.free(req) @@ -269,12 +289,13 @@ def test_hash_block_correct_reuse(): # block is cleared. req = make_request("1", list(range(num_tokens - 1))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) - assert len(blocks) == 1 + blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 1 - assert manager.block_pool[blocks[0].block_id].block_hash is None + assert manager.block_pool[blocks[0][0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -284,10 +305,8 @@ def test_computed_blocks_not_evicted(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=2, + make_kv_cache_config(block_size, 2), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -296,20 +315,22 @@ def test_computed_blocks_not_evicted(): num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 0 + blocks = manager.allocate_slots(req0, num_tokens, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 1 + assert blocks[0][0].block_id == 0 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 1 + blocks = manager.allocate_slots(req1, num_tokens, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 1 + assert blocks[0][0].block_id == 1 # Free the blocks. manager.free(req0) @@ -319,14 +340,14 @@ def test_computed_blocks_not_evicted(): # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks) == 1 - assert computed_blocks[0].block_id == 0 + assert len(computed_blocks[0]) == 1 + assert computed_blocks[0][0].block_id == 0 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, - computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 1 + computed_blocks, num_computed_tokens) + assert len(blocks[0]) == 1 + assert blocks[0][0].block_id == 1 def test_basic_prefix_caching_disabled(): @@ -335,10 +356,8 @@ def test_basic_prefix_caching_disabled(): """ block_size = 4 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=4, + make_kv_cache_config(block_size, 4), max_model_len=8192, - sliding_window=None, enable_caching=False, num_preallocate_tokens=0, ) @@ -346,10 +365,11 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 10, computed_blocks) - assert len(blocks) == 3 + blocks = manager.allocate_slots(req1, 10, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 3 # Free the blocks. manager.free(req1) @@ -357,17 +377,19 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 16, computed_blocks) - assert len(blocks) == 4 + blocks = manager.allocate_slots(req2, 16, computed_blocks, + num_computed_tokens) + assert len(blocks[0]) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 4, computed_blocks) + blocks = manager.allocate_slots(req3, 4, computed_blocks, + num_computed_tokens) assert not blocks @@ -378,10 +400,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): This tests that the preallocated blocks are correctly added. """ manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=10, + make_kv_cache_config(block_size, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=num_preallocate_tokens, ) @@ -389,22 +409,23 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): req = make_request("0", list(range(block_size * 30))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 # Just ask for 1 block. - blocks = manager.allocate_slots(req, block_size, computed_blocks) + blocks = manager.allocate_slots(req, block_size, computed_blocks, + num_computed_tokens) req.num_computed_tokens = block_size - assert len(blocks) == 1 + num_preallocated_blocks + assert len(blocks[0]) == 1 + num_preallocated_blocks # Assume all computed, only when num_preallocate_tokens > 0, we need to # consume the previously preallocated blocks. if num_preallocated_blocks > 0: - manager.allocate_slots(req, block_size * (len(blocks) - 1)) - req.num_computed_tokens = block_size * len(blocks) + manager.allocate_slots(req, block_size * (len(blocks[0]) - 1)) + req.num_computed_tokens = block_size * len(blocks[0]) # Append 1 block. blocks = manager.allocate_slots(req, block_size) - assert len(blocks) == 1 + num_preallocated_blocks + assert len(blocks[0]) == 1 + num_preallocated_blocks def test_cache_blocks(): @@ -414,10 +435,8 @@ def test_cache_blocks(): """ block_size = 4 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=5, + make_kv_cache_config(block_size, 5), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -430,12 +449,15 @@ def test_cache_blocks(): # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] + block_hashes = [] manager._cache_full_blocks( request=req, + block_hashes=block_hashes, blk_start_idx=0, full_blocks=blocks, prev_block=None, + kv_cache_group_id=0, ) assert len(manager.cached_block_hash_to_block) == 2 @@ -445,9 +467,11 @@ def test_cache_blocks(): blocks = [KVCacheBlock(block_id=2)] manager._cache_full_blocks( request=req, + block_hashes=block_hashes, blk_start_idx=2, full_blocks=blocks, prev_block=None, + kv_cache_group_id=0, ) assert len(manager.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None @@ -458,10 +482,8 @@ def test_mm_prefix_caching(): This tests that the multi-modal prefix caching is correct. """ manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -492,26 +514,28 @@ def test_mm_prefix_caching(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - assert len(req0.kv_block_hashes) == 3 - assert req0.kv_block_hashes[0].extra_keys == ("aaa", ) - assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb") - assert req0.kv_block_hashes[2].extra_keys == ("bbb", ) - - blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes[0]) == 3 + assert block_hashes[0][0].extra_keys == ("aaa", ) + assert block_hashes[0][1].extra_keys == ("aaa", "bbb") + assert block_hashes[0][2].extra_keys == ("bbb", ) + + blocks = manager.allocate_slots(req0, 59, computed_blocks, + num_computed_tokens) + assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks[0]) == 0 # The just completed block should have hashes with extra keys. - assert len(req0.kv_block_hashes) == 4 - assert req0.kv_block_hashes[3].extra_keys == ("ccc", ) + assert len(block_hashes[0]) == 4 + assert block_hashes[0][3].extra_keys == ("ccc", ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -525,7 +549,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(computed_blocks) == 3 + assert len(computed_blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -538,10 +562,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=10, + make_kv_cache_config(block_size, 10), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -550,9 +572,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req0, 48, computed_blocks) + manager.allocate_slots(req0, 48, computed_blocks, num_computed_tokens) block_part0 = manager.req_to_blocks[req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | @@ -560,21 +582,22 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks == block_part0 assert num_computed_tokens == 3 * 16 - manager.allocate_slots(req1, 48, computed_blocks) + manager.allocate_slots(req1, 48, computed_blocks, num_computed_tokens) block_part1 = manager.req_to_blocks[req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) - assert {block.ref_cnt for block in block_part1[:3]} == {1} - assert {block.ref_cnt for block in block_part1[3:]} == {0} + assert {block.ref_cnt for block in block_part1[0][:3]} == {1} + assert {block.ref_cnt for block in block_part1[0][3:]} == {0} # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req2, block_size * 2, computed_blocks) + manager.allocate_slots(req2, block_size * 2, computed_blocks, + num_computed_tokens) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -585,11 +608,12 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert computed_blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. - assert manager.allocate_slots(req3, 48, computed_blocks) is None + assert manager.allocate_slots(req3, 48, computed_blocks, + num_computed_tokens) is None # Block 0-2 are used by Req 1. - assert {block.ref_cnt for block in block_part1[:3]} == {1} + assert {block.ref_cnt for block in block_part1[0][:3]} == {1} # Block 3-5 are free. - assert {block.ref_cnt for block in block_part1[3:]} == {0} + assert {block.ref_cnt for block in block_part1[0][3:]} == {0} def test_reset_prefix_cache(): @@ -606,17 +630,21 @@ def test_reset_prefix_cache(): unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) - blocks = manager.allocate_slots(req0, 55) - assert [b.block_id for b in blocks] == [0, 1, 2, 3] + computed_blocks, _ = manager.get_computed_blocks(req0) + assert len(req0.kv_block_hashes[0]) == 3 + assert len(computed_blocks[0]) == 0 + blocks = manager.allocate_slots(req0, 55, computed_blocks) + assert [b.block_id for b in blocks[0]] == [0, 1, 2, 3] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 - assert len(computed_blocks) == 3 + assert len( + manager.managers.managers[0].req_to_block_hashes[req1.request_id]) == 3 + assert len(computed_blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) - assert [b.block_id for b in blocks] == [4] + assert [b.block_id for b in blocks[0]] == [4] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py new file mode 100644 index 0000000000000..969b18afbe977 --- /dev/null +++ b/tests/v1/core/test_specialized_manager.py @@ -0,0 +1,116 @@ +from collections import deque +from typing import Deque + +import torch + +from vllm.v1.core.specialized_manager import (BlockPoolOperations, + SlidingWindowManager) +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + PrefixLengthRange) +from vllm.v1.kv_cache_interface import SlidingWindowSpec + + +def test_sliding_window_possible_cached_prefix(): + sliding_window_spec = SlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + ) + + block_pool_result: Deque[bool] = deque() + null_block = KVCacheBlock(-1, 0) + + def get_cached_block(_block_hash): + if isinstance(_block_hash, + BlockHashType) and _block_hash.hash_value == -1: + # the dummy block hash + return None + is_cached = block_pool_result.popleft() + if is_cached: + return 1 + else: + return None + + def get_null_block(): + return null_block + + manager = SlidingWindowManager( + sliding_window_spec, + BlockPoolOperations(get_cached_block, get_null_block)) + + block_pool_result.clear() + block_pool_result.extend([ + True, True, False, True, False, False, True, True, False, True, True, + True + ]) + ranges, computed_blocks = manager.get_possible_cached_prefix( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert ranges == [ + PrefixLengthRange(0, 4), + PrefixLengthRange(16, 16), + PrefixLengthRange(22, 24) + ] + assert computed_blocks == [ + 1, 1, null_block, 1, null_block, null_block, 1, 1, null_block, 1, 1, 1 + ] + + +def test_sliding_window_remove_useless_blocks(): + sliding_window_spec = SlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + ) + + def get_cached_block(_block_hash): + # should not be called + raise NotImplementedError + + def get_null_block(): + return KVCacheBlock(-1, 0) + + manager = SlidingWindowManager( + sliding_window_spec, + BlockPoolOperations(get_cached_block, get_null_block)) + + def id_to_block_table(ids): + return [ + KVCacheBlock(id_, 0) if id_ != -1 else get_null_block() + for id_ in ids + ] + + def assert_block_id(block_table, ids): + for block, id_ in zip(block_table, ids): + if id_ == -1: + assert block == get_null_block() + else: + assert block.block_id == id_ + + block_table = id_to_block_table([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + removed = manager.remove_useless_blocks(block_table, 0) + assert_block_id(removed, []) + assert_block_id(block_table, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 5) + assert_block_id(removed, []) + assert_block_id(block_table, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 6) + assert_block_id(removed, [0]) + assert_block_id(block_table, [-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 7) + assert_block_id(removed, []) + assert_block_id(block_table, [-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 8) + assert_block_id(removed, [1]) + assert_block_id(block_table, [-1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + removed = manager.remove_useless_blocks(block_table, 12) + assert_block_id(removed, [3, 2]) + assert_block_id(block_table, [-1, -1, -1, -1, 4, 5, 6, 7, 8, 9, 10]) diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py new file mode 100644 index 0000000000000..2bc6d3d1712fc --- /dev/null +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass +from typing import List, Tuple + +import pytest + +from vllm import LLM, SamplingParams + +from ...core.block.e2e.test_correctness_sliding_window import (check_answers, + prep_prompts) + + +@dataclass +class TestConfig: + sliding_window: int + assign_range: Tuple[int, int] + + +model_config = { + "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), + "google/gemma-2-2b-it": TestConfig(4096, (400, 800)), +} + + +@pytest.mark.parametrize("model", + ["bigcode/starcoder2-3b", "google/gemma-2-2b-it"]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_retrival(monkeypatch, model, batch_size, seed): + """ + The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then + asks for value of one of them (which is outside the sliding window). + If we tell it upfront which we are going to be looking for, then + it answers correctly (mostly). + """ + # TODO: implement check_window + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + test_config = model_config[model] + + llm = LLM(model=model) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + + prompts, answer, indices = prep_prompts( + batch_size, assign_range=test_config.assign_range) + + # both starcoder2-3b and gemma-2-2b-it have 4096 sliding window + check_window(prompts, llm, test_config.sliding_window) + + responses = llm.generate(prompts, sampling_params) + check_answers(indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0) + + +def check_window(prompts: List[str], llm: LLM, sliding_window: int): + tokenizer = llm.get_tokenizer() + max_model_len = llm.llm_engine.model_config.max_model_len + assert any( + len(tokenizer.encode(prompt)) > sliding_window + for prompt in prompts), "Prompt is too short for test" + assert all( + len(tokenizer.encode(prompt)) <= max_model_len + for prompt in prompts), "Prompt is too long for test" diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 5b40fbff8212e..561b4399b943e 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -168,7 +168,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): max_num_blocks_per_req=10, device=torch.device(device), pin_memory=is_pin_memory_available(), - vocab_size=1024) + vocab_size=1024, + num_kv_cache_groups=1) reqs: List[CachedRequestState] = [] req_id_reqs = {} req_id_output_token_ids = {} diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 19ee89630ffa4..466ea00a0c154 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -288,6 +288,8 @@ def unified_attention( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) @@ -320,6 +322,8 @@ def unified_attention_with_output( attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self.impl.forward(self, query, key, diff --git a/vllm/config.py b/vllm/config.py index bc4bf627b8e74..d5bf5050d1b61 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1125,7 +1125,7 @@ def _verify_prefix_caching(self) -> None: if not self.enable_prefix_caching: return - if self.sliding_window is not None: + if self.sliding_window is not None and not envs.VLLM_USE_V1: raise NotImplementedError( "Prefix caching is not supported with sliding window. " "Run with --disable-sliding-window to use prefix caching.") diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 10de8bc593ab8..48bf9694d806a 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import torch @@ -26,14 +26,33 @@ @dataclass class ForwardContext: - # copy from vllm_config.compilation_config.static_forward_context + """ + Map from layer_name to all attention modules + copy from vllm_config.compilation_config.static_forward_context + """ attn_layers: Dict[str, Any] - # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass - # TODO: remove after making all virtual_engines share the same kv cache + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, mapping from layer_name to + AttentionMetadata of that layer + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", Dict[str, "AttentionMetadata"]] + """ + The virtual_engine for v0 pipeline parallelism + set dynamically for each forward pass + """ virtual_engine: int # set dynamically for each forward pass +@dataclass +class ForwardMetadata: + """ + Forward metadata for each forward pass + """ + num_input_tokens: int + + _forward_context: Optional[ForwardContext] = None @@ -48,7 +67,8 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0): + virtual_engine: int = 0, + forward_metadata: Optional[ForwardMetadata] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -68,13 +88,14 @@ def set_forward_context(attn_metadata: Any, finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: - if hasattr(attn_metadata, "num_prefill_tokens"): + if not envs.VLLM_USE_V1: # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens else: # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + assert forward_metadata is not None + batchsize = forward_metadata.num_input_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py new file mode 100644 index 0000000000000..4264da7fd347d --- /dev/null +++ b/vllm/v1/core/block_pool.py @@ -0,0 +1,244 @@ +from collections import defaultdict +from typing import Dict, List, Optional + +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class BlockPool: + + def __init__(self, num_gpu_blocks: int, enable_caching: bool): + self.num_gpu_blocks = num_gpu_blocks + self.enable_caching = enable_caching + # A Block pool of all kv-cache blocks. + self._block_pool: List[KVCacheBlock] = [ + KVCacheBlock(idx) for idx in range(self.num_gpu_blocks) + ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). + self._free_block_queue = FreeKVCacheBlockQueue(self._block_pool) + + # {block_hash: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. + # The cached block may be used by running requests or in the + # free_block_queue that could potentially be evicted. + # NOTE: We currently don't de-duplicate the blocks in the cache, + # meaning that if a block becomes full and is cached, we don't check + # if there is already an identical block in the cache. This is because + # we want to make sure the allocated block IDs won't change so that + # block tables are append-only. + self._cached_block_hash_to_block: Dict[BlockHashType, Dict[ + int, KVCacheBlock]] = defaultdict(dict) + + self._null_block: KVCacheBlock = KVCacheBlock(-1) + + def get_cached_block(self, + block_hash: BlockHashType) -> Optional[KVCacheBlock]: + """Get a cached block by the block hash, or None if cache miss. + If there are duplicated blocks, we return the first block in the cache. + + Args: + block_hash: The hash value of the block. + + Returns: + The cached block if it exists, or None. + """ + if block_hash in self._cached_block_hash_to_block: + first_block_id = list( + self._cached_block_hash_to_block[block_hash].keys())[0] + return self._cached_block_hash_to_block[block_hash][first_block_id] + return None + + def cache_full_blocks( + self, + request: Request, + block_hashes: List[BlockHashType], + block_size: int, + blk_start_idx: int, + full_blocks: List[KVCacheBlock], + prev_block: Optional[KVCacheBlock], + kv_cache_group_id: int, + ) -> None: + """Cache a list of full blocks for prefix caching. + + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `blk_start_idx` to the end + of the request's full blocks, updating the metadata for each block + and caching them in the `_cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blk_start_idx: The index of the first block in the request's blocks + to cache. + full_blocks: The list of blocks to update hash metadata. + prev_block: The previous block in the chain. + kv_cache_group_id: The KV cache group that the blocks belong to + """ + num_cached_block_hashes = len(block_hashes) + + # Update the new blocks with the block hashes through the chain. + prev_block_hash_value = None + if prev_block is not None: + # Previous block must have a block hash because it must be + # a full, cached block. + assert prev_block.block_hash is not None + prev_block_hash_value = prev_block.block_hash.hash_value + + # Find the first uncached block. This case should only happen when + # speculative decoding is used. + offset = 0 + for blk in full_blocks: + if blk.block_hash is None: + break + else: + prev_block_hash_value = blk.block_hash.hash_value + offset += 1 + else: + # All blocks are cached. + return + + for i, blk in enumerate(full_blocks[offset:]): + blk_idx = blk_start_idx + offset + i + assert blk.block_hash is None + + if blk_idx < num_cached_block_hashes: + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = block_hashes[blk_idx] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * block_size + end_token_idx = (blk_idx + 1) * block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == block_size, ( + f"Expected {block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, kv_cache_group_id, + extra_keys) + block_hashes.append(block_hash) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self._cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash_value = block_hash.hash_value + + def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + + Note that we do not check block cache in this function. + + Args: + num_blocks: The number of blocks to allocate. + + Returns: + A list of new block. + """ + if num_blocks > self._free_block_queue.num_free_blocks: + raise ValueError( + f"Cannot get {num_blocks} free blocks from the pool") + + ret: List[KVCacheBlock] = [] + idx = 0 + while idx < num_blocks: + # First allocate blocks. + curr_block = self._free_block_queue.popleft() + assert curr_block.ref_cnt == 0 + + # If the block is cached, evict it. + if self.enable_caching: + self._maybe_evict_cached_block(curr_block) + + curr_block.incr_ref() + ret.append(curr_block) + idx += 1 + + return ret + + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: + """ + If a block is cached in `_cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. + + Args: + block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. + """ + block_hash = block.block_hash + if block_hash and block_hash in self._cached_block_hash_to_block: + block.reset_hash() + del self._cached_block_hash_to_block[block_hash][block.block_id] + + if len(self._cached_block_hash_to_block[block_hash]) == 0: + del self._cached_block_hash_to_block[block_hash] + + return True + return False + + def touch(self, blocks: List[KVCacheBlock]) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and block != self._null_block: + self._free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, ordered_blocks: List[KVCacheBlock]) -> None: + for block in ordered_blocks: + if block == self._null_block: + continue + block.decr_ref() + if block.ref_cnt == 0: + self._free_block_queue.append(block) + + def get_num_free_blocks(self) -> int: + return self._free_block_queue.num_free_blocks + + def get_null_block(self) -> KVCacheBlock: + return self._null_block + + def reset_prefix_cache(self): + num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks()) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Remove all hashes so that no new blocks will hit. + self._cached_block_hash_to_block = defaultdict(dict) + + # Remove all hashes from all blocks. + for block in self.block_pool: + block.reset_hash() + + logger.info("Successfully reset prefix cache") + return True diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index de349ec120999..19c1405a09e1b 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,36 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict -from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple +from typing import List, Optional, Tuple +from vllm.v1.core.block_pool import BlockPool from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens, - hash_request_tokens) +from vllm.v1.core.kv_cache_utils import MayGroupedKVCacheBlocks +from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus logger = init_logger(__name__) class KVCacheManager: + """ + The KVCacheManager for models with one KV cache type (e.g., Llama) and + thus one kv cache group (Refer to class `KVCacheConfig` for the meaning of + kv cache group). + """ def __init__( self, - block_size: int, - num_gpu_blocks: int, + kv_cache_config: KVCacheConfig, max_model_len: int, - sliding_window: Optional[int] = None, enable_caching: bool = True, num_preallocate_tokens: int = 64, ) -> None: - self.block_size = block_size - self.num_gpu_blocks = num_gpu_blocks + self.kv_cache_config = kv_cache_config + self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, block_size) - self.sliding_window = sliding_window self.enable_caching = enable_caching # NOTE(woosuk): To avoid frequent block allocation, we preallocate some # blocks for each request. For example, when a request reaches the end @@ -42,43 +41,37 @@ def __init__( # the request gets N empty blocks, it starts to use the blocks without # further allocation. When it uses up all the N empty blocks, it gets # N new empty blocks. + # NOTE(Chen): For simplicity, we keep the number of preallocated blocks + # the same for all kv cache groups, which will result in different + # preallocated tokens for different groups if their block sizes are + # different. self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) - - # A Block pool of all kv-cache blocks. - self.block_pool: List[KVCacheBlock] = [ - KVCacheBlock(idx) for idx in range(num_gpu_blocks) - ] - # Free block queue that constructs and manipulates a doubly linked - # list of free blocks (including eviction candidates when caching is - # enabled). - self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) - - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ - int, KVCacheBlock]] = defaultdict(dict) - - # Mapping from request ID to blocks to track the blocks allocated - # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: DefaultDict[str, - List[KVCacheBlock]] = defaultdict(list) + # NOTE(Chen): For simplicity, we keep the number of preallocated blocks + # the same for all kv cache groups, which will result in different + # preallocated tokens for different groups if their block sizes are + # different. + self.num_preallocate_blocks = cdiv( + num_preallocate_tokens, + max(g.kv_cache_spec.block_size for g in kv_cache_config.groups)) + + self.block_pool = BlockPool(self.num_gpu_blocks, self.enable_caching) + + # Specialized managers for each kv cache group, which handle the + # different kv cache management logic of different attention layers. + self.managers = get_specialized_manager( + kv_cache_config, + max_model_len, + enable_caching, + self.block_pool, + ) @property def usage(self) -> float: - return 1.0 - (self.free_block_queue.num_free_blocks / + return 1.0 - (self.block_pool.get_num_free_blocks() / self.num_gpu_blocks) def get_computed_blocks( - self, request: Request) -> Tuple[List[KVCacheBlock], int]: + self, request: Request) -> Tuple[MayGroupedKVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -87,51 +80,41 @@ def get_computed_blocks( Returns: A tuple containing: - - A list of blocks that are computed for the request. + - The blocks that are computed for the request - The number of computed tokens. """ if not self.enable_caching: # Prefix caching is disabled. - return [], 0 - - computed_blocks = [] + return self.managers.new_block_list(), 0 # The block hashes for the request may already be computed - # if the request was preempted and resumed. - if not request.kv_block_hashes: - request.set_kv_block_hashes( - hash_request_tokens(self.block_size, request)) - block_hashes = request.kv_block_hashes - - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self._get_cached_block(block_hash): - computed_blocks.append(cached_block) - else: - break - - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size + # if the scheduler has tried to schedule the request before. + block_hashes = self.managers.hash_request_tokens(request) + + prefix_length, computed_blocks = self.managers.get_possible_cached_prefix( + block_hashes) + + num_computed_tokens = prefix_length[-1].end + + computed_blocks = self.managers.truncate_computed_blocks( + computed_blocks, num_computed_tokens) return computed_blocks, num_computed_tokens def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[List[KVCacheBlock]] = None - ) -> Optional[List[KVCacheBlock]]: + new_computed_blocks: Optional[MayGroupedKVCacheBlocks] = None, + num_new_computed_tokens: int = 0, + ) -> Optional[MayGroupedKVCacheBlocks]: """Add slots for a request with new tokens to append. Args: request: The request to allocate slots. num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. - new_computed_blocks: A list of new computed blocks just hitting the - prefix caching. + new_computed_blocks_all_groups: A list of new computed blocks + just hitting the prefix caching. Blocks layout: ----------------------------------------------------------------------- @@ -151,83 +134,51 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - new_computed_blocks = new_computed_blocks or [] + new_computed_blocks = new_computed_blocks or self.managers.new_block_list( + ) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) - num_required_blocks = cdiv(num_computed_tokens + num_tokens, - self.block_size) - req_blocks = self.req_to_blocks[request.request_id] - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) + num_new_computed_tokens) + + # We can free blocks that are no longer needed even if we cannot + # schedule this request due to the limit of free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. + blocks_to_free = self.managers.remove_useless_blocks( + request, num_computed_tokens) + self.managers.free_blocks(blocks_to_free) + + num_new_blocks = self.managers.get_req_num_new_blocks( + request, new_computed_blocks, num_computed_tokens, num_tokens) + num_new_blocks = max(num_new_blocks, 0) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks - if blk.ref_cnt == 0) - if (num_new_blocks > self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks = sum( + 1 for blk in self.managers.iter_all(new_computed_blocks) + if blk.ref_cnt == 0) + + if (num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): - # Cannot allocate new blocks + # Cannot allocate new blocks. return None - # Touch the computed blocks to make sure they won't be evicted. - if self.enable_caching: - self._touch(new_computed_blocks) - else: - assert not new_computed_blocks, ( - "Computed blocks should be empty when " - "prefix caching is disabled") - - # Append the new computed blocks to the request blocks until now to - # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_blocks) + # Truncate the number of pre-allocated blocks to ensure that we can + # have at least `num_new_blocks` free blocks for each group. + num_preallocate_blocks = min( + self.num_preallocate_blocks, + (self.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks - num_new_blocks) // + len(self.kv_cache_config.groups)) - # Start to handle new blocks - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_new_blocks = min( - num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self._get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) - - if not self.enable_caching: - return new_blocks - - # NOTE(rickyx): We are assuming the `num_tokens` are actual - # tokens rather than lookahead slots (e.g. for speculative decoding). - # TODO(rickyx): When supporting speculative decoding, we will need to - # differentiate between them so that we can know how many blocks are - # full after appending the actual tokens. - num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - num_computed_full_blocks = num_computed_tokens // self.block_size - new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=num_computed_full_blocks, - # The new full blocks are the full blocks that are not computed. - full_blocks=new_full_blocks, - prev_block=(req_blocks[num_computed_full_blocks - 1] - if num_computed_full_blocks > 0 else None)) + new_blocks = self.managers.allocate_slots(request, new_computed_blocks, + num_new_blocks, + num_preallocate_blocks, + num_computed_tokens, + num_tokens) return new_blocks @@ -239,18 +190,15 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) - - for block in ordered_blocks: - block.decr_ref() - if block.ref_cnt == 0: - self.free_block_queue.append(block) + # Default to None in case a request is freed (aborted) before alloc. + blocks = self.managers.pop_blocks_of_request(request.request_id) + if blocks is None: + # This request is freed before alloc. just return + return + else: + # Reverse the blocks so that the tail blocks can have higher + # eviction priority. + self.managers.free_blocks(blocks, need_reverse=True) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -261,29 +209,14 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ - num_used_blocks = (self.num_gpu_blocks - - self.free_block_queue.num_free_blocks) - if num_used_blocks > 0: - logger.warning( - "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) - return False - - # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) - # Remove all hashes from all blocks. - for block in self.block_pool: - block.reset_hash() - - logger.info("Successfully reset prefix cache") - return True + return self.block_pool.reset_prefix_cache() def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> int: + ) -> List[int]: """Calculate the number of common prefix blocks shared by all requests in the RUNNING state. @@ -317,184 +250,15 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - int: The number of common prefix blocks. - """ - assert request.status == RequestStatus.RUNNING - blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - return num_common_blocks - - def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: - """Get new blocks from the free block pool. - - Note that we do not check block cache in this function. - - Args: - num_blocks: The number of blocks to allocate. - - Returns: - A list of new block. - """ - if num_blocks > self.free_block_queue.num_free_blocks: - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") - - ret: List[KVCacheBlock] = [] - idx = 0 - while idx < num_blocks: - # First allocate blocks. - curr_block = self.free_block_queue.popleft() - assert curr_block.ref_cnt == 0 - - # If the block is cached, evict it. - if self.enable_caching: - self._maybe_evict_cached_block(curr_block) - - curr_block.incr_ref() - ret.append(curr_block) - idx += 1 - - return ret - - def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: - """ - If a block is cached in `cached_block_hash_to_block`, we reset its hash - metadata and evict it from the cache. - - Args: - block: The block to evict. - - Returns: - True if the block is evicted, False otherwise. + List[int]: The number of common prefix blocks per KV cache group. """ - block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] + return self.managers.get_num_common_prefix_blocks( + request, num_running_requests) - return True - return False + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. - def _get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: - """Get a cached block by the block hash, or None if cache miss. - If there are duplicated blocks, we return the first block in the cache. - - Args: - block_hash: The hash value of the block. - - Returns: - The cached block if it exists, or None. + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. """ - if block_hash in self.cached_block_hash_to_block: - first_block_id = list( - self.cached_block_hash_to_block[block_hash].keys())[0] - return self.cached_block_hash_to_block[block_hash][first_block_id] - return None - - def _touch(self, blocks: List[KVCacheBlock]) -> None: - """Touch a block increases its reference count by 1, and may remove - the block from the free queue. This is used when a block is hit by - another request with the same prefix. - - Args: - blocks: A list of blocks to touch. - """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0: - self.free_block_queue.remove(block) - block.incr_ref() - - def _cache_full_blocks( - self, - request: Request, - blk_start_idx: int, - full_blocks: List[KVCacheBlock], - prev_block: Optional[KVCacheBlock], - ) -> None: - """Cache a list of full blocks for prefix caching. - - This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `blk_start_idx` to the end - of the request's full blocks, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. - - Args: - request: The request to cache the blocks. - blk_start_idx: The index of the first block in the request's blocks - to cache. - full_blocks: The list of blocks to update hash metadata. - prev_block: The previous block in the chain. - """ - num_cached_block_hashes = len(request.kv_block_hashes) - - # Update the new blocks with the block hashes through the chain. - prev_block_hash_value = None - if prev_block is not None: - # Previous block must have a block hash because it must be - # a full, cached block. - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.hash_value - - # Find the first uncached block. This case should only happen when - # speculative decoding is used. - offset = 0 - for blk in full_blocks: - if blk.block_hash is None: - break - else: - prev_block_hash_value = blk.block_hash.hash_value - offset += 1 - else: - # All blocks are cached. - return - - for i, blk in enumerate(full_blocks[offset:]): - blk_idx = blk_start_idx + offset + i - assert blk.block_hash is None - - if blk_idx < num_cached_block_hashes: - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption). In this case we simply - # reuse the block hash. - block_hash = request.kv_block_hashes[blk_idx] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - start_token_idx = blk_idx * self.block_size - end_token_idx = (blk_idx + 1) * self.block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, - block_tokens, extra_keys) - request.append_kv_block_hashes(block_hash) - - # Update and added the full block to the cache. - blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk - prev_block_hash_value = block_hash.hash_value + self.managers.pop_block_hashes_of_request(request.request_id) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e0976ba8577b9..b8b741de0ccdb 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" +import math +from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, - KVCacheTensor) +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroup, + KVCacheNewTensor, KVCacheReuseTensor, + KVCacheSpec) from vllm.v1.request import Request logger = init_logger(__name__) @@ -24,6 +27,8 @@ class BlockHashType(NamedTuple): hash_value: int # Token IDs in the block. token_ids: Tuple[int, ...] + # The KV cache group that the block belongs to. + kv_cache_group_id: int # Extra keys for the block. extra_keys: Optional[Any] = None @@ -31,7 +36,8 @@ class BlockHashType(NamedTuple): @dataclass class KVCacheBlock: """KV-cache block metadata.""" - # Block ID, ranging from 0 to num_gpu_blocks - 1. + # Block ID, ranging from 0 to num_gpu_blocks - 1, and a special null_block + # with block_id = -1. block_id: int # Reference count. ref_cnt: int = 0 @@ -64,6 +70,32 @@ def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None + def __repr__(self): + # print block_id instead of KVCacheBlock object to avoid printing the + # KVCacheBlock object recursively. + prev_block_id = self.prev_free_block.block_id \ + if self.prev_free_block else None + next_block_id = self.next_free_block.block_id \ + if self.next_free_block else None + return (f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}), " + f"_block_hash={self._block_hash}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})") + + +"""When a model needs different types of kv_caches (e.g., full attention + +sliding window attention), the attention layers will be split to multiple +"KV cache groups", where layers in the same group has the same kv cache type and +can use the same KVCacheBlock. There will be only one group if all layers use +the same type of KV cache. +See KVCacheConfig class for more examples of "KV cache group". +List[KVCacheBlocks]: the blocks of one group of layer in one request +List[List[KVCacheBlocks]]: the blocks of all groups of layers in one request. +""" +MayGroupedKVCacheBlocks = Union[List["KVCacheBlock"], + List[List["KVCacheBlock"]]] + class FreeKVCacheBlockQueue: """This class organizes a list of KVCacheBlock objects to a doubly linked @@ -243,6 +275,7 @@ def generate_block_hash_extra_keys( def hash_block_tokens( parent_block_hash: Optional[int], curr_block_token_ids: Sequence[int], + kv_cache_group_id: int, extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for @@ -274,18 +307,20 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHashType( - hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)), - curr_block_token_ids_tuple, extra_keys) + hash((parent_block_hash, curr_block_token_ids_tuple, kv_cache_group_id, + extra_keys)), curr_block_token_ids_tuple, kv_cache_group_id, + extra_keys) -def hash_request_tokens(block_size: int, - request: Request) -> List[BlockHashType]: +def hash_request_tokens(block_size: int, request: Request, + kv_cache_group_id: int) -> List[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. request: The request object. + kv_cache_group_id: The KV cache group that the blocks belong to Returns: The list of computed hash values. @@ -316,14 +351,15 @@ def hash_request_tokens(block_size: int, request, start, end, curr_mm_idx) block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids, extra_keys) + block_token_ids, kv_cache_group_id, + extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int): """ Checks whether `available_memory` is enough for the KV cache to hold at @@ -331,7 +367,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Raises: @@ -358,12 +394,12 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") -def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: +def is_kv_cache_type_uniform(kv_cache_spec: Dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same type of KV cache. Args: - kv_cache_spec: The KVCacheSpec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model Returns: True if all layers have the same type, False otherwise. @@ -373,8 +409,52 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: return len(layer_keys) == 1 +def is_kv_cache_page_size_uniform( + kv_cache_spec: Dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same page size. + + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + True if all layers have the same page size, False otherwise. + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + return len(page_sizes) == 1 + + +def _create_kv_cache_groups( + kv_cache_spec: Dict[str, KVCacheSpec], + grouped_layers: List[List[str]]) -> List[KVCacheGroup]: + """ + Create KVCacheGroup objects for each group of layers. + The layers in one group should share the same KVCacheSpec. + + Args: + kv_cache_spec (Dict[str, KVCacheSpec]): + A mapping from each layer name to its corresponding KVCacheSpec. + grouped_layers (List[List[str]]): + A list of layer groups, where each element is a list of layer names + that belongs to one group and should share the same KVCacheSpec. + + Returns: + A list of KVCacheGroup objects, one for each group of layers. + """ + kv_cache_groups = [] + for layer_names in grouped_layers: + group_spec = kv_cache_spec[layer_names[0]] + assert all( + kv_cache_spec[layer_name] == group_spec + for layer_name in layer_names[1:]), ( + "All layers in a group must share the same KVCacheSpec.") + kv_cache_groups.append(KVCacheGroup(layer_names, group_spec)) + return kv_cache_groups + + def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. @@ -382,7 +462,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Returns: @@ -411,27 +491,88 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, vllm_config.model_config.max_model_len, max_concurrency) per_layer_size = page_size * num_blocks + grouped_layers = [[layer_name for layer_name in kv_cache_spec]] + + kv_cache_config = KVCacheConfig(num_blocks=num_blocks, + tensors={ + layer_name: + KVCacheNewTensor(size=per_layer_size) + for layer_name in kv_cache_spec + }, + groups=_create_kv_cache_groups( + kv_cache_spec, grouped_layers)) + return kv_cache_config + - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - tensors={ - layer_name: KVCacheTensor(size=per_layer_size) - for layer_name in kv_cache_spec - }, - groups=[[layer_name for layer_name in kv_cache_spec]], - kv_cache_spec=kv_cache_spec) +def _get_kv_cache_config_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: Dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one page size. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The KVCacheSpec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + # Group all layers by type_id. + # E.g., 2 full attention layers and 4 sliding window attention layers, + # -> (full.0, full.1), (sw.0, sw.1, sw.2, sw.3). + same_type_layers: Dict[str, List[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec.type_id].append(layer_name) + + # Split each group into smaller groups, to make the number of layers in + # each group identical. + # E.g., (full.0, full.1), (sw.0, sw.1, sw.2, sw.3) is split to 3 groups: + # (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3). + group_size_gcd = math.gcd( + *[len(layers) for layers in same_type_layers.values()]) + grouped_layers = [] + for layers in same_type_layers.values(): + for i in range(0, len(layers), group_size_gcd): + grouped_layers.append(layers[i:i + group_size_gcd]) + + # Divide the available memory equally among all layers in the first group. + # The memory layout in the example will be: + # full.0: Tensor with size=available_memory//2 + # full.1: Tensor with size=available_memory//2 + kv_cache_spec_first_group = { + layer_name: kv_cache_spec[layer_name] + for layer_name in grouped_layers[0] + } + kv_cache_config = _get_kv_cache_config_uniform_type( + vllm_config, kv_cache_spec_first_group, available_memory) + + # Reuse the KV cache tensors of the first group for the other groups. + # The memory layout in the example will be: + # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 + # full.1, sw.1, sw.3: share another Tensor with size=available_memory//2 + # Layers of different groups have different block table, so they will + # use different parts of the shared Tensor. + for layers in grouped_layers[1:]: + for layer_name, layer_name_first_group in zip(layers, + grouped_layers[0]): + kv_cache_config.tensors[layer_name] = KVCacheReuseTensor( + reused_layer_name=layer_name_first_group) + + kv_cache_config.groups = _create_kv_cache_groups(kv_cache_spec, + grouped_layers) return kv_cache_config -def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, +def get_kv_cache_config(vllm_config: VllmConfig, + kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model - TODO: support hybrid models with more than one type of KV cache. Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The KVCacheSpec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Returns: @@ -443,5 +584,72 @@ def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, # Allocate the same amount of memory for each layer. return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # KV cache of all layers have the same page size. + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) else: raise NotImplementedError + + +@dataclass +class PrefixLengthRange: + """ + A closed interval [start, end] representing a range of valid prefix lengths. + """ + start: int + end: int + + +def intersect_two_ranges( + a: List[PrefixLengthRange], + b: List[PrefixLengthRange]) -> List[PrefixLengthRange]: + """ + Intersect two sorted lists of PrefixLengthRange intervals. + + Args: + a: List of intervals + b: List of intervals + Returns: + List of intervals that are intersections of a and b + """ + i, j = 0, 0 + result = [] + + while i < len(a) and j < len(b): + overlap_start = max(a[i].start, b[j].start) + overlap_end = min(a[i].end, b[j].end) + + if overlap_start <= overlap_end: + result.append(PrefixLengthRange(overlap_start, overlap_end)) + + if a[i].end < b[j].end: + i += 1 + else: + j += 1 + + return result + + +def intersect_ranges( + ranges: List[List[PrefixLengthRange]]) -> List[PrefixLengthRange]: + """ + Intersect multiple lists of PrefixLengthRange intervals, each is sorted. + + Args: + ranges: A list of lists of intervals + Returns: + A list of intervals representing the intersection of all ranges + """ + if not ranges: + return [] + + current_intersection = ranges[0] + for i in range(1, len(ranges)): + current_intersection = intersect_two_ranges(current_intersection, + ranges[i]) + if not current_intersection: + break + + return current_intersection diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index fb5e83fe06274..47e27267c7c36 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -12,6 +12,9 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.kv_cache_interface import (BlockIDGenerator, FullAttentionSpec, + GroupedBlockIDs, KVCacheConfig, + MayGroupedBlockIDs, MayGroupedInt) from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -31,6 +34,7 @@ def __init__( model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + kv_cache_config: KVCacheConfig, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -48,11 +52,10 @@ def __init__( assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, + kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) + BlockIDGenerator.num_kv_cache_groups = len(kv_cache_config.groups) self.block_size = self.cache_config.block_size # req_id -> Request @@ -107,7 +110,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] - req_to_new_block_ids: Dict[str, List[int]] = {} + req_to_new_block_ids: Dict[str, GroupedBlockIDs] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -164,9 +167,9 @@ def schedule(self) -> "SchedulerOutput": # Schedule the request. scheduled_running_reqs.append(request) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] + req_to_new_block_ids[ + request.request_id] = BlockIDGenerator.generate(new_blocks) + num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -203,9 +206,21 @@ def schedule(self) -> "SchedulerOutput": # is always a multiple of the block size. This limitation # can potentially be removed in the future to slightly # improve the performance. - num_computed_tokens -= self.block_size - num_new_tokens = self.block_size - computed_blocks.pop() + kv_groups = self.kv_cache_manager.kv_cache_config.groups + if len(kv_groups) > 1 or \ + not isinstance(kv_groups[0].kv_cache_spec, + FullAttentionSpec): + # It is difficult to handle the last block problem + # for hybrid models. Ignore all computed tokens as + # a temporary solution. + num_computed_tokens = 0 + num_new_tokens = request.num_tokens + computed_blocks = [[] for _ in kv_groups] + else: + block_size = kv_groups[0].kv_cache_spec.block_size + num_computed_tokens -= block_size + num_new_tokens = block_size + computed_blocks[0].pop() num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -219,7 +234,8 @@ def schedule(self) -> "SchedulerOutput": break new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens, computed_blocks, + num_computed_tokens) if new_blocks is None: # The request cannot be scheduled. break @@ -234,9 +250,11 @@ def schedule(self) -> "SchedulerOutput": raise RuntimeError( f"Invalid request status: {request.status}") - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[ + request.request_id] = BlockIDGenerator.generate( + computed_blocks) + BlockIDGenerator.generate( + new_blocks) + num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -264,7 +282,7 @@ def schedule(self) -> "SchedulerOutput": # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = 0 + num_common_prefix_blocks = [0] if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -315,7 +333,7 @@ def schedule(self) -> "SchedulerOutput": def _make_cached_request_data( self, request: Request, - new_block_ids: List[int], + new_block_ids: GroupedBlockIDs, num_computed_tokens: int, resumed_from_preemption: bool, ) -> "CachedRequestData": @@ -534,6 +552,7 @@ def finish_requests( def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] @@ -566,14 +585,14 @@ class NewRequestData: mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams - block_ids: List[int] + block_ids: MayGroupedBlockIDs num_computed_tokens: int @classmethod def from_request( cls, request: Request, - block_ids: List[int], + block_ids: MayGroupedBlockIDs, num_computed_tokens: int, ) -> "NewRequestData": return cls( @@ -597,7 +616,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool - new_block_ids: List[int] + new_block_ids: MayGroupedBlockIDs num_computed_tokens: int @classmethod @@ -605,7 +624,7 @@ def from_request( cls, request: Request, resumed_from_preemption: bool, - new_block_ids: List[int], + new_block_ids: MayGroupedBlockIDs, num_computed_tokens: int, ) -> "CachedRequestData": return cls( @@ -625,7 +644,9 @@ class SchedulerOutput: num_scheduled_tokens: Dict[str, int] total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] - num_common_prefix_blocks: int + # Number of common prefix blocks per kv cache group + # See KVCacheConfig class for the meaning of "group" + num_common_prefix_blocks: MayGroupedInt finished_req_ids: Set[str] free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py new file mode 100644 index 0000000000000..134e9357e7855 --- /dev/null +++ b/vllm/v1/core/specialized_manager.py @@ -0,0 +1,600 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass +from itertools import chain +import math +from typing import Callable, DefaultDict, Dict, Iterator, List, Optional, Tuple, Type, TypeVar + +from vllm.v1.core.block_pool import BlockPool +from vllm.utils import cdiv +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + PrefixLengthRange, + hash_request_tokens, intersect_ranges) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec, SlidingWindowSpec) +from vllm.v1.request import Request, RequestStatus +from vllm.v1.utils import ConstantList + +T = TypeVar("T") + + +class SpecializedManager(ABC): + """ + An abstract base class for specialized managers that handle the kv + cache management logic of different attention layers. + """ + block_size: int + max_num_blocks_per_req: int + + def __init__( + self, + kv_cache_spec: KVCacheSpec, + max_model_len: int, + enable_caching: bool, + kv_cache_group_id: int, + block_pool: BlockPool, + ) -> None: + """ + Initializes the SpecializedManager. + + Args: + kv_cache_spec: The kv_cache_spec for this manager. + block_pool: The block pool. + + Returns: + None + """ + + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_pool = block_pool + self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) + self.enable_caching = enable_caching + self.kv_cache_group_id = kv_cache_group_id + + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: DefaultDict[str, + List[KVCacheBlock]] = defaultdict(list) + + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: DefaultDict[ + str, List[BlockHashType]] = defaultdict(list) + + def hash_request_tokens(self, request: Request) -> List[BlockHashType]: + """ + Hash the tokens of a request to block hashes. + + Args: + request: The request to hash. + + Returns: + List[BlockHashType]: The block hashes of the request. + """ + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = hash_request_tokens(self.block_size, request, + self.kv_cache_group_id) + self.req_to_block_hashes[request.request_id] = block_hashes + return block_hashes + + def truncate_computed_blocks( + self, computed_blocks: List[KVCacheBlock], + num_computed_tokens: int) -> List[KVCacheBlock]: + # Truncate the computed blocks to the number of computed tokens. + # E.g., group 0 has 3 computed blocks, and group 1 has 4 computed + # blocks with the same block size, we truncate both groups to 3 blocks. + computed_blocks = computed_blocks[:num_computed_tokens // + self.block_size] + return computed_blocks + + def get_req_num_new_blocks(self, request: Request, + new_computed_blocks: List[KVCacheBlock], + num_computed_tokens: int, num_tokens: int): + req_blocks = self.req_to_blocks[request.request_id] + return self.get_num_new_blocks( + num_computed_tokens, num_tokens, + len(req_blocks) + len(new_computed_blocks)) + + def allocate_slots( + self, + request: Request, + new_computed_blocks: Optional[List[KVCacheBlock]], + num_new_blocks: int, + num_preallocate_blocks: int, + num_computed_tokens: int, + num_tokens: int, + ): + if new_computed_blocks is None: + new_computed_blocks = [] + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self.block_pool.touch(new_computed_blocks) + else: + assert len(new_computed_blocks) == 0, ( + "Computed blocks should be empty when " + "prefix caching is disabled") + + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + req_blocks = self.req_to_blocks[request.request_id] + req_blocks.extend(new_computed_blocks) + + # Start to handle new blocks + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_new_blocks = min( + num_new_blocks + num_preallocate_blocks, + # Should not exceed the maximum number of blocks per request + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(req_blocks), + ) + + assert num_new_blocks >= 0 + + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + + if not self.enable_caching: + return new_blocks + + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. + num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size + num_computed_full_blocks = num_computed_tokens // self.block_size + + new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] + if new_full_blocks: + self.block_pool.cache_full_blocks( + request=request, + block_hashes=self.req_to_block_hashes[request.request_id], + block_size=self.block_size, + blk_start_idx=num_computed_full_blocks, + # The new full blocks are the full blocks that are not + # computed. + full_blocks=new_full_blocks, + prev_block=(req_blocks[num_computed_full_blocks - 1] + if num_computed_full_blocks > 0 else None), + kv_cache_group_id=self.kv_cache_group_id, + ) + + return new_blocks + + def get_num_common_prefix_blocks(self, request: Request, + num_running_requests: int) -> int: + assert request.status == RequestStatus.RUNNING + blocks = self.req_to_blocks[request.request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks + + def new_block_list(self) -> List[KVCacheBlock]: + return [] + + def pop_blocks_of_request(self, + request_id: int) -> Optional[List[KVCacheBlock]]: + blocks = self.req_to_blocks.pop(request_id, None) + return blocks + + def free_blocks(self, + blocks_to_free: List[KVCacheBlock], + need_reverse: bool = False) -> None: + if need_reverse: + blocks_to_free = list(reversed(blocks_to_free)) + self.block_pool.free_blocks(blocks_to_free) + + def iter_all(self, x: List[T]) -> Iterator[T]: + return iter(x) + + def pop_block_hashes_of_request(self, request_id: int) -> None: + self.req_to_block_hashes.pop(request_id, None) + + @abstractmethod + def get_possible_cached_prefix( + self, block_hashes: ConstantList[BlockHashType] + ) -> Tuple[List[PrefixLengthRange], List[KVCacheBlock]]: + """ + Get the possible cached prefixes of a request based on its block hashes. + If no cached prefixes are found, returns a tuple with a prefix length + range of [0, 0] and an empty list of blocks. + + Args: + block_hashes: The block hashes of the request. + + Returns: + A tuple containing: + - A list of all possible cached prefix lengths. + - The computed blocks that are cached. + """ + + raise NotImplementedError + + @abstractmethod + def get_num_new_blocks(self, num_computed_tokens: int, + num_append_tokens: int, + num_allocated_blocks: int) -> int: + """ + Calculate the number of new blocks needed by this manager. + + Args: + num_computed_tokens: The number of tokens that have been computed. + num_append_tokens: The number of tokens that need to be appended. + num_allocated_blocks: The number of blocks that have already been + allocated. + + Returns: + int: The number of new blocks needed. + """ + raise NotImplementedError + + @abstractmethod + def remove_useless_blocks(self, request: Request, + num_computed_tokens: int) -> List[KVCacheBlock]: + """ + Update the `block_table` in place to remove blocks that are no longer + needed. Replace the removed blocks with null_block and returns the + removed blocks. + The removed blocks should be in the order of the + priority to be evicted, where the first block should have the highest + priority. + + Args: + block_table: The block table to be updated. + num_computed_tokens: The number of tokens that have been computed. + + Returns: + List[KVCacheBlock]: The removed blocks. + """ + raise NotImplementedError + + +class FullAttentionManager(SpecializedManager): + + def get_possible_cached_prefix( + self, block_hashes: ConstantList[BlockHashType] + ) -> Tuple[List[PrefixLengthRange], List[KVCacheBlock]]: + computed_blocks: List[KVCacheBlock] = [] + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self.block_pool.get_cached_block(block_hash): + computed_blocks.append(cached_block) + else: + break + return [PrefixLengthRange(0, + len(computed_blocks) * self.block_size) + ], computed_blocks + + def get_num_new_blocks(self, num_computed_tokens: int, + num_append_tokens: int, + num_allocated_blocks: int) -> int: + num_required_blocks = cdiv(num_computed_tokens + num_append_tokens, + self.block_size) + num_new_blocks = num_required_blocks - num_allocated_blocks + return num_new_blocks + + def remove_useless_blocks(self, request: Request, + num_computed_tokens: int) -> List[KVCacheBlock]: + return [] + + +class SlidingWindowManager(FullAttentionManager): + + def __init__( + self, + kv_cache_spec: SlidingWindowSpec, + max_model_len: int, + enable_caching: bool, + kv_cache_group_id: int, + block_pool: BlockPool, + ) -> None: + super().__init__( + kv_cache_spec=kv_cache_spec, + max_model_len=max_model_len, + enable_caching=enable_caching, + kv_cache_group_id=kv_cache_group_id, + block_pool=block_pool, + ) + self.sliding_window = kv_cache_spec.sliding_window + self._null_block = block_pool.get_null_block() + + def get_possible_cached_prefix( + self, block_hashes: ConstantList[BlockHashType] + ) -> Tuple[List[PrefixLengthRange], List[KVCacheBlock]]: + # TODO: check the hit every num_block_sliding_window blocks, to optimize + # the time complexity from O(num_block) to + # O(num_block / num_block_sliding_window) + O(num_computed_block), + # which is good for low cache hit rate scenarios. + start = 0 + ranges = [] + computed_blocks: List[KVCacheBlock] = [] + + dummy_block_hash = BlockHashType(-1, (), -1) + # Add a dummy block hash to support the case that the last block is + # cached. + for i, block_hash in enumerate(chain(block_hashes, + [dummy_block_hash])): + if cached_block := self.block_pool.get_cached_block(block_hash): + computed_blocks.append(cached_block) + else: + if start == 0: + # All tokens between [0, i * block_size] are cached. + # All of them are possible cached prefix. + ranges.append(PrefixLengthRange(0, i * self.block_size)) + elif (i - start) * self.block_size >= self.sliding_window: + # All tokens between [start * block_size, + # i * block_size)] are cached. These tokens except the + # first `self.sliding_window - 1` ones are possible cached + # prefix. + first_cached_token = start * self.block_size + # should be first_cached_token + self.sliding_window - 1 + 1 + # +1 is for converting the token index to the prefix length. + first_possible_length = first_cached_token + \ + self.sliding_window + ranges.append( + PrefixLengthRange(first_possible_length, + i * self.block_size)) + computed_blocks.append(self._null_block) + start = i + 1 + computed_blocks = computed_blocks[:-1] # remove the dummy block + return ranges, computed_blocks + + def remove_useless_blocks(self, request: Request, + num_computed_tokens: int) -> List[KVCacheBlock]: + # Remove the blocks that are no longer be in the sliding window. + last_useful_token = num_computed_tokens - self.sliding_window + last_useful_block = last_useful_token // self.block_size + + block_table = self.req_to_blocks[request.request_id] + removed_blocks: List[KVCacheBlock] = [] + for i in range(last_useful_block - 1, -1, -1): + if block_table[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also be null blocks. + break + removed_blocks.append(block_table[i]) + block_table[i] = self._null_block + return removed_blocks + + +spec_manager_map: Dict[Type[KVCacheSpec], Type[SpecializedManager]] = { + FullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager +} + + +def transpose_output(outputs: List[Tuple[T]]) -> Tuple[List[T]]: + return tuple(map(list, zip(*outputs))) + + +class GroupedManager(SpecializedManager): + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + enable_caching: bool, block_pool: BlockPool) -> None: + self.enable_caching = enable_caching + self.managers: List[SpecializedManager] = [] + self.kv_cache_config = kv_cache_config + for i, g in enumerate(kv_cache_config.groups): + manager_class = spec_manager_map[type(g.kv_cache_spec)] + manager = manager_class(g.kv_cache_spec, max_model_len, + enable_caching, i, block_pool) + self.managers.append(manager) + self.block_pool = block_pool + + # Simple broadcast functions + # TODO: a better way to handle the broadcast functions + def hash_request_tokens(self, + request: Request) -> List[List[BlockHashType]]: + return [ + manager.hash_request_tokens(request) for manager in self.managers + ] + + def truncate_computed_blocks( + self, computed_blocks: List[List[KVCacheBlock]], + num_computed_tokens: int) -> List[List[KVCacheBlock]]: + return [ + manager.truncate_computed_blocks(computed_blocks[i], + num_computed_tokens) + for i, manager in enumerate(self.managers) + ] + + def get_req_num_new_blocks(self, request: Request, + new_computed_blocks: List[List[KVCacheBlock]], + num_computed_tokens: int, num_tokens: int): + num_new_blocks = [ + manager.get_req_num_new_blocks(request, new_computed_blocks[i], + num_computed_tokens, num_tokens) + for i, manager in enumerate(self.managers) + ] + return sum(max(x, 0) for x in num_new_blocks) + + def get_num_common_prefix_blocks(self, request: Request, + num_running_requests: int): + return [ + manager.get_num_common_prefix_blocks(request, num_running_requests) + for manager in self.managers + ] + + def remove_useless_blocks(self, request: Request, + num_computed_tokens: int) -> None: + """ + Frees memory blocks that are not needed. E.g., sliding window + layer with window size 2 and block size 1, we have req_blocks as + [[1, 2, 3]], this function will free block 1 and change the req_blocks + to [[-1, 2, 3]] (-1 refers to null block) + + Args: + req_blocks: The KV cache blocks of one request. + num_computed_tokens: The number of computed tokens. + """ + return [ + manager.remove_useless_blocks(request, num_computed_tokens) + for i, manager in enumerate(self.managers) + ] + + def new_block_list(self) -> List[List[KVCacheBlock]]: + return [manager.new_block_list() for manager in self.managers] + + def pop_blocks_of_request( + self, request_id: int) -> Optional[List[List[KVCacheBlock]]]: + blocks = [ + manager.pop_blocks_of_request(request_id) + for manager in self.managers + ] + if all(blks is None for blks in blocks): + return None + assert all(blks is not None for blks in blocks) + return blocks + + def pop_block_hashes_of_request(self, request_id: int) -> None: + for manager in self.managers: + manager.pop_block_hashes_of_request(request_id) + + def allocate_slots(self, request: Request, + new_computed_blocks: Optional[List[List[KVCacheBlock]]], + num_new_blocks: int, num_preallocate_blocks: int, + num_computed_tokens: int, num_tokens: int): + # NOTE: the input _num_new_blocks is the sum of all groups. + # recompute instead + num_new_blocks_per_group = [ + manager.get_req_num_new_blocks(request, new_computed_blocks[i], + num_computed_tokens, num_tokens) + for i, manager in enumerate(self.managers) + ] + assert num_new_blocks == sum( + max(x, 0) for x in num_new_blocks_per_group) + return [ + manager.allocate_slots( + request, new_computed_blocks[i] if new_computed_blocks + is not None else None, num_new_blocks_per_group[i], + num_preallocate_blocks, num_computed_tokens, num_tokens) + for i, manager in enumerate(self.managers) + ] + + def _get_common_prefix_length( + self, prefix_length: List[List[PrefixLengthRange]] + ) -> List[PrefixLengthRange]: + """ + Find the longest prefix that is cached by all KV cache groups. Returns + the number of tokens in that prefix. + + Args: + prefix_length (List[PrefixLength]): The valid cached prefix lengths + of each KV cache group. + + Returns: + The number of tokens in the common prefix. + """ + intersection = intersect_ranges(prefix_length) + + # Since incomplete blocks are not eligible for sharing, + # `num_computed_tokens` should be a multiple of `block_size` of + # all managers, so we take the least common multiple (LCM) of them + alignment = math.lcm( + *[manager.block_size for manager in self.managers]) + + aligned_intersection = [] + for range_ in intersection: + aligned_end = cdiv(range_.end, alignment) * alignment + if aligned_end >= range_.start: + aligned_intersection.append( + PrefixLengthRange(range_.start, aligned_end)) + + if len(aligned_intersection) == 0: + aligned_intersection.append(PrefixLengthRange(0, 0)) + + return aligned_intersection + + def iter_all(self, x: List[List[T]]) -> Iterator[T]: + return chain.from_iterable(x) + + def _sort_blocks_by_eviction_order( + self, blocks: List[List[KVCacheBlock]], + need_reverse: bool) -> List[KVCacheBlock]: + """ + Merge the blocks of different groups to one list. The returned blocks + are sorted by eviction order, with the first block having the highest + eviction priority. + + Args: + blocks: the blocks of each kv cache group, ordered by eviction + priority. + + Returns: + A list of KVCacheBlocks sorted by eviction order. + """ + if need_reverse: + blocks = [ + list(reversed(blocks_of_group)) for blocks_of_group in blocks + ] + + if self.enable_caching: + # NOTE (Chen): A simple strategy that interleaves the blocks of + # different KV cache groups. We can investigate more advanced + # strategies in the future. + ordered_blocks = [] + max_len = max(len(blocks_of_group) for blocks_of_group in blocks) + for i in range(max_len): + for blocks_of_group in blocks: + if i < len(blocks_of_group): + ordered_blocks.append(blocks_of_group[i]) + else: + ordered_blocks = [] + for blocks_of_group in blocks: + ordered_blocks.extend(blocks_of_group) + + return ordered_blocks + + def free_blocks(self, + blocks_to_free: List[List[KVCacheBlock]], + need_reverse: bool = False) -> None: + ordered_blocks = self._sort_blocks_by_eviction_order( + blocks_to_free, need_reverse) + self.block_pool.free_blocks(ordered_blocks) + + def get_possible_cached_prefix( + self, block_hashes: List[List[BlockHashType]] + ) -> Tuple[List[PrefixLengthRange], List[List[KVCacheBlock]]]: + output = [ + manager.get_possible_cached_prefix(block_hashes[i]) + for i, manager in enumerate(self.managers) + ] + prefix_length, computed_blocks = transpose_output(output) + common_prefix_length = self._get_common_prefix_length(prefix_length) + + return common_prefix_length, computed_blocks + + def get_num_new_blocks(self, num_computed_tokens: int, + num_append_tokens: int, + num_allocated_blocks: int) -> int: + raise RuntimeError("This method should not be called") + + +def get_specialized_manager(kv_cache_config: KVCacheConfig, max_model_len: int, + enable_caching: bool, + block_pool: BlockPool) -> SpecializedManager: + if len(kv_cache_config.groups) == 1: + return spec_manager_map[type(kv_cache_config.groups[0].kv_cache_spec)]( + kv_cache_config.groups[0].kv_cache_spec, max_model_len, + enable_caching, 0, block_pool) + else: + return GroupedManager(kv_cache_config, max_model_len, enable_caching, + block_pool) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 29a9ac1868f27..19ee966c12213 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -6,7 +6,7 @@ import threading import time from multiprocessing.connection import Connection -from typing import List, Tuple, Type +from typing import List, Type import psutil import zmq @@ -25,6 +25,7 @@ EngineCoreRequestUnion, EngineCoreResetPrefixCache) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder from vllm.version import __version__ as VLLM_VERSION @@ -51,10 +52,9 @@ def __init__( self.model_executor = executor_class(vllm_config) # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( - vllm_config) - vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks - vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + kv_cache_config = self._initialize_kv_caches(vllm_config) + vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + vllm_config.cache_config.num_cpu_blocks = 0 # Setup scheduler. self.scheduler = Scheduler( @@ -62,13 +62,12 @@ def __init__( model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, - ) + kv_cache_config=kv_cache_config) self.mm_input_mapper_server = MMInputMapperServer( vllm_config.model_config) - def _initialize_kv_caches(self, - vllm_config: VllmConfig) -> Tuple[int, int]: + def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig: start = time.time() # Get all kv cache needed by the model @@ -81,8 +80,6 @@ def _initialize_kv_caches(self, # Get the kv cache tensor size kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, availble_gpu_memory) - num_gpu_blocks = kv_cache_config.num_blocks - num_cpu_blocks = 0 # Initialize kv cache and warmup the execution self.model_executor.initialize(kv_cache_config) @@ -90,7 +87,7 @@ def _initialize_kv_caches(self, elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) - return num_gpu_blocks, num_cpu_blocks + return kv_cache_config def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index ac10d43eb0d54..f273bcdd9da7a 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import Dict, Type from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase @@ -64,7 +64,7 @@ def determine_available_memory(self) -> int: # in bytes # operators can be applied to all workers. return min(output) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: output = self.collective_rpc("get_kv_cache_spec") for x in output: assert x == output[0] diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index eddfb5949ebe6..09be2463d9e9f 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,18 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List +from typing import TYPE_CHECKING, Dict, List, TypeVar, Union import torch from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import KVCacheBlock logger = init_logger(__name__) @dataclass -class KVCacheSpecBase: +class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. """ @@ -56,7 +58,8 @@ def bytes_for_tokens(self, num_tokens: int) -> int: @dataclass -class FullAttentionSpec(KVCacheSpecBase): +class FullAttentionSpec(KVCacheSpec): + num_heads: int num_kv_heads: int head_size: int dtype: torch.dtype @@ -74,19 +77,64 @@ def bytes_for_tokens(self, num_tokens: int) -> int: return cdiv(num_tokens, self.block_size) * self.page_size_bytes -KVCacheSpec = Dict[str, KVCacheSpecBase] +@dataclass +class SlidingWindowSpec(KVCacheSpec): + num_heads: int + num_kv_heads: int + head_size: int + dtype: torch.dtype + sliding_window: int + + @property + def type_id(self) -> str: + return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa + + @property + def page_size_bytes(self) -> int: + return 2 * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + + def bytes_for_tokens(self, num_tokens: int) -> int: + num_tokens = min(num_tokens, self.sliding_window) + return cdiv(num_tokens, self.block_size) * self.page_size_bytes @dataclass -class KVCacheTensor: +class KVCacheTensorBase: """ A dataclass for specifying how the workers should initialize the KV cache - for a layer. Only contains the size of KV cache for that layer for now. Will - be extended to support multiple layers sharing the same memory pool. + for a layer. + """ + pass + + +@dataclass +class KVCacheNewTensor(KVCacheTensorBase): + """ + Initialize the KV cache with a tensor of `size` bytes. """ size: int # The size of KV cache Tensor in bytes +@dataclass +class KVCacheReuseTensor(KVCacheTensorBase): + """ + Reuse the KV cache tensor of `layer_name` for the current layer. + """ + reused_layer_name: str + + +@dataclass +class KVCacheGroup: + """ + A dataclass for specifying the KV cache group of a model. + """ + # The names of layers in this group + layer_names: List[str] + # The KV cache spec of this group + kv_cache_spec: KVCacheSpec + + @dataclass class KVCacheConfig: """ @@ -95,7 +143,7 @@ class KVCacheConfig: """The number of KV cache blocks""" num_blocks: int """layer_name -> how to initialize KV cache for that layer""" - tensors: Dict[str, KVCacheTensor] + tensors: Dict[str, KVCacheTensorBase] """ A list of kv-cache groups. Each group includes a set of layers with the same kv-cache spec, and the total page_size of layers inside a group @@ -108,6 +156,51 @@ class KVCacheConfig: 3. (not implemented yet) A model with 2 full attention layers and 4 sliding window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). """ - groups: List[List[str]] - """the KVCacheSpec of the model""" - kv_cache_spec: KVCacheSpec + groups: List[KVCacheGroup] + + +@dataclass +class GroupedBlockIDs: + # A list of block IDs for each group of KV cache blocks + _block_ids: List[List[int]] + + def __init__(self, block_ids: List[List[int]]): + self._block_ids = block_ids + + @classmethod + def from_kv_cache_blocks( + cls, + kv_cache_blocks: List[List["KVCacheBlock"]]) -> "GroupedBlockIDs": + return cls( + block_ids=[[blk.block_id for blk in kv_cache_blocks_one_group] + for kv_cache_blocks_one_group in kv_cache_blocks]) + + def extend(self, new_block_ids: "GroupedBlockIDs") -> None: + for i, block_ids in enumerate(new_block_ids._block_ids): + self._block_ids[i].extend(block_ids) + + def __add__(self, other: "GroupedBlockIDs") -> "GroupedBlockIDs": + return GroupedBlockIDs(block_ids=[ + a + b for a, b in zip(self._block_ids, other._block_ids) + ]) + + def get_group(self, group_idx: int) -> List[int]: + return self._block_ids[group_idx] + + +MayGroupedBlockIDs = Union[GroupedBlockIDs, List[int]] +MayGroupedInt = Union[int, List[int]] + + +class BlockIDGenerator: + num_kv_cache_groups: int + + @classmethod + def generate( + cls, kv_cache_blocks: Union[List["KVCacheBlock"], + List[List["KVCacheBlock"]]] + ) -> MayGroupedBlockIDs: + if cls.num_kv_cache_groups == 1: + return [blk.block_id for blk in kv_cache_blocks] + else: + return GroupedBlockIDs.from_kv_cache_blocks(kv_cache_blocks) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 0519d9e787518..8c76d72f5d812 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.core.kv_cache_utils import BlockHashType class Request: @@ -63,11 +62,6 @@ def __init__( if self.mm_hashes: assert len(self.mm_inputs) == len(self.mm_hashes) - # Cache the computed kv block hashes of the request to avoid - # recomputing. - self._kv_block_hashes: List[BlockHashType] = [] - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - # Read-only views # Prevent directly appending to the these lists since # they should also be updated simultaneously. @@ -124,13 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens - def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: - self._kv_block_hashes = value - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - - def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: - self._kv_block_hashes.append(block_hash) - class RequestStatus(enum.IntEnum): """Status of a request.""" diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5494542c181d7..f18b08746b31f 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -86,6 +86,9 @@ def __contains__(self, item): def __len__(self): return len(self._x) + def __repr__(self): + return "ConstantList(" + repr(self._x) + ")" + class BackgroundProcHandle: """ diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index f520ee9586c5c..06dd0c593f51d 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List +from typing import Callable, List, Union import numpy as np import torch +from triton import cdiv +from typing_extensions import Concatenate, ParamSpec from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import (GroupedBlockIDs, KVCacheConfig, + KVCacheSpec) logger = init_logger(__name__) @@ -16,23 +20,26 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, + max_num_tokens: int, pin_memory: bool, device: torch.device, + kv_cache_spec: KVCacheSpec, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_tokens = max_num_tokens + self.max_num_blocks_per_req = cdiv(max_model_len, + kv_cache_spec.block_size) self.pin_memory = pin_memory self.device = device self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, self.max_num_blocks_per_req), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), + (max_num_reqs, self.max_num_blocks_per_req), device="cpu", dtype=torch.int32, pin_memory=pin_memory, @@ -40,20 +47,27 @@ def __init__( self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + def append_row( self, - row_idx: int, - start: int, block_ids: List[int], + row_idx: int, ) -> None: if not block_ids: return num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] self.block_table_np[row_idx, start:start + num_blocks] = block_ids self.num_blocks_per_row[row_idx] = start + num_blocks - def add_row(self, row_idx: int, block_ids: List[int]) -> None: - self.append_row(row_idx, 0, block_ids) + def add_row(self, block_ids: List[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] @@ -80,3 +94,70 @@ def get_cpu_tensor(self) -> torch.Tensor: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np + + +P = ParamSpec("P") + + +class GroupedBlockTable: + move_row: Callable[P, None] + commit: Callable[P, None] + clear: Callable[P, None] + + append_row: Callable[Concatenate["GroupedBlockIDs", P], None] + add_row: Callable[Concatenate["GroupedBlockIDs", P], None] + + def __init__(self, max_num_reqs: int, max_model_len: int, + max_num_tokens: int, pin_memory: bool, device: torch.device, + kv_cache_config: KVCacheConfig) -> None: + self.block_tables = [ + BlockTable(max_num_reqs, max_model_len, max_num_tokens, pin_memory, + device, g.kv_cache_spec) for g in kv_cache_config.groups + ] + # For methods that just pass the arguments to each BlockTable. + for f_name in ("move_row", "commit", "clear"): + setattr(self, f_name, self._make_grouped_func(f_name)) + # For methods that require a block_ids as the first argument. + for f_name in ("append_row", "add_row"): + setattr(self, f_name, + self._make_grouped_func_with_block_ids(f_name)) + + def _make_grouped_func(self, f_name: str) -> Callable[P, None]: + + def grouped_func(*args: P.args, **kwargs: P.kwargs) -> None: + for block_table in self.block_tables: + getattr(block_table, f_name)(*args, **kwargs) + + return grouped_func + + def _make_grouped_func_with_block_ids( + self, + f_name: str) -> Callable[Concatenate["GroupedBlockIDs", P], None]: + + def grouped_func(block_ids: "GroupedBlockIDs", *args: P.args, + **kwargs: P.kwargs) -> None: + for i, block_table in enumerate(self.block_tables): + getattr(block_table, f_name)(block_ids.get_group(i), *args, + **kwargs) + + return grouped_func + + def __getitem__(self, idx: int) -> "BlockTable": + return self.block_tables[idx] + + +def initialize_block_table( + max_num_reqs: int, + max_model_len: int, + max_num_tokens: int, + pin_memory: bool, + device: torch.device, + kv_cache_config: KVCacheConfig, +) -> Union[BlockTable, GroupedBlockTable]: + if len(kv_cache_config.groups) == 1: + return BlockTable(max_num_reqs, max_model_len, max_num_tokens, + pin_memory, device, + kv_cache_config.groups[0].kv_cache_spec) + else: + return GroupedBlockTable(max_num_reqs, max_model_len, max_num_tokens, + pin_memory, device, kv_cache_config) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 39708f833fd58..bfad40ff96b7d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,8 +10,9 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.kv_cache_interface import KVCacheConfig, MayGroupedBlockIDs from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.block_table import initialize_block_table if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -28,7 +29,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: List[int] + block_ids: MayGroupedBlockIDs num_computed_tokens: int output_token_ids: List[int] @@ -46,14 +47,14 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, + max_num_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, + kv_cache_config: KVCacheConfig, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size @@ -77,12 +78,13 @@ def __init__( self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Block table. - self.block_table = BlockTable( + self.block_table = initialize_block_table( max_num_reqs=max_num_reqs, max_model_len=max_model_len, - max_num_blocks_per_req=max_num_blocks_per_req, + max_num_tokens=max_num_tokens, pin_memory=pin_memory, device=device, + kv_cache_config=kv_cache_config, ) # Sampling-related. @@ -194,7 +196,7 @@ def add_request( self.num_tokens[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(req_index, request.block_ids) + self.block_table.add_row(request.block_ids, req_index) sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7841fac1df34b..715b3c59bb7d3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,7 +13,7 @@ from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture -from vllm.forward_context import set_forward_context +from vllm.forward_context import ForwardMetadata, set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -28,10 +28,12 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) + KVCacheNewTensor, KVCacheReuseTensor, + KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -61,7 +63,6 @@ def __init__( model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config - parallel_config = self.parallel_config self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype @@ -73,19 +74,10 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() # Multi-modal data support @@ -112,15 +104,6 @@ def __init__( # Request states. self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE @@ -189,11 +172,14 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() + + self.kv_cache_config = cast(KVCacheConfig, + None) # Set by initialize_kv_cache + + # InputBatch depends on KVCacheConfig, assign a fake value here and + # initialize in `initialize_kv_cache``. + self.input_batch = cast(InputBatch, None) # Persistent batch. + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -337,10 +323,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) - start_index = len(req_state.block_ids) - len( - req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -425,24 +409,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) - # Prepare the attention metadata. self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -469,102 +435,144 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True) seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, non_blocking=True) - slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( - self.device, non_blocking=True).long() - - # Prepare for cascade attention if needed. - common_prefix_len = (scheduler_output.num_common_prefix_blocks * - self.block_size) - if common_prefix_len == 0: - # Common case. - use_cascade = False + + attn_metadata: Dict[str, FlashAttentionMetadata] = {} + + if len(self.kv_cache_config.groups) == 1: + may_grouped_unwrapper = lambda x, _group_id: x else: - # NOTE(woosuk): Cascade attention uses two attention kernels: one - # for the common prefix and the other for the rest. For the first - # kernel, we concatenate all the query tokens (possibly from - # different requests) and treat them as if they are from the same - # request. Then, we use bi-directional attention to process the - # common prefix in the KV cache. Importantly, this means that the - # first kernel does not do any masking. - - # Consider the following example: - # Request 1's input query: [D, E, X] - # Request 1's kv cache: [A, B, C, D, E, X] - # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) - # Request 2's input query: [E, Y] - # Request 2's kv cache: [A, B, C, D, E, Y] - # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) - - # If we use [A, B, C, D, E] as the common prefix, then the - # first kernel will compute the bi-directional attention between - # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. - # However, this is wrong because D in Request 1 should not attend to - # E in the common prefix (i.e., we need masking). - # To avoid this, [A, B, C, D] should be the common prefix. - # That is, the common prefix should be capped by the minimum - # num_computed_tokens among the requests, and plus one to include - # the first token of the query. - - # In practice, we use [A, B, C] as the common prefix, instead of - # [A, B, C, D] (i.e., the common prefix is capped by the minimum - # num_computed_tokens, without plus one). - # This is because of an implementation detail: We want to always - # use two kernels for cascade attention. Let's imagine: - # Request 3's input query: [D] - # Request 3's kv cache: [A, B, C, D] - # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) - # If we use [A, B, C, D] as the common prefix for Request 1-3, - # then Request 3 will be processed only by the first kernel, - # and the second kernel will get an empty input. While this is not - # a fundamental problem, our current implementation does not support - # this case. - common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) - # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) - use_cascade = FlashAttentionBackend.use_cascade_attention( + may_grouped_unwrapper = lambda x, group_id: x[group_id] + + for group_id, kv_cache_group in enumerate(self.kv_cache_config.groups): + block_size = kv_cache_group.kv_cache_spec.block_size + block_table: BlockTable = may_grouped_unwrapper( + self.input_batch.block_table, group_id) + + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + slot_mapping = block_table.slot_mapping_cpu \ + [:total_num_scheduled_tokens] \ + .to(self.device, non_blocking=True).long() + + # Prepare for cascade attention if needed. + common_prefix_len = (may_grouped_unwrapper( + scheduler_output.num_common_prefix_blocks, group_id) * + block_size) + if common_prefix_len == 0: + # Common case. + use_cascade = False + else: + # NOTE(woosuk): Cascade attention uses two attention kernels: + # one for the common prefix and the other for the rest. For the + # first kernel, we concatenate all the query tokens (possibly + # from different requests) and treat them as if they are from + # the same request. Then, we use bi-directional attention to + # process the common prefix in the KV cache. Importantly, this + # means that the first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not + # attend to E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to + # include the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is + # not a fundamental problem, our current implementation does not + # support this case. + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // block_size * + block_size) + kv_cache_spec = kv_cache_group.kv_cache_spec + assert isinstance(kv_cache_spec, + (FullAttentionSpec, SlidingWindowSpec)) + use_cascade = FlashAttentionBackend.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=kv_cache_spec.num_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, + use_alibi=False, # FIXME + use_sliding_window=self.sliding_window is not None, + num_sms=self.num_sms, + ) + + if use_cascade: + # TODO: Optimize. + cu_prefix_query_lens = torch.tensor( + [0, total_num_scheduled_tokens], + dtype=torch.int32, + device=self.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.device) + suffix_kv_lens = (self.seq_lens_np[:num_reqs] - + common_prefix_len) + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( + self.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + + attn_metadata_of_group = FlashAttentionMetadata( + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table.get_device_tensor()[:num_reqs], + slot_mapping=slot_mapping, + use_cascade=use_cascade, common_prefix_len=common_prefix_len, - query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - use_alibi=False, # FIXME - use_sliding_window=self.sliding_window is not None, - num_sms=self.num_sms, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, ) - if use_cascade: - # TODO: Optimize. - cu_prefix_query_lens = torch.tensor( - [0, total_num_scheduled_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len) - suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device) - else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=( - self.input_batch.block_table.get_device_tensor()[:num_reqs]), - slot_mapping=slot_mapping, - use_cascade=use_cascade, - common_prefix_len=common_prefix_len, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, - ) + for layer_name in kv_cache_group.layer_names: + attn_metadata[layer_name] = attn_metadata_of_group # NOTE(woosuk): Due to chunked prefills, the batch may contain partial # requests. While we should not sample any token from these partial # requests, we do so for simplicity. We will ignore the sampled @@ -758,7 +766,8 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens + + forward_metadata = ForwardMetadata(num_input_tokens=num_input_tokens) if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision @@ -784,7 +793,9 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, + self.vllm_config, + forward_metadata=forward_metadata): positions = self.mrope_positions[:, :num_input_tokens] \ if self.model_config.uses_mrope \ else self.positions[:num_input_tokens] @@ -908,9 +919,11 @@ def profile_run(self) -> None: # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. + num_attn_layers = self.model_config.get_num_layers_by_block_type( + self.parallel_config, LayerBlockType.attention) dummy_kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(self.num_attn_layers) + for _ in range(num_attn_layers) ] # Profile with multimodal encoder & encoder cache. @@ -1043,6 +1056,71 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def _initialize_kv_cache_buffer( + self, kv_cache_config: KVCacheConfig) -> Dict[str, torch.Tensor]: + """ + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. + + Args: + kv_cache_config: The KV cache config + + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_cache_raw_tensors: Dict[str, torch.Tensor] = {} + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheNewTensor): + # A new tensor with `tensor_config.size` bytes + kv_cache_raw_tensors[layer_name] = torch.zeros( + tensor_config.size, dtype=torch.int8, device=self.device) + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheReuseTensor): + # Reuse a tensor from `kv_cache_raw_tensors` + kv_cache_raw_tensors[layer_name] = kv_cache_raw_tensors[ + tensor_config.reused_layer_name] + assert len(kv_cache_raw_tensors) == len( + kv_cache_config.tensors), "Some layers are not initialized" + return kv_cache_raw_tensors + + def _setup_kv_cache_shapes( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """ + Reshape the KV cache tensors to the desired shape. + + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_caches: Dict[str, torch.Tensor] = {} + for kv_cache_group in kv_cache_config.groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel( + ) // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, + (FullAttentionSpec, SlidingWindowSpec)): + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape) + else: + raise NotImplementedError + return kv_caches + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1050,34 +1128,29 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") - - kv_caches: Dict[str, torch.Tensor] = {} - - for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % layer_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // layer_spec.page_size_bytes - if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( - num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, - layer_spec.head_size) - dtype = layer_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - else: - raise NotImplementedError - + self.kv_cache_config = kv_cache_config + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._initialize_kv_cache_buffer( + kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._setup_kv_cache_shapes(kv_cache_config, + kv_cache_raw_tensors) bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) - def get_kv_cache_spec(self) -> KVCacheSpec: + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.vllm_config.model_config.get_vocab_size(), + kv_cache_config=kv_cache_config, + ) + + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. @@ -1088,18 +1161,29 @@ def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: KVCacheSpec = {} + kv_cache_spec: Dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=attn_module.dtype, - ) + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_heads=attn_module.num_heads, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + sliding_window=attn_module.sliding_window, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_heads=attn_module.num_heads, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. @@ -1109,5 +1193,4 @@ def get_kv_cache_spec(self) -> KVCacheSpec: else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") - return kv_cache_spec diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0adb69073397c..7f4407cd3723f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Optional import torch import torch.distributed @@ -192,7 +192,7 @@ def determine_available_memory(self) -> int: return int(available_kv_cache_memory) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: