Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Sliding window for block manager v2 #4545

Merged
merged 58 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
6b1d3b2
merge
Apr 10, 2024
d55db5e
enable sliding window (doesn't work)
mmoskal Apr 30, 2024
be53001
add null_block
mmoskal May 2, 2024
ac070a1
ruff, ruff
mmoskal May 2, 2024
b37f028
yapf
mmoskal May 2, 2024
368c3ee
add AttentionBackend.zero_block()
mmoskal May 2, 2024
c4e533d
zero-out null_block
mmoskal May 2, 2024
7bc88f2
comment out debug assertions
mmoskal May 2, 2024
2825cd7
merge main
mmoskal May 2, 2024
0ade169
fix mypy
mmoskal May 2, 2024
f703af2
basic correctness test for sliding window
mmoskal May 3, 2024
9661776
allocate block progressively for chunked prefill
mmoskal May 3, 2024
a0459b4
Revert "allocate block progressively for chunked prefill"
mmoskal May 3, 2024
6326fbc
fix sliding_window+chunked_prefill
mmoskal May 3, 2024
785aa19
add test for chunked prefill + sliding window
mmoskal May 3, 2024
165b7a8
merge main
mmoskal May 3, 2024
e26631e
spelling + formatting
mmoskal May 3, 2024
57678f4
remove junk
mmoskal May 3, 2024
22e9bb8
simplify test
mmoskal May 3, 2024
be984d1
testcase PR feedback
mmoskal May 8, 2024
34b0fef
Revert "add AttentionBackend.zero_block()"
mmoskal May 8, 2024
b0261fd
zero-out whole KV cache on alloc
mmoskal May 8, 2024
0a3d3b6
add comments
mmoskal May 8, 2024
c947103
require num_computed_slots for sliding window
mmoskal May 8, 2024
a28116f
formatting
mmoskal May 8, 2024
2a5436d
add NullBlock proxy class
mmoskal May 9, 2024
cc1467f
improve comments
mmoskal May 9, 2024
1edc9be
assert sliding window size
mmoskal May 9, 2024
85d8fd9
add docstring
mmoskal May 9, 2024
29fd0d5
add comment
mmoskal May 9, 2024
09c192a
format
mmoskal May 9, 2024
54a5e93
bump test size
mmoskal May 9, 2024
609a9ce
start on sliding window support in paged attn decode kernel
mmoskal May 9, 2024
5079d9a
Revert "start on sliding window support in paged attn decode kernel"
mmoskal May 9, 2024
83f82d9
construct correct block tables in sliding window decode phase
mmoskal May 10, 2024
0bb1f67
remove debug out
mmoskal May 10, 2024
7e7bd02
bump test len again
mmoskal May 10, 2024
728b722
ruff
mmoskal May 10, 2024
96b71a0
add sliding window v2 test
mmoskal May 10, 2024
0ac37c0
Merge branch 'main' into sliding_window_v2
mmoskal May 11, 2024
e086854
fix possible issue with kernel
mmoskal May 11, 2024
bdad3bb
make test pass
mmoskal May 11, 2024
d9521ba
update comments
mmoskal May 11, 2024
24844dd
ruff
mmoskal May 11, 2024
f4ede62
rename: block_sliding_window => max_block_sliding_window
mmoskal May 23, 2024
63d9e50
merge main
mmoskal May 23, 2024
7f00204
finish merge
mmoskal May 24, 2024
fa1ea2f
fix conftest import
mmoskal May 24, 2024
97aa968
allow v2+chunked_prefill+sliding_window
mmoskal May 24, 2024
ce04cde
restore assert
mmoskal May 24, 2024
0a5a73d
formatting
mmoskal May 24, 2024
60f0d25
fix sliding window computation
mmoskal May 25, 2024
2855de6
formatting
mmoskal May 25, 2024
14f33b1
force rebuild
mmoskal May 25, 2024
d58aefd
Merge branch 'main' into sliding_window_v2
mmoskal May 25, 2024
507d2dc
prefix caching fix
mmoskal May 25, 2024
cc6cab4
add assertion message
mmoskal May 25, 2024
dffca79
re-run build
mmoskal May 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/core/block/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Iterable, Optional

import pytest

from tests.conftest import cleanup
Expand Down Expand Up @@ -39,3 +41,27 @@ def generator_inner():
for llm in generator_inner():
yield llm
del llm


def get_text_from_llm_generator(llm_generator: Iterable[LLM],
prompts,
sampling_params,
llm_cb: Optional[Callable[[LLM],
None]] = None):
for llm in llm_generator:
if llm_cb:
llm_cb(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
text = [output.outputs[0].text for output in outputs]
del llm

return text


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
del llm

return token_ids
10 changes: 1 addition & 9 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from itertools import cycle

import pytest
from conftest import get_token_ids_from_llm_generator

from vllm import SamplingParams

Expand Down Expand Up @@ -444,12 +445,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
del llm

return token_ids
167 changes: 167 additions & 0 deletions tests/core/block/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import random
from typing import List

import pytest
from conftest import get_text_from_llm_generator

from vllm import LLM, SamplingParams

# relatively small model with 4k sliding window
MODEL = "bigcode/starcoder2-3b"
BLOCK_SIZE = 16


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": MODEL,

# skip cuda graph creation for fast test.
"enforce_eager": True,
"block_size": BLOCK_SIZE,
# needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
batch_size, seed):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).

Additionally, we compare the results of the v1 and v2 managers.
"""
sampling_params = SamplingParams(
max_tokens=1024,
ignore_eos=True,
temperature=0.0,
)

prompts, answer, indices = prep_prompts(batch_size)

print('Getting token ids from block manager v1')
baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
prompts,
sampling_params,
llm_cb=check_window(prompts))

check_answers(indices, answer, baseline_texts)

print('Getting token ids from block manager v2')
test_texts = get_text_from_llm_generator(test_llm_generator, prompts,
sampling_params)
check_answers(indices, answer, test_texts)

cmp = [
expected_text == actual_text
for expected_text, actual_text in zip(baseline_texts, test_texts)
]
print(cmp)
# make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768
# however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290
# states that xformers and flash_attn have different ideas about the window
# size anyways
assert sum(cmp) > 0.7 * len(cmp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm but this compares block manager v1 vs v2? Doesn't this mean everything else is equivalent (i.e., kernels and things like that).

cc @cadedaniel do you think this is expected outputs are different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, it's not the same kernel arguments. With v2 manager we're passing up 15 kv entries more to the paged attention kernel. With v1, the kv entries "wrap around" in the blocks, but they do not with v2 (needed for prefix caching etc.), see description in #4768

Copy link
Collaborator

@cadedaniel cadedaniel May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the reason why this is so hard to test is because the semantics of blocks are changed in V2. E.g. we now have clear distinction between mutable and immutable blocks. so the kernels that previously would overwrite blocks (causing U.B. with copy-on-write in V1) now don't, but the downside is they capture additional context since we don't yet have masking.

The tests in this PR are not really good enough to catch correctness issues with the sliding window block mapping. The error tolerance in this test is very high and the unit test test_sliding_window only checks num consumed is correct.

That said, this is an improvement over the previous sliding window tests and I think we can merge and follow up later..

FWIW the way I'd test this is one/both of the following:

  • Modify this test to use block_size=1. This avoids the masking issue entirely and we should expect exact equality between v1 and v2 for most prompts.
  • Add a stronger unit test for block_manager_v2 or block_allocator that verifies the correct sliding window block mapping



@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": MODEL,

# skip cuda graph creation for fast test.
"enforce_eager": True,
"block_size": BLOCK_SIZE,
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"enable_chunked_prefill": True
}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
"""
This is similar to test_sliding_window_retrival, however, it doesn't
compare against the v1 block manager since v1 doesn't support
chunked prefill with sliding window.

The results with and without chunked prefill are not the same due to
numerical instabilities.
"""
sampling_params = SamplingParams(
max_tokens=10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid these tests won't catch issues with the block mapping. E.g. I expect error to accumulate over many tokens before we see a significant divergence in attention scores. 10/4096 tokens is not very much, same for 128/4096 although it's better.

WDYT? Is my intuition right? Should we test with larger generation size? Another option is to patch sliding_window to be smaller (e.g. two blocks) so the impact of any error is larger. If we go with patching sliding_window we could even use one of the 68m models for a faster test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an issue with the block tables passed to single token decode, which was causing different output with 1024 tokens. I have now fixed that and bumped test size to 1024.

However, it's still slightly incorrect because the decode kernel does not support sliding window natively - the way it works now it just takes all the blocks passed in (up to seq_len). With v1 manager, the sliding window uses blocks in a "ring buffer" fashion, so this is not a problem. With the new block manager we need potentially to start attention computation in the middle of a block, otherwise we pay attention to a few tokens too many. It doesn't seem to affect this test though.

I have started fixing the decode kernel, but I think that should be a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the changes but want to say that yes, this problem is known and we should fix it eventually (awesome if you want to do it). let's get this PR in with good tests for where we're at and future PR can fix the decode kernel.

ignore_eos=True,
temperature=0.0,
)

prompts, answer, indices = prep_prompts(batch_size)

# We don't compare with the baseline model here, since the results
# slightly different due to different tailing in attention.
test_texts = get_text_from_llm_generator(test_llm_generator,
prompts,
sampling_params,
llm_cb=check_window(prompts))
check_answers(indices, answer, test_texts)


def prep_prompts(batch_size: int):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any good way to assert the return prompt token len > 4k?

so the answer is outside sliding window, but should still be correct.
"""
prompts: List[str] = []
answer: List[int] = []
indices: List[int] = []
random.seed(1)
for _ in range(batch_size):
idx = random.randint(30, 90)
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
answer.append(v)
prompt += f"x{k} = {v}\n"
prompt += f"# Now, we check the value of x{idx}:\n"
prompt += f"assert x{idx} == "
prompts.append(prompt)
return prompts, answer, indices


def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
for a1, a2 in zip(answer, answer2):
if a1 == a2:
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the higest we can get? Or more of arbitrarty value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's arbitrary; I've been mostly getting 4/5 right, sometimes 5/5



def check_window(prompts: List[str]):

def inner(llm: LLM):
sliding_window = llm.llm_engine.model_config.get_sliding_window()
assert sliding_window and sliding_window > 0
assert any(
len(llm.get_tokenizer().tokenize(prompt)) > sliding_window
for prompt in prompts)

return inner
69 changes: 69 additions & 0 deletions tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
assert num_consumed_blocks == expected_consumed_blocks


@pytest.mark.parametrize("block_size", [8, 16])
@pytest.mark.parametrize("prompt_len", [10, 300, 1000])
@pytest.mark.parametrize("num_slots_to_append", [50])
@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512])
def test_sliding_window(block_size, prompt_len, num_slots_to_append,
sliding_window):
"""Verify append_slots consumes the correct number of blocks from the block
table.
"""

num_gpu_blocks = 1024
watermark = 0.1
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
watermark=watermark,
sliding_window=sliding_window,
)

def check_used(min_n, max_n=None):
if max_n is None:
max_n = min_n
used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
#print("check", min_n, used, max_n)
assert min_n <= used
assert used <= max_n

def num_blocks(num_tokens):
return (num_tokens + block_size - 1) // block_size

check_used(0)

seq_group = create_seq_group(
seq_prompt_len=prompt_len,
seq_output_lens=[0],
)

check_used(0)

# Allocate seq
assert block_manager.can_allocate(seq_group)
block_manager.allocate(seq_group)

check_used(num_blocks(prompt_len))

# Seq seq to RUNNING
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

seq.data.update_num_computed_tokens(prompt_len)
check_used(num_blocks(prompt_len))

# this is how we compute it in BlockSpaceManagerV2.__init__
sliding_blocks = (sliding_window // block_size) + 2
# plus one block for null block
sliding_blocks += 1

# Append tokens to the sequeqnce
for token_id in range(num_slots_to_append):
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
seq.data.update_num_computed_tokens(1)
block_manager.append_slots(seq, num_lookahead_slots=0)
if prompt_len < sliding_window + 10:
check_used(0, sliding_blocks + 1)
else:
check_used(sliding_blocks, sliding_blocks + 1)
6 changes: 5 additions & 1 deletion vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,10 @@ def context_attention_fwd(q,

grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,

# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0

num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
Expand Down Expand Up @@ -794,7 +798,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
SLIDING_WINDOW=sliding_window,
num_warps=num_warps,
num_stages=1,
)
Expand Down
34 changes: 32 additions & 2 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class BlockTable:
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
block_sliding_window (Optional[int], optional): The number of blocks to
keep around for each sequance. If None, all blocks are kept
(eg., when sliding window is note used).
It should at least fit the sliding window size of the model.

Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
Expand All @@ -37,13 +41,15 @@ def __init__(
block_size: int,
block_allocator: DeviceAwareBlockAllocator,
_blocks: Optional[List[Block]] = None,
block_sliding_window: Optional[int] = None,
):
self._block_size = block_size
self._allocator = block_allocator
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks

self._block_sliding_window = block_sliding_window
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
Expand Down Expand Up @@ -89,7 +95,8 @@ def allocate(self,

def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0) -> None:
num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.

Expand All @@ -104,13 +111,35 @@ def append_token_ids(self,

Args:
token_ids (List[int]): The sequence of token IDs to be appended.
num_computed_slots (Optional[int]): The number of KV cache slots
that are already filled (computed).
When sliding window is enabled, this is used to compute how many
blocks to drop at the front of the sequence.
Without sliding window, None can be passed.
Without chunked prefill, it should be the same as
_num_full_slots.
"""
assert self._is_allocated
assert self._is_allocated, "no blocks have been allocated"
assert len(self._blocks) > 0

# Drop blocks that are no longer needed due to sliding window
if self._block_sliding_window is not None:
null_block = self._allocator.allocate_or_get_null_block()
assert num_computed_slots is not None
end_block_idx = (num_computed_slots //
self._block_size) - self._block_sliding_window
for idx in range(0, end_block_idx):
b = self._blocks[idx]
if b is not null_block:
self._allocator.free(b)
self._blocks[idx] = null_block

# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)

# Update the blocks with the new tokens
blocks = self._blocks[self._num_full_slots // self._block_size:]
token_blocks = self._chunk_token_blocks_for_append(token_ids)

Expand Down Expand Up @@ -168,6 +197,7 @@ def fork(self) -> "BlockTable":
block_size=self._block_size,
block_allocator=self._allocator,
_blocks=forked_blocks,
block_sliding_window=self._block_sliding_window,
)

def free(self) -> None:
Expand Down
Loading
Loading