Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Sliding window for block manager v2 #4545

Merged
merged 58 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
6b1d3b2
merge
Apr 10, 2024
d55db5e
enable sliding window (doesn't work)
mmoskal Apr 30, 2024
be53001
add null_block
mmoskal May 2, 2024
ac070a1
ruff, ruff
mmoskal May 2, 2024
b37f028
yapf
mmoskal May 2, 2024
368c3ee
add AttentionBackend.zero_block()
mmoskal May 2, 2024
c4e533d
zero-out null_block
mmoskal May 2, 2024
7bc88f2
comment out debug assertions
mmoskal May 2, 2024
2825cd7
merge main
mmoskal May 2, 2024
0ade169
fix mypy
mmoskal May 2, 2024
f703af2
basic correctness test for sliding window
mmoskal May 3, 2024
9661776
allocate block progressively for chunked prefill
mmoskal May 3, 2024
a0459b4
Revert "allocate block progressively for chunked prefill"
mmoskal May 3, 2024
6326fbc
fix sliding_window+chunked_prefill
mmoskal May 3, 2024
785aa19
add test for chunked prefill + sliding window
mmoskal May 3, 2024
165b7a8
merge main
mmoskal May 3, 2024
e26631e
spelling + formatting
mmoskal May 3, 2024
57678f4
remove junk
mmoskal May 3, 2024
22e9bb8
simplify test
mmoskal May 3, 2024
be984d1
testcase PR feedback
mmoskal May 8, 2024
34b0fef
Revert "add AttentionBackend.zero_block()"
mmoskal May 8, 2024
b0261fd
zero-out whole KV cache on alloc
mmoskal May 8, 2024
0a3d3b6
add comments
mmoskal May 8, 2024
c947103
require num_computed_slots for sliding window
mmoskal May 8, 2024
a28116f
formatting
mmoskal May 8, 2024
2a5436d
add NullBlock proxy class
mmoskal May 9, 2024
cc1467f
improve comments
mmoskal May 9, 2024
1edc9be
assert sliding window size
mmoskal May 9, 2024
85d8fd9
add docstring
mmoskal May 9, 2024
29fd0d5
add comment
mmoskal May 9, 2024
09c192a
format
mmoskal May 9, 2024
54a5e93
bump test size
mmoskal May 9, 2024
609a9ce
start on sliding window support in paged attn decode kernel
mmoskal May 9, 2024
5079d9a
Revert "start on sliding window support in paged attn decode kernel"
mmoskal May 9, 2024
83f82d9
construct correct block tables in sliding window decode phase
mmoskal May 10, 2024
0bb1f67
remove debug out
mmoskal May 10, 2024
7e7bd02
bump test len again
mmoskal May 10, 2024
728b722
ruff
mmoskal May 10, 2024
96b71a0
add sliding window v2 test
mmoskal May 10, 2024
0ac37c0
Merge branch 'main' into sliding_window_v2
mmoskal May 11, 2024
e086854
fix possible issue with kernel
mmoskal May 11, 2024
bdad3bb
make test pass
mmoskal May 11, 2024
d9521ba
update comments
mmoskal May 11, 2024
24844dd
ruff
mmoskal May 11, 2024
f4ede62
rename: block_sliding_window => max_block_sliding_window
mmoskal May 23, 2024
63d9e50
merge main
mmoskal May 23, 2024
7f00204
finish merge
mmoskal May 24, 2024
fa1ea2f
fix conftest import
mmoskal May 24, 2024
97aa968
allow v2+chunked_prefill+sliding_window
mmoskal May 24, 2024
ce04cde
restore assert
mmoskal May 24, 2024
0a5a73d
formatting
mmoskal May 24, 2024
60f0d25
fix sliding window computation
mmoskal May 25, 2024
2855de6
formatting
mmoskal May 25, 2024
14f33b1
force rebuild
mmoskal May 25, 2024
d58aefd
Merge branch 'main' into sliding_window_v2
mmoskal May 25, 2024
507d2dc
prefix caching fix
mmoskal May 25, 2024
cc6cab4
add assertion message
mmoskal May 25, 2024
dffca79
re-run build
mmoskal May 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions tests/core/block/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import random
from typing import Iterable, List

import pytest

from vllm import LLM, SamplingParams

# relatively small model with 4k sliding window
MODEL = "bigcode/starcoder2-3b"


# the prompt is just under 10k tokens; sliding window is 4k
# so the answer is outside sliding window, but should still be correct
def prep_prompts(batch_size: int):
prompts: List[str] = []
answer: List[int] = []
indices: List[int] = []
random.seed(1)
for _ in range(batch_size):
idx = random.randint(30, 90)
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
answer.append(v)
prompt += f"x{k} = {v}\n"
prompt += f"# Now, we check the value of x{idx}:\n"
prompt += f"assert x{idx} == "
prompts.append(prompt)
return prompts, answer, indices


def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
for a1, a2 in zip(answer, answer2):
if a1 == a2:
numok += 1
frac_ok = numok / len(answer)
print(f"Numok: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": MODEL,

# skip cuda graph creation for fast test.
"enforce_eager": True,
"block_size": 16,
"num_gpu_blocks_override": 100000 // 16,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
batch_size, seed):
sampling_params = SamplingParams(
max_tokens=128,
ignore_eos=True,
temperature=0.0,
)

prompts, answer, indices = prep_prompts(batch_size)

print('Getting token ids from block manager v1')
baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
prompts, sampling_params)

check_answers(indices, answer, baseline_texts)

print('Getting token ids from block manager v2')
test_texts = get_text_from_llm_generator(test_llm_generator, prompts,
sampling_params)
check_answers(indices, answer, test_texts)

for expected_text, actual_text in zip(baseline_texts, test_texts):
assert expected_text == actual_text


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": MODEL,

# skip cuda graph creation for fast test.
"enforce_eager": True,
"block_size": 16,
"num_gpu_blocks_override": 100000 // 16,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"enable_chunked_prefill": True
}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
sampling_params = SamplingParams(
max_tokens=10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid these tests won't catch issues with the block mapping. E.g. I expect error to accumulate over many tokens before we see a significant divergence in attention scores. 10/4096 tokens is not very much, same for 128/4096 although it's better.

WDYT? Is my intuition right? Should we test with larger generation size? Another option is to patch sliding_window to be smaller (e.g. two blocks) so the impact of any error is larger. If we go with patching sliding_window we could even use one of the 68m models for a faster test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an issue with the block tables passed to single token decode, which was causing different output with 1024 tokens. I have now fixed that and bumped test size to 1024.

However, it's still slightly incorrect because the decode kernel does not support sliding window natively - the way it works now it just takes all the blocks passed in (up to seq_len). With v1 manager, the sliding window uses blocks in a "ring buffer" fashion, so this is not a problem. With the new block manager we need potentially to start attention computation in the middle of a block, otherwise we pay attention to a few tokens too many. It doesn't seem to affect this test though.

I have started fixing the decode kernel, but I think that should be a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the changes but want to say that yes, this problem is known and we should fix it eventually (awesome if you want to do it). let's get this PR in with good tests for where we're at and future PR can fix the decode kernel.

ignore_eos=True,
temperature=0.0,
)

prompts, answer, indices = prep_prompts(batch_size)

# We don't compare with the baseline model here, since the results
# slightly different due to different tailing in attention.
test_texts = get_text_from_llm_generator(test_llm_generator, prompts,
sampling_params)
check_answers(indices, answer, test_texts)


def get_text_from_llm_generator(llm_generator: Iterable[LLM], prompts,
sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
text = [output.outputs[0].text for output in outputs]
del llm

return text
8 changes: 8 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def copy_blocks(
) -> None:
raise NotImplementedError

@staticmethod
@abstractmethod
def zero_block(
kv_caches: List[torch.Tensor],
block_id: int,
) -> None:
raise NotImplementedError


@dataclass
class AttentionMetadataPerStage:
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def zero_block(
kv_caches: List[torch.Tensor],
block_id: int,
) -> None:
PagedAttention.zero_block(kv_caches, block_id)


@dataclass
class FlashAttentionMetadata(AttentionMetadataPerStage,
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def zero_block(
kv_caches: List[torch.Tensor],
block_id: int,
) -> None:
PagedAttention.zero_block(kv_caches, block_id)


@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def zero_block(
kv_caches: List[torch.Tensor],
block_id: int,
) -> None:
PagedAttention.zero_block(kv_caches, block_id)


@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def copy_blocks(
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

@staticmethod
def zero_block(
kv_caches: List[torch.Tensor],
block_id: int,
) -> None:
PagedAttention.zero_block(kv_caches, block_id)


@dataclass
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
Expand Down
8 changes: 8 additions & 0 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,11 @@ def copy_blocks(
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

@staticmethod
def zero_block(
kv_caches: List[torch.Tensor],
block_id: int,
) -> None:
for kv_cache in kv_caches:
kv_cache[:, block_id, :].zero_()
23 changes: 21 additions & 2 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def __init__(
block_size: int,
block_allocator: DeviceAwareBlockAllocator,
_blocks: Optional[List[Block]] = None,
block_sliding_window: Optional[int] = None,
):
self._block_size = block_size
self._allocator = block_allocator
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks

self._block_sliding_window = block_sliding_window
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
Expand Down Expand Up @@ -89,7 +91,8 @@ def allocate(self,

def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0) -> None:
num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.

Expand All @@ -105,12 +108,27 @@ def append_token_ids(self,
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert self._is_allocated, "no blocks have been allocated"
assert len(self._blocks) > 0

if self._block_sliding_window is not None:
null_block = self._allocator.null_block
if num_computed_slots is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have some comments on this code? (what is this branch for?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@mmoskal mmoskal May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test

num_computed_slots = self._num_full_slots
end_idx = (num_computed_slots //
self._block_size) - self._block_sliding_window
for idx in range(0, end_idx):
b = self._blocks[idx]
if b is not null_block:
self._allocator.free(b)
self._blocks[idx] = null_block

# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)

# Update the blocks with the new tokens
blocks = self._blocks[self._num_full_slots // self._block_size:]
token_blocks = self._chunk_token_blocks_for_append(token_ids)

Expand Down Expand Up @@ -168,6 +186,7 @@ def fork(self) -> "BlockTable":
block_size=self._block_size,
block_allocator=self._allocator,
_blocks=forked_blocks,
block_sliding_window=self._block_sliding_window,
)

def free(self) -> None:
Expand Down
16 changes: 16 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,24 @@ def __init__(
Device.GPU: gpu_block_allocator,
}

self._null_block: Optional[Block] = None

self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator

@property
def null_block(self) -> Block:
if self._null_block is None:
self._null_block = self.allocate_mutable(None, Device.GPU)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super familiar with interface, but if we use allocate_immutable, isn't this supposed to guarantee the hack below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem to prevent modifications, at least for the NaiveBlocks - we now use a NullBlock wrapper anyways. allocate_immutable() makes the block participate in hash tables of valid prefixes. I'm not sure we want the null block there (and even then, it's probably not very useful), so I kept allocate_mutable()


def fail(token_ids: List[int]):
raise ValueError("null_block should not be modified")

self._null_block.append_token_ids = fail # type: ignore
return self._null_block

def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
"""Allocates a new mutable block on the specified device.
Expand Down Expand Up @@ -149,6 +162,8 @@ def free(self, block: Block) -> None:
Args:
block (Block): The block to be freed.
"""
if block is self._null_block:
return
block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
Expand All @@ -165,6 +180,7 @@ def fork(self, last_block: Block) -> List[Block]:
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
assert last_block is not self._null_block
block_id = last_block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
Expand Down
5 changes: 5 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,8 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass

@property
@abstractmethod
def null_block(self) -> Block:
pass
19 changes: 15 additions & 4 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ def __init__(
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks

assert sliding_window is None, "Sliding window not yet supported"

self.sliding_window = sliding_window
# block_sliding_window is the max number of blocks that need to be
# allocated
# We generally need up 1 block more due to the way BlockTable works
self.block_sliding_window = None
if sliding_window is not None:
self.block_sliding_window = sliding_window // block_size + 2

self.watermark = watermark
assert watermark >= 0.0
Expand All @@ -82,6 +86,12 @@ def __init__(
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
if self.sliding_window is not None:
# Allocate the null_block first, so it gets ID of 0.
# CacheEngine makes sure the first block is always zeroed-out
# so we don't get some nasty NaNs in there.
null_block = self.block_allocator.null_block
assert null_block.block_id == 0

self.block_tables: Dict[SeqId, BlockTable] = {}

Expand All @@ -95,7 +105,6 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
block_size=self.block_size,
)

assert self.block_sliding_window is None
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
Expand Down Expand Up @@ -124,8 +133,9 @@ def allocate(self, seq_group: SequenceGroup) -> None:
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
block_sliding_window=self.block_sliding_window,
)
assert self.block_sliding_window is None

block_table.allocate(seq.get_token_ids())
self.block_tables[seq.seq_id] = block_table

Expand Down Expand Up @@ -173,6 +183,7 @@ def append_slots(
block_table.append_token_ids(
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
num_computed_slots=seq.data.get_num_computed_tokens(),
)

# Return any new copy-on-writes.
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
# Zero out the first block in the cache, in case it gets used as
# 'null_block' in the CpuGpuBlockAllocator
self.attn_backend.zero_block(self.gpu_cache, 0)

def _allocate_kv_cache(
self,
Expand Down
Loading
Loading