Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager #12003

Merged
merged 4 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 47 additions & 24 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def test_prefill():
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

Expand All @@ -73,9 +74,10 @@ def test_prefill():
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
Expand All @@ -91,7 +93,7 @@ def test_prefill():
# All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8)]
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
Expand All @@ -103,9 +105,10 @@ def test_prefill():
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]
Expand All @@ -123,8 +126,9 @@ def test_prefill():

# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
Expand All @@ -150,8 +154,9 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

Expand Down Expand Up @@ -197,16 +202,18 @@ def test_evict():

last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated

# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
Expand All @@ -222,8 +229,9 @@ def test_evict():

# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
Expand All @@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1

Expand All @@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks = manager.get_computed_blocks(req)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1

Expand All @@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0

# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
Expand All @@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0
assert num_computed_tokens == block_size

blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
Expand All @@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():

req1 = make_request("1", list(range(10))) # 2 blocks and some more

computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3

Expand All @@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():

# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4

# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks = manager.get_computed_blocks(req3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks

Expand All @@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)

req = make_request("0", list(range(block_size * 30)))
computed_blocks = manager.get_computed_blocks(req)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
Expand Down Expand Up @@ -469,10 +486,11 @@ def test_mm_prefix_caching():
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)

# Completed block should have hashes with extra keys.
assert not computed_blocks
assert num_computed_tokens == 0
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
Expand Down Expand Up @@ -503,8 +521,9 @@ def test_mm_prefix_caching():
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
assert num_computed_tokens == 3 * 16


def test_prefill_not_enough_free_blocks_with_computed_blocks():
Expand All @@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id]

# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
Expand All @@ -547,17 +568,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks)

# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks = manager.get_computed_blocks(req3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
# Block 0-2 are used by Req 1.
Expand Down
17 changes: 12 additions & 5 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Tuple

from vllm.logger import init_logger
from vllm.utils import cdiv
Expand Down Expand Up @@ -69,19 +69,22 @@ def __init__(
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}

def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
Comment on lines +72 to +73
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the docstring about the return type.

"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.

Args:
request: The request to get the computed blocks.

Returns:
A list of blocks that are computed for the request.
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
if not self.enable_caching:
# Prefix caching is disabled.
return []
return [], 0

computed_blocks = []

Expand All @@ -101,7 +104,11 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
else:
break

return computed_blocks
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens

def append_slots(
self,
Expand Down
8 changes: 2 additions & 6 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,8 @@ def schedule(self) -> "SchedulerOutput":

request = self.waiting[0]
# Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(request)
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
Expand Down
Loading