Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Hybrid allocator for full attention & sliding window attention interleaved models #12655

Draft
wants to merge 55 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0f8a54c
can run
heheda12345 Jan 15, 2025
990d086
fix tests
heheda12345 Jan 15, 2025
e46fff5
format
heheda12345 Jan 15, 2025
36a649a
fix bug
heheda12345 Jan 15, 2025
9c36e7d
add comments
heheda12345 Jan 15, 2025
da6b549
format
heheda12345 Jan 15, 2025
4030199
Merge branch 'main' of github.com:vllm-project/vllm into grouped_bloc…
heheda12345 Jan 17, 2025
2d8213e
Merge branch 'main' of github.com:vllm-project/vllm into grouped_bloc…
heheda12345 Jan 17, 2025
41bc571
update code
heheda12345 Jan 20, 2025
a939b6d
Merge branch 'main' of github.com:vllm-project/vllm into grouped_bloc…
heheda12345 Jan 20, 2025
34c9d74
can run
heheda12345 Jan 20, 2025
cfcf2b4
update comments
heheda12345 Jan 20, 2025
4898973
init kv cache for group allocation
heheda12345 Jan 21, 2025
ef9dc9d
can run, result a little strange
heheda12345 Jan 21, 2025
5b71ccd
fix small bug
heheda12345 Jan 22, 2025
99de9f8
Merge branch 'main' of github.com:vllm-project/vllm into grouped_bloc…
heheda12345 Jan 22, 2025
6a0eb69
cleanup SpecializedManager
heheda12345 Jan 29, 2025
14ad04e
add test and fix bug for sliding window manager
heheda12345 Jan 29, 2025
eb34a44
remove useless code
heheda12345 Jan 29, 2025
f53e824
fix several bugs
heheda12345 Jan 31, 2025
0ecf3fa
update sliding window test
heheda12345 Jan 31, 2025
3998e92
small fix
heheda12345 Jan 31, 2025
446e99d
small fix, can run gemma2
heheda12345 Jan 31, 2025
d97c1b0
add test for range_intersect
heheda12345 Jan 31, 2025
5ebfeac
clean up get_computed_blocks, append_slots, allocate_slots
heheda12345 Jan 31, 2025
4e0dc48
finish the clean up of kv cache manager
heheda12345 Jan 31, 2025
cd4f8e2
clean up the code
heheda12345 Jan 31, 2025
2d7bbca
fix some tests
heheda12345 Feb 1, 2025
68fe2db
remove print kvcacheconfig
heheda12345 Feb 1, 2025
30e9837
move files
heheda12345 Feb 1, 2025
05d8b0d
Merge branch 'main' of github.com:vllm-project/vllm into grouped_bloc…
heheda12345 Feb 2, 2025
e6016e5
add docstrings
heheda12345 Feb 2, 2025
b369fa2
fix pre-commit
heheda12345 Feb 2, 2025
beb5d08
Merge branch 'main' of github.com:vllm-project/vllm into grouped_bloc…
heheda12345 Feb 4, 2025
ea65e60
add request.py
heheda12345 Feb 4, 2025
f4d1c93
Merge branch 'main' of github.com:vllm-project/vllm into grouped_bloc…
heheda12345 Feb 4, 2025
42c391d
remove small comment
heheda12345 Feb 4, 2025
f6d2bfd
avoid loop in block table
heheda12345 Feb 5, 2025
b614b42
clean up attn_metadata
heheda12345 Feb 6, 2025
ca91b30
BlockIDList
heheda12345 Feb 6, 2025
0475e9f
forward metadata
heheda12345 Feb 6, 2025
bcfc994
cleanup
heheda12345 Feb 6, 2025
bcab7af
fix pre-commit
heheda12345 Feb 6, 2025
8fa6f8f
Merge remote-tracking branch 'heheda/grouped_block_table' into hybrid…
heheda12345 Feb 7, 2025
0a9701e
fix
heheda12345 Feb 7, 2025
a7173a2
fix
heheda12345 Feb 7, 2025
5e2d3bd
cherry-pick: [V1] Move KV block hashes from Request to KVCacheManager…
WoosukKwon Feb 8, 2025
09782a2
small fix
heheda12345 Feb 8, 2025
1e44abd
add manager group, can run, result strange
heheda12345 Feb 8, 2025
bf302c5
fix some bug
heheda12345 Feb 9, 2025
be926dd
move out block pool
heheda12345 Feb 9, 2025
4278b01
unify get_possible_cached_prefix
heheda12345 Feb 9, 2025
9853e25
unify get_req_num_new_blocks
heheda12345 Feb 9, 2025
33dfa2b
dynamic group
heheda12345 Feb 9, 2025
57cb55b
small fix
heheda12345 Feb 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/offline_inference/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="google/gemma-2-2b-it")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
13 changes: 8 additions & 5 deletions tests/core/block/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import random
from typing import List
from typing import List, Tuple

import pytest

Expand Down Expand Up @@ -120,7 +120,7 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
check_answers(indices, answer, test_texts)


def prep_prompts(batch_size: int):
def prep_prompts(batch_size: int, assign_range: Tuple[int, int] = (800, 1100)):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
Expand All @@ -136,7 +136,7 @@ def prep_prompts(batch_size: int):
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
ln = random.randint(*assign_range)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
Expand All @@ -148,7 +148,10 @@ def prep_prompts(batch_size: int):
return prompts, answer, indices


def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
def check_answers(indices: List[int],
answer: List[int],
outputs: List[str],
accept_rate=0.7):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
Expand All @@ -157,7 +160,7 @@ def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7
assert frac_ok >= accept_rate


def check_window(prompts: List[str]):
Expand Down
45 changes: 36 additions & 9 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
KVCacheBlock, PrefixLengthRange,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
hash_request_tokens, intersect_ranges)
from vllm.v1.request import Request


Expand Down Expand Up @@ -49,7 +49,9 @@ def test_kv_cache_block():
assert block.ref_cnt == 0

# Test block hash setting and resetting
block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3))
block_hash = BlockHashType(hash_value=123,
kv_cache_group_id=0,
token_ids=(1, 2, 3))
block.block_hash = block_hash
assert block.block_hash == block_hash

Expand Down Expand Up @@ -190,11 +192,11 @@ def test_hash_block_tokens():
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")

block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids, 0,
extra_keys)
assert isinstance(block_hash, BlockHashType)
assert block_hash.hash_value == hash(
(parent_block_hash, curr_block_token_ids, extra_keys))
(parent_block_hash, curr_block_token_ids, 0, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys

Expand All @@ -214,7 +216,7 @@ def test_hash_request_tokens():
)

block_size = 3
block_hashes = hash_request_tokens(block_size, request)
block_hashes = hash_request_tokens(block_size, request, 0)

assert len(block_hashes) == 2
assert isinstance(block_hashes[0], BlockHashType)
Expand Down Expand Up @@ -255,8 +257,8 @@ def test_hash_tokens_different_mm_input():
mm_hashes=["hash3", "hash2"],
)
block_size = 3
block_hashes1 = hash_request_tokens(block_size, request1)
block_hashes2 = hash_request_tokens(block_size, request2)
block_hashes1 = hash_request_tokens(block_size, request1, 0)
block_hashes2 = hash_request_tokens(block_size, request2, 0)
assert block_hashes1[0] != block_hashes2[0]
assert block_hashes1[1] != block_hashes2[1]

Expand All @@ -270,10 +272,35 @@ def test_hash_request_tokens_no_mm_inputs():
)

block_size = 3
block_hashes = hash_request_tokens(block_size, request)
block_hashes = hash_request_tokens(block_size, request, 0)

assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys is None
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys is None


def test_prefix_length_range_intersection():
range0 = [
PrefixLengthRange(1, 5),
PrefixLengthRange(10, 14),
PrefixLengthRange(16, 18)
]
range1 = [
PrefixLengthRange(2, 6),
PrefixLengthRange(8, 12),
PrefixLengthRange(15, 17)
]
range2 = [PrefixLengthRange(3, 11), PrefixLengthRange(13, 19)]
ranges = [range0, range1, range2]

intersection = intersect_ranges(ranges)
assert intersection == [
PrefixLengthRange(3, 5),
PrefixLengthRange(10, 11),
PrefixLengthRange(16, 17)
]


# TODO: add tests for hash of kv_cache_group_id
Loading