diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a46ee15817f4c..8d0554b0f4f05 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -46,6 +46,7 @@ def main(args: argparse.Namespace): load_format=args.load_format, distributed_executor_backend=args.distributed_executor_backend, otlp_traces_endpoint=args.otlp_traces_endpoint, + enable_prefix_caching=args.enable_prefix_caching, ) sampling_params = SamplingParams( @@ -220,6 +221,9 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + parser.add_argument("--enable-prefix-caching", + action='store_true', + help="Enable automatic prefix caching") parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", diff --git a/tests/conftest.py b/tests/conftest.py index 0bd24905efab8..ac802d03b1c85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -474,7 +474,7 @@ def generate( req_sample_output_strs: List[str] = [] for sample in req_output.outputs: output_str = sample.text - output_ids = sample.token_ids + output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) req_sample_output_strs.append(prompt_str + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index 496774c8de53c..e2391a5680b36 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -373,8 +373,9 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, block_size) - (sequence_len // block_size) original_block_table.allocate(token_ids=token_ids, device=Device.GPU) - original_block_ids = original_block_table.physical_block_ids + original_block_ids = original_block_table.physical_block_ids[:] + print("original_block_ids = {}".format(original_block_ids)) forked_block_table = original_block_table.fork() # Expect no additional allocation (copy on _write_). @@ -457,7 +458,7 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, # Allocate lookahead slots. original_block_table.ensure_num_empty_slots(lookahead_slots) - original_block_ids = original_block_table.physical_block_ids + original_block_ids = original_block_table.physical_block_ids[:] forked_block_table = original_block_table.fork() diff --git a/tests/core/block/test_cpu_gpu_block_allocator.py b/tests/core/block/test_cpu_gpu_block_allocator.py index 44a5be6c181a0..15b76d9093c63 100644 --- a/tests/core/block/test_cpu_gpu_block_allocator.py +++ b/tests/core/block/test_cpu_gpu_block_allocator.py @@ -8,8 +8,8 @@ @pytest.mark.parametrize("num_gpu_blocks", [1024]) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): +def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): allocator = CpuGpuBlockAllocator.create( allocator_type=allocator_type, num_gpu_blocks=num_gpu_blocks, @@ -21,14 +21,14 @@ def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int, assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks cpu_blocks = [ - allocator.allocate_mutable(prev_block=None, device=Device.CPU) + allocator.allocate_mutable_block(prev_block=None, device=Device.CPU) for _ in range(num_cpu_blocks) ] assert allocator.get_num_free_blocks(Device.CPU) == 0 assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks gpu_blocks = [ - allocator.allocate_mutable(prev_block=None, device=Device.GPU) + allocator.allocate_mutable_block(prev_block=None, device=Device.GPU) for _ in range(num_gpu_blocks) ] assert allocator.get_num_free_blocks(Device.CPU) == 0 @@ -47,8 +47,8 @@ def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int, @pytest.mark.parametrize("num_gpu_blocks", [1024]) @pytest.mark.parametrize("block_size", [2]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): +def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): allocator = CpuGpuBlockAllocator.create( allocator_type=allocator_type, num_gpu_blocks=num_gpu_blocks, @@ -67,18 +67,18 @@ def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int, assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks cpu_blocks = [ - allocator.allocate_immutable(prev_block=None, - token_ids=token_ids, - device=Device.CPU) + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.CPU) for token_ids in cpu_token_ids ] assert allocator.get_num_free_blocks(Device.CPU) == 0 assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks gpu_blocks = [ - allocator.allocate_immutable(prev_block=None, - token_ids=token_ids, - device=Device.GPU) + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.GPU) for token_ids in gpu_token_ids ] assert allocator.get_num_free_blocks(Device.CPU) == 0 diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py index edcdc0c7d4f98..9821ac41b8342 100644 --- a/tests/core/block/test_naive_block.py +++ b/tests/core/block/test_naive_block.py @@ -14,11 +14,11 @@ def create_allocate_lambda(allocate_type: str, prev_block: Optional[Block], token_ids: List[int]): if allocate_type == "immutable": - allocate_block = lambda: allocator.allocate_immutable( + allocate_block = lambda: allocator.allocate_immutable_block( prev_block=prev_block, token_ids=token_ids) elif allocate_type == "mutable": - allocate_block = lambda: allocator.allocate_mutable(prev_block= - prev_block) + allocate_block = lambda: allocator.allocate_mutable_block( + prev_block=prev_block) else: raise ValueError() diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index fcf32cbe99472..95858268a964f 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -26,11 +26,10 @@ def test_first_block_has_correct_content_hash(seed: int, block_size: int, token_ids = list(range(num_to_fill)) mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) - block_with_prev = PrefixCachingBlock( - prev_block=None, - token_ids=token_ids, - block_size=block_size, - prefix_caching_allocator=mock_allocator) + block_with_prev = PrefixCachingBlock(prev_block=None, + token_ids=token_ids, + block_size=block_size, + allocator=mock_allocator) if is_curr_block_full: # Expect hash since block is full. @@ -71,7 +70,7 @@ def test_nth_block_has_correct_content_hash(seed: int, block_size: int, prev_block=previous_block, token_ids=token_ids, block_size=block_size, - prefix_caching_allocator=mock_allocator, + allocator=mock_allocator, ) if is_curr_block_full and prev_block_has_hash: @@ -138,7 +137,7 @@ def create_chain(block_size: int, prev_block=prev_block, token_ids=[], block_size=block_size, - prefix_caching_allocator=allocator, + allocator=allocator, ) tokens_to_append = token_ids[block_number * @@ -159,11 +158,11 @@ def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator, prev_block: Optional[Block], token_ids: List[int]): if allocate_type == "immutable": - allocate_block = lambda: allocator.allocate_immutable( + allocate_block = lambda: allocator.allocate_immutable_block( prev_block=prev_block, token_ids=token_ids) elif allocate_type == "mutable": - allocate_block = lambda: allocator.allocate_mutable(prev_block= - prev_block) + allocate_block = lambda: allocator.allocate_mutable_block( + prev_block=prev_block) else: raise ValueError() @@ -233,12 +232,13 @@ def test_allocate_immutable_ooms_many_hash(num_blocks: int, # Expect allocation with unseen hash to fail. with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_immutable(prev_block=chain[-1], - token_ids=list(range(block_size))) + allocator.allocate_immutable_block(prev_block=chain[-1], + token_ids=list( + range(block_size))) # Expect mutable allocation to fail. with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable(prev_block=chain[-1]) + allocator.allocate_mutable_block(prev_block=chain[-1]) # Expect allocation of exact same chain to pass. second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( @@ -270,7 +270,7 @@ def test_free_prevents_oom(num_blocks: int, block_size: int): # Expect mutable allocation to fail. with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable(prev_block=None) + allocator.allocate_mutable_block(prev_block=None) block_to_free = chain[-1] @@ -280,11 +280,11 @@ def test_free_prevents_oom(num_blocks: int, block_size: int): allocator.free(block_to_free) assert block_to_free.block_id is None, i - new_block = allocator.allocate_mutable(prev_block=None) + new_block = allocator.allocate_mutable_block(prev_block=None) assert new_block.block_id == block_id, i with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable(prev_block=None) + allocator.allocate_mutable_block(prev_block=None) block_to_free = new_block @@ -376,7 +376,6 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, # Create token ids that will exhaust all blocks. token_ids = list(range(num_blocks_to_consume * block_size)) - blocks = list(range(num_blocks_to_consume)) first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -384,9 +383,6 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, allocator=allocator, ) - # mark all blocks in first chain as computed - allocator.mark_blocks_as_computed(blocks) - # After zero_point, second_chain's token_ids would be set -1, which # make it different from here comparing with first_chain zero_point = random.randint(1, len(token_ids) - 1) @@ -424,15 +420,16 @@ def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): block_size=block_size) token_ids = list(range(block_size)) - block = allocator.allocate_immutable(prev_block=None, - token_ids=token_ids) + block = allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) assert allocator._refcounter.get(block.block_id) == 1 - m = allocator.allocate_mutable(prev_block=None) + m = allocator.allocate_mutable_block(prev_block=None) block_id = m.block_id for i in range(block_size): m.append_token_ids([i]) + # After block get promoted to immutable from mutable, if there is # already same content hash block, then it shall be released into # hashless_allocator @@ -452,48 +449,79 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): all_blocks_list = [i for i in range(num_blocks)] zero_ref = {i: 0 for i in range(num_blocks)} + one_ref = {i: 1 for i in range(num_blocks)} allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, block_size=block_size) token_ids = list(range(num_blocks * block_size)) - # now we have num_blocks free blocks in hashless allocator - # with internal tracking list _blocks _cached_blocks and evictor - # empty and block's ref shall be 0 + # Verify initial/pre-alloc state + + # Ensure all blocks are free inside hashless allocator assert list(allocator._hashless_allocator._free_block_indices ) == all_blocks_list - assert len(allocator._blocks.keys()) == 0 + # Ensure no tracked blocks + assert len(allocator._block_tracker.keys()) == num_blocks + for block_id in range(num_blocks): + assert not allocator._block_tracker[block_id].active + # Ensure no cached blocks assert len(allocator._cached_blocks.values()) == 0 + # Ensure no evicted blocks assert len(allocator.evictor.free_table.keys()) == 0 + # Ensure 0s ref counts for all blocks assert allocator._refcounter._refcounts == zero_ref # Allocate immutable chains with only one block residuled in new_block = [] for i in range(num_blocks): - block = allocator.allocate_immutable( + block = allocator.allocate_immutable_block( prev_block=None, token_ids=token_ids[block_size * i:block_size * (i + 1)]) new_block.append(block) + # Verify post-alloc state + + # Ensure no blocks are free inside hashless allocator + assert (len(allocator._hashless_allocator._free_block_indices) == 0) + # Ensure all blocks are tracked + assert len(allocator._block_tracker.keys()) == num_blocks + for block_id in range(num_blocks): + assert allocator._block_tracker[block_id].active + # Ensure all blocks are cached (all promoted) + assert len(allocator._cached_blocks.values()) == num_blocks + # Ensure no evicted blocks + assert len(allocator.evictor.free_table.keys()) == 0 + # Ensure 1s ref counts for all blocks + assert allocator._refcounter._refcounts == one_ref + # Free all blocks, and now all blocks shall be in the evictor - # there shall be no tracking data left in _blocks + # there shall be no tracking data left in _block_tracker # all blocks shall be tracked in _cached_blocks # all blocks' ref shall be zero for block in new_block: allocator.free(block) - assert len(allocator._blocks.keys()) == 0 + # Verify post-free state + + # Ensure no tracked blocks + assert len(allocator._block_tracker.keys()) == num_blocks + for block_id in range(num_blocks): + assert not allocator._block_tracker[block_id].active + # Ensure no blocks in hashless allocator (all promoted) assert len(allocator._hashless_allocator._free_block_indices) == 0 + # Ensure all blocks are cached assert list(allocator._cached_blocks.values()) == all_blocks_list + # Ensure all blocks are inside the evictor assert list(allocator.evictor.free_table.keys()) == all_blocks_list + # Ensure 0s refcounts assert allocator._refcounter._refcounts == zero_ref # Allocate a mutable block, and the first block shall be evicted # and set its content hash into None, ref to 1 - mutable = allocator.allocate_mutable(prev_block=None) + mutable = allocator.allocate_mutable_block(prev_block=None) assert mutable.block_id == 0 assert mutable.content_hash is None - assert 0 in allocator._blocks + assert allocator._block_tracker[0].active assert allocator._refcounter.get(0) == 1 assert 0 not in allocator._cached_blocks assert 0 not in allocator.evictor @@ -502,27 +530,27 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): # hashless allocator allocator.free(mutable) - assert len(allocator._blocks.keys()) == 0 + assert not allocator._block_tracker[0].active assert allocator._refcounter._refcounts == zero_ref assert 0 not in allocator._cached_blocks assert 0 not in allocator.evictor assert 0 in allocator._hashless_allocator._free_block_indices - # when allocate immutable with first block_size tokens, we + # When allocate immutable with first block_size tokens, we # shall get free block from hashless allocator, thus no block left # in hashless - block = allocator.allocate_immutable(prev_block=None, - token_ids=token_ids[:block_size]) + block = allocator.allocate_immutable_block( + prev_block=None, token_ids=token_ids[:block_size]) assert block.block_id == 0 assert len(allocator._hashless_allocator._free_block_indices) == 0 - assert 0 in allocator._blocks + assert allocator._block_tracker[0].active assert 0 in allocator._cached_blocks.values() assert allocator._refcounter.get(0) == 1 assert 0 not in allocator.evictor # allocate mutable block again, it shall be popped from evictor - mutable = allocator.allocate_mutable(prev_block=None) + mutable = allocator.allocate_mutable_block(prev_block=None) assert len(allocator._hashless_allocator._free_block_indices) == 0 assert mutable.block_id not in allocator.evictor.free_table assert allocator._refcounter.get(mutable.block_id) == 1 @@ -619,7 +647,7 @@ def create_immutable_chain( block_token_ids = token_ids[block_number * block_size:(block_number + 1) * block_size] - prev_block = allocator.allocate_immutable( + prev_block = allocator.allocate_immutable_block( prev_block=prev_block, token_ids=block_token_ids) blocks.append(prev_block) diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py index 42dd90422ec47..c350a2c55396e 100644 --- a/tests/spec_decode/test_batch_expansion.py +++ b/tests/spec_decode/test_batch_expansion.py @@ -90,10 +90,10 @@ def test_create_single_target_seq_group_metadata(k: int): assert output.request_id == input_seq_group_metadata.request_id assert len(output.seq_data) == 1 - assert output.seq_data[target_seq_id].get_prompt_token_ids( - ) == prompt_tokens - assert output.seq_data[target_seq_id].get_output_token_ids( - ) == prev_output_tokens + token_ids + assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple( + prompt_tokens) + assert output.seq_data[target_seq_id].get_output_token_ids() == tuple( + prev_output_tokens + token_ids) assert len(output.block_tables) == 1 assert output.block_tables[ diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index d705f3d91a074..49e63c23155b8 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -1,5 +1,6 @@ from typing import List, Optional +from vllm.core.block.common import BlockList from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator from vllm.utils import Device, cdiv, chunk_list @@ -47,12 +48,10 @@ def __init__( self._allocator = block_allocator if _blocks is None: _blocks = [] - self._blocks: List[Block] = _blocks + self._blocks: BlockList = BlockList(_blocks) self._max_block_sliding_window = max_block_sliding_window - # Use helper method instead of directly calculating, as blocks - # may not be allocated. - self._num_full_slots = len(self._get_all_token_ids()) + self._num_full_slots = self._get_num_token_ids() @staticmethod def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: @@ -88,11 +87,18 @@ def allocate(self, """ assert not self._is_allocated assert token_ids - self._blocks = self._allocate_blocks_for_token_ids(prev_block=None, - token_ids=token_ids, - device=device) + blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_ids=token_ids, + device=device) + self.update(blocks) self._num_full_slots = len(token_ids) + def update(self, blocks: List[Block]) -> None: + """Resets the table to the newly provided blocks + (with their corresponding block ids) + """ + self._blocks.update(blocks) + def append_token_ids(self, token_ids: List[int], num_lookahead_slots: int = 0, @@ -140,11 +146,11 @@ def append_token_ids(self, num_lookahead_slots) # Update the blocks with the new tokens - blocks = self._blocks[self._num_full_slots // self._block_size:] + first_block_idx = self._num_full_slots // self._block_size token_blocks = self._chunk_token_blocks_for_append(token_ids) - for block, token_block in zip(blocks, token_blocks): - block.append_token_ids(token_block) + for i, token_block in enumerate(token_blocks): + self._blocks.append_token_ids(first_block_idx + i, token_block) self._num_full_slots += len(token_ids) @@ -174,8 +180,8 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: for _ in range(blocks_to_allocate): assert len(self._blocks) > 0 self._blocks.append( - self._allocator.allocate_mutable(prev_block=self._blocks[-1], - device=device)) + self._allocator.allocate_mutable_block( + prev_block=self._blocks[-1], device=device)) def fork(self) -> "BlockTable": """Creates a new BlockTable instance with a copy of the blocks from the @@ -209,12 +215,12 @@ def free(self) -> None: is set to `None`. """ assert self._is_allocated - for block in self._blocks: + for block in self.blocks: self._allocator.free(block) - self._blocks = [] + self._blocks.reset() @property - def physical_block_ids(self) -> List[Optional[int]]: + def physical_block_ids(self) -> List[int]: """Returns a list of physical block indices for the blocks in the BlockTable. @@ -228,7 +234,7 @@ def physical_block_ids(self) -> List[Optional[int]]: BlockTable. """ assert self._is_allocated - return [block.block_id for block in self._blocks] + return self._blocks.ids() def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: """Get the number of "unseen" tokens in the sequence. @@ -253,17 +259,31 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], token_ids: List[int], device: Device) -> List[Block]: blocks: List[Block] = [] - for block_token_ids in chunk_list(token_ids, self._block_size): - if len(block_token_ids) == self._block_size: - # If the block is full, create an immutable block. - prev_block = self._allocator.allocate_immutable( - prev_block, token_ids=block_token_ids, device=device) + + block_token_ids = [] + tail_token_ids = [] + for cur_token_ids in chunk_list(token_ids, self._block_size): + if len(cur_token_ids) == self._block_size: + block_token_ids.append(cur_token_ids) else: - # Else, partially fill a mutable block with token ids. - prev_block = self._allocator.allocate_mutable( - prev_block=prev_block, device=device) - prev_block.append_token_ids(block_token_ids) - blocks.append(prev_block) + tail_token_ids.append(cur_token_ids) + + if block_token_ids: + blocks.extend( + self._allocator.allocate_immutable_blocks( + prev_block, block_token_ids=block_token_ids, + device=device)) + prev_block = blocks[-1] + + if tail_token_ids: + assert len(tail_token_ids) == 1 + cur_token_ids = tail_token_ids[0] + + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device) + block.append_token_ids(cur_token_ids) + + blocks.append(block) return blocks @@ -274,18 +294,25 @@ def _get_all_token_ids(self) -> List[int]: if not self._is_allocated: return token_ids - for block in self._blocks: + for block in self.blocks: token_ids.extend(block.token_ids) return token_ids + def _get_num_token_ids(self) -> int: + res = 0 + for block in self.blocks: + res += len(block.token_ids) + + return res + @property def _is_allocated(self) -> bool: return len(self._blocks) > 0 @property - def blocks(self) -> Optional[List[Block]]: - return self._blocks + def blocks(self) -> List[Block]: + return self._blocks.list() @property def _num_empty_slots(self) -> int: diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index d2787d69616f0..1e808e21b72e5 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,4 +1,5 @@ -from typing import Dict, Iterable, List, Optional, Protocol, Tuple +from collections import deque +from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -95,64 +96,40 @@ class CopyOnWriteTracker: The CopyOnWriteTracker class maintains a mapping of source block indices to their corresponding copy-on-write destination block indices. It works in - conjunction with a RefCounter and a BlockAllocator to handle reference - counting and block allocation. + conjunction with a RefCounter. Args: refcounter (RefCounter): The reference counter used to track block reference counts. - allocator (BlockAllocator): The block allocator used to allocate and - free blocks. """ - def __init__( - self, - refcounter: RefCounterProtocol, - allocator: BlockAllocator, - ): + def __init__(self, refcounter: RefCounterProtocol): self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] self._refcounter = refcounter - self._allocator = allocator - - def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: - """Performs a copy-on-write operation on the given block if it is not - appendable. - - This method checks the reference count of the given block. If the - reference count is greater than 1, indicating that the block is shared, - a copy-on-write operation is performed. The original block is freed, - and a new block is allocated with the same content. The new block index - is returned. - - Args: - block (Block): The block to check for copy-on-write. - Returns: - Optional[BlockId]: The block index of the new block if a copy-on - -write operation was performed, or the original block index if - no copy-on-write was necessary. + def is_appendable(self, block: Block) -> bool: + """Checks if the block is shared or not. If shared, then it cannot + be appended and needs to be duplicated via copy-on-write """ block_id = block.block_id if block_id is None: - return block_id + return True refcount = self._refcounter.get(block_id) - assert refcount != 0 - if refcount > 1: - src_block_id = block_id - # Decrement refcount of the old block. - self._allocator.free(block) - - # Allocate a fresh new block. - block_id = self._allocator.allocate_mutable( - prev_block=block.prev_block).block_id + return refcount <= 1 - # Track src/dst copy. - assert src_block_id is not None - assert block_id is not None - self._copy_on_writes.append((src_block_id, block_id)) - - return block_id + def record_cow(self, src_block_id: Optional[BlockId], + trg_block_id: Optional[BlockId]) -> None: + """Records a copy-on-write operation from source to target block id + Args: + src_block_id (BlockId): The source block id from which to copy + the data + trg_block_id (BlockId): The target block id to which the data + is copied + """ + assert src_block_id is not None + assert trg_block_id is not None + self._copy_on_writes.append((src_block_id, trg_block_id)) def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: """Clears the copy-on-write tracking information and returns the current @@ -172,6 +149,139 @@ def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: return cows +class BlockPool: + """Used to pre-allocate block objects, in order to avoid excessive python + object allocations/deallocations. + The pool starts from "pool_size" objects and will increase to more objects + if necessary + + Note that multiple block objects may point to the same physical block id, + which is why this pool is needed, so that it will be easier to support + prefix caching and more complicated sharing of physical blocks. + """ + + def __init__(self, block_size: int, create_block: Block.Factory, + allocator: BlockAllocator, pool_size: int): + self._block_size = block_size + self._create_block = create_block + self._allocator = allocator + self._pool_size = pool_size + assert self._pool_size >= 0 + + self._free_ids: Deque[int] = deque(range(self._pool_size)) + self._pool = [] + for i in range(self._pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None)) + + def increase_pool(self): + """Doubles the internal pool size + """ + cur_pool_size = self._pool_size + new_pool_size = cur_pool_size * 2 + self._pool_size = new_pool_size + + self._free_ids += deque(range(cur_pool_size, new_pool_size)) + + for i in range(cur_pool_size, new_pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None)) + + def init_block(self, prev_block: Optional[Block], token_ids: List[int], + block_size: int, physical_block_id: Optional[int]) -> Block: + if len(self._free_ids) == 0: + self.increase_pool() + assert len(self._free_ids) > 0 + + pool_id = self._free_ids.popleft() + + block = self._pool[pool_id] + block.__init__( # type: ignore[misc] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + allocator=block._allocator, # type: ignore[attr-defined] + block_id=physical_block_id) + block.pool_id = pool_id # type: ignore[attr-defined] + return block + + def free_block(self, block: Block) -> None: + self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] + + +class BlockList: + """This class is an optimization to allow fast-access to physical + block ids. It maintains a block id list that is updated with the + block list and this avoids the need to reconstruct the block id + list on every iteration of the block manager + """ + + def __init__(self, blocks: List[Block]): + self._blocks: List[Block] = [] + self._block_ids: List[int] = [] + + self.update(blocks) + + def _add_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_ids.append(block_id) + + def _update_block_id(self, block_index: int, + new_block_id: Optional[BlockId]) -> None: + assert new_block_id is not None + self._block_ids[block_index] = new_block_id + + def update(self, blocks: List[Block]): + self._blocks = blocks + + # Cache block ids for fast query + self._block_ids = [] + for block in self._blocks: + self._add_block_id(block.block_id) + + def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + block = self._blocks[block_index] + prev_block_id = block.block_id + + block.append_token_ids(token_ids) + + # CoW or promotion may update the internal block_id + if prev_block_id != block.block_id: + self._update_block_id(block_index, block.block_id) + + def append(self, new_block: Block): + self._blocks.append(new_block) + self._add_block_id(new_block.block_id) + + def __len__(self) -> int: + return len(self._blocks) + + def __getitem__(self, block_index: int) -> Block: + return self._blocks[block_index] + + def __setitem__(self, block_index: int, new_block: Block) -> None: + self._blocks[block_index] = new_block + self._update_block_id(block_index, new_block.block_id) + + def reset(self): + self._blocks = [] + self._block_ids = [] + + def list(self) -> List[Block]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids + + def get_all_blocks_recursively(last_block: Block) -> List[Block]: """Retrieves all the blocks in a sequence starting from the last block. diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 255aae9d17318..5287cd9c1bfb3 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -113,11 +113,11 @@ def __init__(self, cpu_block_allocator: BlockAllocator, def allocate_or_get_null_block(self) -> Block: if self._null_block is None: self._null_block = NullBlock( - self.allocate_mutable(None, Device.GPU)) + self.allocate_mutable_block(None, Device.GPU)) return self._null_block - def allocate_mutable(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable_block(self, prev_block: Optional[Block], + device: Device) -> Block: """Allocates a new mutable block on the specified device. Args: @@ -128,10 +128,31 @@ def allocate_mutable(self, prev_block: Optional[Block], Returns: Block: The newly allocated mutable block. """ - return self._allocators[device].allocate_mutable(prev_block) + return self._allocators[device].allocate_mutable_block(prev_block) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Optional[Device]) -> List[Block]: + """Allocates a new group of immutable blocks with the provided block + token IDs on the specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + block_token_ids (List[int]): The list of block token IDs to be + stored in the new blocks. + device (Device): The device on which to allocate the new block. + + Returns: + List[Block]: The newly allocated list of immutable blocks + containing the provided block token IDs. + """ + return self._allocators[device].allocate_immutable_blocks( + prev_block, block_token_ids) + + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + device: Device) -> Block: """Allocates a new immutable block with the provided token IDs on the specified device. @@ -146,7 +167,7 @@ def allocate_immutable(self, prev_block: Optional[Block], Block: The newly allocated immutable block containing the provided token IDs. """ - return self._allocators[device].allocate_immutable( + return self._allocators[device].allocate_immutable_block( prev_block, token_ids) def free(self, block: Block) -> None: @@ -161,7 +182,7 @@ def free(self, block: Block) -> None: block_id = block.block_id assert block_id is not None allocator = self._block_ids_to_allocator[block_id] - return allocator.free(block) + allocator.free(block) def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying @@ -210,8 +231,8 @@ def get_physical_block_id(self, device: Device, absolute_id: int) -> int: """ return self._allocators[device].get_physical_block_id(absolute_id) - def swap(self, blocks: List[Block], source_device: Device, - dest_device: Device) -> Dict[int, int]: + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: """Execute the swap for the given blocks from source_device on to dest_device, save the current swap mapping and append them to the accumulated `self._swap_mapping` for each @@ -219,23 +240,23 @@ def swap(self, blocks: List[Block], source_device: Device, Args: blocks: List of blocks to be swapped. - source_device (Device): Device to swap the 'blocks' from. - dest_device (Device): Device to swap the 'blocks' to. + src_device (Device): Device to swap the 'blocks' from. + dst_device (Device): Device to swap the 'blocks' to. Returns: Dict[int, int]: Swap mapping from source_device on to dest_device. """ - source_block_ids = [block.block_id for block in blocks] - self._allocators[source_device].swap_out(blocks) - self._allocators[dest_device].swap_in(blocks) - dest_block_ids = [block.block_id for block in blocks] + src_block_ids = [block.block_id for block in blocks] + self._allocators[src_device].swap_out(blocks) + self._allocators[dst_device].swap_in(blocks) + dst_block_ids = [block.block_id for block in blocks] current_swap_mapping: Dict[int, int] = {} - for src, dest in zip(source_block_ids, dest_block_ids): - if src is not None and dest is not None: - self._swap_mapping[src] = dest - current_swap_mapping[src] = dest + for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): + if src_block_id is not None and dst_block_id is not None: + self._swap_mapping[src_block_id] = dst_block_id + current_swap_mapping[src_block_id] = dst_block_id return current_swap_mapping def get_num_blocks_touched(self, @@ -283,23 +304,25 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: device = Device.GPU return self._allocators[device].mark_blocks_as_computed(block_ids) + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].get_computed_block_ids( + prev_computed_block_ids, block_ids, skip_last_block_id) + def get_common_computed_block_ids( - self, seq_block_ids: List[List[int]]) -> List[int]: + self, computed_seq_block_ids: List[List[int]]) -> List[int]: # Prefix caching only supported on GPU. device = Device.GPU return self._allocators[device].get_common_computed_block_ids( - seq_block_ids) + computed_seq_block_ids) @property def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) - def promote_to_immutable_block(self, block: Block) -> BlockId: - raise NotImplementedError - - def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: - raise NotImplementedError - def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every @@ -341,6 +364,11 @@ def block_id(self, value: Optional[BlockId]): def token_ids(self) -> List[BlockId]: return self._proxy.token_ids + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for null block") + @property def num_empty_slots(self) -> BlockId: return self._proxy.num_empty_slots diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 4b20856a1b42d..ab39832bc1f6e 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -28,6 +28,13 @@ def block_id(self, value: Optional[int]) -> None: def token_ids(self) -> List[int]: pass + @property + @abstractmethod + def num_tokens_total(self) -> int: + """The number of tokens till the current block (inclusive) + """ + pass + @property @abstractmethod def num_empty_slots(self) -> int: @@ -92,12 +99,18 @@ def content_hash(self) -> Optional[int]: class BlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block: pass @abstractmethod - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks( + self, prev_block: Optional[Block], + block_token_ids: List[List[int]]) -> List[Block]: pass @abstractmethod @@ -146,13 +159,19 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass + @abstractmethod + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + pass + @abstractmethod def get_common_computed_block_ids( - self, seq_block_ids: List[List[int]]) -> List[int]: + self, computed_seq_block_ids: List[List[int]]) -> List[int]: pass @abstractmethod - def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]: + def cow_block_if_not_appendable(self, block: Block) -> BlockId: """NOTE: This should not be used besides Block""" pass @@ -174,13 +193,20 @@ class NoFreeBlocksError(ValueError): class DeviceAwareBlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable_block(self, prev_block: Optional[Block], + device: Device) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + device: Device) -> Block: pass @abstractmethod - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device) -> List[Block]: pass @abstractmethod @@ -217,9 +243,15 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass + @abstractmethod + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + pass + @abstractmethod def get_common_computed_block_ids( - self, seq_block_ids: List[List[int]]) -> List[int]: + self, computed_seq_block_ids: List[List[int]]) -> List[int]: pass @abstractmethod @@ -230,8 +262,8 @@ def get_num_blocks_touched(self, pass @abstractmethod - def swap(self, blocks: List[Block], source_device: Device, - dest_device: Device) -> Dict[int, int]: + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: pass @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 50f27bab33776..0c1e883141716 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,6 +1,7 @@ -from typing import FrozenSet, Iterable, List, Optional, Set, Tuple +from collections import deque +from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple -from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, +from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.utils import cdiv @@ -31,28 +32,39 @@ def __init__( num_blocks: int, block_size: int, block_ids: Optional[Iterable[int]] = None, + block_pool: Optional[BlockPool] = None, ): if block_ids is None: block_ids = range(num_blocks) - self._free_block_indices: Set[BlockId] = set(block_ids) + self._free_block_indices: Deque[BlockId] = deque(block_ids) self._all_block_indices = frozenset(block_ids) assert len(self._all_block_indices) == num_blocks self._refcounter = RefCounter( all_block_indices=self._free_block_indices) - self._create_block = create_block self._block_size = block_size self._cow_tracker = CopyOnWriteTracker( - refcounter=self._refcounter.as_readonly(), - allocator=self, - ) - - def allocate_immutable(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + refcounter=self._refcounter.as_readonly()) + + if block_pool is None: + extra_factor = 4 + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + self._block_pool = BlockPool(self._block_size, create_block, self, + num_blocks * extra_factor) + else: + # In this case, the block pool is provided by the caller, + # which means that there is most likely a need to share + # a block pool between allocators + self._block_pool = block_pool + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -66,13 +78,36 @@ def allocate_immutable(self, Block: The newly allocated immutable block. """ assert device is None - block = self.allocate_mutable(prev_block=prev_block) + block = self.allocate_mutable_block(prev_block=prev_block) block.append_token_ids(token_ids) return block - def allocate_mutable(self, - prev_block: Optional[Block], - device: Optional[Device] = None) -> Block: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Optional[Device] = None) -> List[Block]: + assert device is None + num_blocks = len(block_token_ids) + + block_ids = [] + for i in range(num_blocks): + block_ids.append(self._allocate_block_id()) + + blocks = [] + for i in range(num_blocks): + prev_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block_token_ids[i], + block_size=self._block_size, + physical_block_id=block_ids[i]) + blocks.append(prev_block) + + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a new mutable block, linked to the previous block. Args: @@ -84,20 +119,39 @@ def allocate_mutable(self, Block: The newly allocated mutable block. """ assert device is None - block_id = self._allocate_new_block_id() - return self._create_block( - prev_block=prev_block, - token_ids=[], - block_id=block_id, - block_size=self._block_size, - allocator=self, - ) - - def free(self, block: Block) -> None: - assert block.block_id is not None - self._free_block_id(block.block_id) + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id) + return block + + def _allocate_block_id(self) -> BlockId: + if not self._free_block_indices: + raise BlockAllocator.NoFreeBlocksError() + + block_id = self._free_block_indices.popleft() + self._refcounter.incr(block_id) + return block_id + + def _free_block_id(self, block: Block) -> None: + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount == 0: + self._free_block_indices.appendleft(block_id) + block.block_id = None + def free(self, block: Block, keep_block_object: bool = False) -> None: + # Release the physical block id + self._free_block_id(block) + + # Release the block object + if not keep_block_object: + self._block_pool.free_block(block) + def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying memory as the original sequence. @@ -120,14 +174,13 @@ def fork(self, last_block: Block) -> List[Block]: refcount = self._refcounter.incr(block.block_id) assert refcount != 1, "can't fork free'd block" - forked_blocks.append( - self._create_block( - prev_block=prev_block, - token_ids=block.token_ids, - block_id=block.block_id, - block_size=self._block_size, - allocator=self, - )) + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block.block_id) + + forked_blocks.append(forked_block) prev_block = forked_blocks[-1] return forked_blocks @@ -138,20 +191,6 @@ def get_num_free_blocks(self) -> int: def get_num_total_blocks(self) -> int: return len(self._all_block_indices) - def _allocate_new_block_id(self) -> BlockId: - if not self._free_block_indices: - raise BlockAllocator.NoFreeBlocksError() - - block_id = next(iter(self._free_block_indices)) - self._refcounter.incr(block_id) - self._free_block_indices.remove(block_id) - return block_id - - def _free_block_id(self, block_id: BlockId) -> None: - refcount = self._refcounter.decr(block_id) - if refcount == 0: - self._free_block_indices.add(block_id) - def get_physical_block_id(self, absolute_id: int) -> int: """Returns the zero-offset block id on certain block allocator given the absolute block id. @@ -173,7 +212,7 @@ def refcounter(self): def all_block_ids(self) -> FrozenSet[int]: return self._all_block_indices - def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + def cow_block_if_not_appendable(self, block: Block) -> BlockId: """Performs a copy-on-write operation on the given block if it is not appendable. @@ -181,11 +220,22 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: block (Block): The block to check for copy-on-write. Returns: - Optional[BlockId]: The block index of the new block if a copy-on - -write operation was performed, or the original block index if + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if no copy-on-write was necessary. """ - return self._cow_tracker.cow_block_if_not_appendable(block) + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. @@ -213,8 +263,15 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """ pass + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + """No prefix caching here => return empty list + """ + return [] + def get_common_computed_block_ids( - self, seq_block_ids: List[List[int]]) -> List[int]: + self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Determine blocks that can be skipped in prefill. Since the naive allocator does not support prefix caching, always return @@ -223,7 +280,7 @@ def get_common_computed_block_ids( return [] def promote_to_immutable_block(self, block: Block) -> BlockId: - raise NotImplementedError + raise NotImplementedError("There is no promotion for naive blocks") def get_num_blocks_touched(self, blocks: List[Block], @@ -263,17 +320,27 @@ def get_num_blocks_touched(self, def swap_out(self, blocks: List[Block]) -> None: for block in blocks: - self.free(block) + self._free_block_id(block) def swap_in(self, blocks: List[Block]) -> None: for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object if block.is_full: - alloc = self.allocate_immutable(block.prev_block, - block.token_ids) + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) else: - alloc = self.allocate_mutable(block.prev_block) - alloc.append_token_ids(block.token_ids) - block.block_id = alloc.block_id + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + tmp_block.block_id = None + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id class NaiveBlock(Block): @@ -315,11 +382,12 @@ def __init__(self, self._append_token_ids_no_cow(token_ids) def append_token_ids(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block, instructing the allocator - to perform a copy-on-write if necessary. + """Appends the given token IDs to the block and performs a + copy-on-write if necessary. Args: - token_ids (List[int]): The token IDs to be appended to the block. + token_ids (Optional[List[int]]): The token IDs to be appended + to the block. """ self._append_token_ids_no_cow(token_ids) @@ -328,7 +396,16 @@ def append_token_ids(self, token_ids: List[int]) -> None: self._cow_target)) def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: - assert self.num_empty_slots >= len(token_ids) + """Appends the given token IDs to the block + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + if len(token_ids) == 0: + return + + assert len(token_ids) <= self.num_empty_slots + self._token_ids.extend(token_ids) @property @@ -361,12 +438,17 @@ def is_full(self) -> bool: @property def num_empty_slots(self) -> int: - return self._block_size - len(self._token_ids) + return self._block_size - len(self.token_ids) @property def token_ids(self) -> List[int]: return self._token_ids + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for naive block") + @property def block_size(self) -> int: return self._block_size diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 2df7d74e4ff19..f272e23ee6088 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,13 +1,13 @@ """Token blocks.""" -from itertools import takewhile from os.path import commonprefix from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block.naive_block import (BlockPool, NaiveBlock, + NaiveBlockAllocator) from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.utils import cdiv @@ -19,6 +19,30 @@ _DEFAULT_LAST_ACCESSED_TIME = -1 +class BlockTracker: + """Used to track the status of a block inside the prefix caching allocator + """ + __slots__ = ("active", "last_accessed", "computed") + + def reset(self): + self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self.computed: bool = False + + def __init__(self): + self.active: bool = False + self.reset() + + def enable(self): + assert not self.active + self.active = True + self.reset() + + def disable(self): + assert self.active + self.active = False + self.reset() + + class PrefixCachingBlockAllocator(BlockAllocator): """A block allocator that implements prefix caching. @@ -41,12 +65,26 @@ def __init__( block_ids: Optional[Iterable[int]] = None, eviction_policy: EvictionPolicy = EvictionPolicy.LRU, ): + if block_ids is None: + block_ids = range(num_blocks) + + self._block_size = block_size + # A mapping of prefix hash to block index. All blocks which have a # prefix hash will be in this dict, even if they have refcount 0. self._cached_blocks: Dict[PrefixHash, BlockId] = {} - # A mapping of blockId to Block to track those cached blocks - self._blocks: Dict[BlockId, Block] = {} + # Used to track status of each physical block id + self._block_tracker: Dict[BlockId, BlockTracker] = {} + for block_id in block_ids: + self._block_tracker[block_id] = BlockTracker() + + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + extra_factor = 4 + self._block_pool = BlockPool(self._block_size, self._create_block, + self, num_blocks * extra_factor) # An allocator for blocks that do not have prefix hashes. self._hashless_allocator = NaiveBlockAllocator( @@ -54,10 +92,9 @@ def __init__( num_blocks=num_blocks, block_size=block_size, block_ids=block_ids, + block_pool=self._block_pool, # Share block pool here ) - self._block_size = block_size - # Evitor used to maintain how we want to handle those computed blocks # if we find memory pressure is high. self.evictor: Evictor = make_evictor(eviction_policy) @@ -68,9 +105,7 @@ def __init__( self._refcounter = self._hashless_allocator.refcounter self._cow_tracker = CopyOnWriteTracker( - refcounter=self._refcounter.as_readonly(), - allocator=self, - ) + refcounter=self._refcounter.as_readonly()) # Implements Block.Factory. def _create_block( @@ -90,14 +125,14 @@ def _create_block( token_ids=token_ids, block_size=block_size, block_id=block_id, - prefix_caching_allocator=allocator, + allocator=allocator, computed=computed, ) - def allocate_immutable(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -111,29 +146,41 @@ def allocate_immutable(self, assert device is None assert_prefix_caching_block_or_none(prev_block) - block = self._create_block( - prev_block=prev_block, - token_ids=token_ids, - block_size=self._block_size, - allocator=self, - ) + # First, try to create a block that points to cached data + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=token_ids, + block_size=self._block_size, + physical_block_id=None) assert block.content_hash is not None cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: block.block_id = cached_block_id - self._incr_refcount_cached_block(block, block.block_id) + self._incr_refcount_cached_block(block) return block + self._block_pool.free_block(block) - block = self.allocate_mutable(prev_block) + # No cached block => Allocate a new block + block = self.allocate_mutable_block(prev_block) block.append_token_ids(token_ids) - assert block.content_hash is not None - return block - def allocate_mutable(self, - prev_block: Optional[Block], - device: Optional[Device] = None) -> Block: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Optional[Device] = None) -> List[Block]: + blocks = [] + for token_ids in block_token_ids: + prev_block = self.allocate_immutable_block(prev_block=prev_block, + token_ids=token_ids, + device=device) + blocks.append(prev_block) + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a mutable block. If there are no free blocks, this will evict unused cached blocks. @@ -147,116 +194,154 @@ def allocate_mutable(self, assert device is None assert_prefix_caching_block_or_none(prev_block) - try: - block = self._hashless_allocator.allocate_mutable( - prev_block=prev_block) - - assert block.block_id not in self._blocks - assert block.block_id is not None - self._blocks[block.block_id] = block - return block - except BlockAllocator.NoFreeBlocksError: - # We must check the unused cached blocks before raising OOM. - pass - - # If the evictor has blocks available for eviction, evict a block - # and return it. - if self.evictor.num_blocks > 0: - # here we get an evicted block, which is only added - # into evictor if its ref counter is 0 - # and since its content would be changed, we need - # to remove it from _cached_blocks's tracking list - block_id, content_hash_to_evict = self.evictor.evict() - - _block_id = self._cached_blocks[content_hash_to_evict] - assert self._refcounter.get(_block_id) == 0 - assert _block_id == block_id - - self._cached_blocks.pop(content_hash_to_evict) - - self._refcounter.incr(block_id) - - # Now this block is pop from evictor and ready to write - # with new content which most probably different with - # original content. So need to tell worker to recompute - # its kvcache - block = self._create_block( - prev_block=prev_block, - token_ids=[], - block_size=self._block_size, - allocator=self, - block_id=block_id, - computed=False, - ) - assert block.content_hash is None - - assert block.block_id not in self._blocks - assert block.block_id is not None - self._blocks[block.block_id] = block - return block - - # No block available in hashless allocator, nor in unused cache blocks. - raise BlockAllocator.NoFreeBlocksError() + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id) + assert not block.computed + assert block.content_hash is None + return block - def _incr_refcount_cached_block(self, block: Block, - block_id: BlockId) -> None: - # now _incr_refcount_cached_block comes from two place - # allocate_immutable/promote_to_immutable_block where hit - # _cached_blocks hash key. - # In both cases, it means that already exists a already - # computed block which shared with block now + def _incr_refcount_cached_block(self, block: Block) -> None: + # Set this block to be "computed" since it is pointing to a + # cached block id (which was already computed) block.computed = True + block_id = block.block_id + assert block_id is not None + refcount = self._refcounter.incr(block_id) if refcount == 1: - # if block get referred, then it shall not be in evictor - # and put it into _blocks for tracking + # In case a cached block was evicted, restore its tracking if block_id in self.evictor: self.evictor.remove(block_id) - self._blocks[block_id] = block - def free(self, block: Block) -> None: - """Decrement the refcount of the block. If the decremented refcount is - zero, store the block in the freelist. + self._track_block_id(block_id, computed=True) - If the block has a content hash (meaning it is immutable), then we will - keep the block around in case future allocations require it. - """ - assert (block.block_id - is not None), "freeing unallocated block is undefined" + def _decr_refcount_cached_block(self, block: Block) -> None: + # Ensure this is immutable/cached block + assert block.content_hash is not None + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount > 0: + block.block_id = None + return + else: + assert refcount == 0 - self._free_block_id_for_block(block.block_id, block) + # No longer used + assert block.content_hash in self._cached_blocks + + # Add the cached block to the evictor + # (This keeps the cached block around so it can be reused) + self.evictor.add(block_id, block.content_hash, block.num_tokens_total, + self._block_tracker[block_id].last_accessed) + + # Stop tracking the block + self._untrack_block_id(block_id) block.block_id = None - def _free_block_id_for_block(self, block_id: BlockId, - block: Block) -> None: - assert isinstance(block, PrefixCachingBlock) - - # if we comes from promote_to_immutable_block, it means that - # block.content_hash is never None. - # However we need to release the same content block, so that - # physical block could get reused. - if block.block_id != block_id or block.content_hash is None: - refcount = self._refcounter.get(block_id) - # We have fork case where block would get more than one ref, - # so we cannot free it from tracking if ref cnt large than 1 - assert block.block_id is not None - refcount = self._refcounter.get(block.block_id) - if refcount == 1: - del self._blocks[block.block_id] - - return self._hashless_allocator.free(block) + def _decr_refcount_hashless_block(self, block: Block) -> None: + block_id = block.block_id + assert block_id is not None - refcount = self._refcounter.decr(block_id) + # We may have a fork case where block is shared, + # in which case, we cannot remove it from tracking + refcount = self._refcounter.get(block_id) + if refcount == 1: + self._untrack_block_id(block_id) - # If no longer used, add the block to the evictor. - if refcount == 0: - assert block.content_hash in self._cached_blocks - assert block.block_id is not None - del self._blocks[block.block_id] - self.evictor.add(block.block_id, block.content_hash, - block.num_tokens_total, block.last_accessed) + # Decrement refcount of the block_id, but do not free the block object + # itself (will be handled by the caller) + self._hashless_allocator.free(block, keep_block_object=True) + + def _allocate_block_id(self) -> BlockId: + """First tries to allocate a block id from the hashless allocator, + and if there are no blocks, then tries to evict an unused cached block. + """ + hashless_block_id = self._maybe_allocate_hashless_block_id() + if hashless_block_id is not None: + return hashless_block_id + + evicted_block_id = self._maybe_allocate_evicted_block_id() + if evicted_block_id is not None: + return evicted_block_id + + # No block available in hashless allocator, nor in unused cache blocks. + raise BlockAllocator.NoFreeBlocksError() + + def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: + try: + # Allocate mutable block and extract its block_id + block = self._hashless_allocator.allocate_mutable_block( + prev_block=None) + block_id = block.block_id + self._block_pool.free_block(block) + + self._track_block_id(block_id, computed=False) + return block_id + except BlockAllocator.NoFreeBlocksError: + return None + + def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: + if self.evictor.num_blocks == 0: + return None + + # Here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list + block_id, content_hash_to_evict = self.evictor.evict() + + # Sanity checks + assert content_hash_to_evict in self._cached_blocks + _block_id = self._cached_blocks[content_hash_to_evict] + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id + + self._cached_blocks.pop(content_hash_to_evict) + + self._refcounter.incr(block_id) + self._track_block_id(block_id, computed=False) + + return block_id + + def _free_block_id(self, block: Block) -> None: + """Decrements the refcount of the block. The block may be in two + possible states: (1) immutable/cached or (2) mutable/hashless. + In the first case, the refcount is decremented directly and the block + may be possibly added to the evictor. In other case, hashless + allocator free(..) with keep_block_object=True is called to only free + the block id (since the block object may be reused by the caller) + """ + block_id = block.block_id + assert block_id is not None, "Freeing unallocated block is undefined" + + if block.content_hash is not None: + # Immutable: This type of block is always cached, and we want to + # keep it in the evictor for future reuse + self._decr_refcount_cached_block(block) + else: + # Mutable: This type of block is not cached, so we release it + # directly to the hashless allocator + self._decr_refcount_hashless_block(block) + + assert block.block_id is None + + def free(self, block: Block, keep_block_object: bool = False) -> None: + """Release the block (look at free_block_id(..) docs) + """ + # Release the physical block index + self._free_block_id(block) + + # Release the block object to the pool + if not keep_block_object: + self._block_pool.free_block(block) def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying @@ -274,17 +359,20 @@ def fork(self, last_block: Block) -> List[Block]: forked_blocks: List[Block] = [] prev_block = None for block in source_blocks: - refcount = self._refcounter.incr(block.block_id) - assert refcount != 1, "can't fork free'd block" - - forked_blocks.append( - self._create_block( - prev_block=prev_block, - token_ids=block.token_ids, - block_id=block.block_id, - block_size=self._block_size, - allocator=self, - )) + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + assert refcount != 1, "can't fork free'd block_id = {}".format( + block_id) + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block_id) + + forked_blocks.append(forked_block) prev_block = forked_blocks[-1] return forked_blocks @@ -329,7 +417,7 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: Note that if we already have a cached block with the same content, we will replace the newly-promoted block's mapping with the existing cached - block. + block id. Args: block: The mutable block to be promoted. @@ -338,23 +426,30 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: BlockId: Either the original block index, or the block index of the previously cached block matching the same content. """ + # Ensure block can be promoted assert block.content_hash is not None assert block.block_id is not None assert self._refcounter.get(block.block_id) > 0 - # If the content hash does not have a corresponding cached block, - # set this block as the cached block. if block.content_hash not in self._cached_blocks: + # No cached content hash => Set this block as cached + # (Note that this block is not computed yet => + # Will be computed after free()) self._cached_blocks[block.content_hash] = block.block_id - else: - self._free_block_id_for_block( - self._cached_blocks[block.content_hash], block) - self._incr_refcount_cached_block( - block, self._cached_blocks[block.content_hash]) + return block.block_id - return self._cached_blocks[block.content_hash] + # Reuse the cached content hash + self._decr_refcount_hashless_block(block) + block.block_id = self._cached_blocks[block.content_hash] - def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + # Increment refcount of the cached block and (possibly) restore + # it from the evictor. + # Note that in this case, the block is marked as computed + self._incr_refcount_cached_block(block) + + return block.block_id + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: """Performs a copy-on-write operation on the given block if it is not appendable. @@ -362,11 +457,22 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: block (Block): The block to check for copy-on-write. Returns: - Optional[BlockId]: The block index of the new block if a copy-on - -write operation was performed, or the original block index if + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if no copy-on-write was necessary. """ - return self._cow_tracker.cow_block_if_not_appendable(block) + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. @@ -386,8 +492,8 @@ def mark_blocks_as_accessed(self, block_ids: List[int], """ for block_id in block_ids: - if block_id in self._blocks: - self._blocks[block_id].last_accessed = now + if self._block_tracker[block_id].active: + self._block_tracker[block_id].last_accessed = now elif block_id in self.evictor: self.evictor.update(block_id, now) else: @@ -395,25 +501,46 @@ def mark_blocks_as_accessed(self, block_ids: List[int], "Mark block as accessed which is not belonged to GPU") def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - """Mark blocks as computed, used in prefix caching.""" + raise NotImplementedError("Marking as computed is incremental") - for block_id in block_ids: - if block_id in self._blocks: - # only those full block is valid for prefix caching - if self._blocks[block_id].is_full: - self._blocks[block_id].computed = True - elif block_id not in self.evictor: - raise ValueError(f"Mark {block_id=} as computed which " - "is not belonged to GPU") + def _track_block_id(self, block_id: Optional[BlockId], + computed: bool) -> None: + assert block_id is not None + self._block_tracker[block_id].enable() + self._block_tracker[block_id].computed = computed + + def _untrack_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_tracker[block_id].disable() def block_is_computed(self, block_id: int) -> bool: - if block_id in self._blocks: - return self._blocks[block_id].computed + if self._block_tracker[block_id].active: + return self._block_tracker[block_id].computed else: return block_id in self.evictor + def get_computed_block_ids(self, + prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool = True) -> List[int]: + prev_prefix_size = len(prev_computed_block_ids) + cur_size = len(block_ids) + if skip_last_block_id: + cur_size -= 1 + + # Sanity checks + assert cur_size >= 0 + assert prev_prefix_size <= cur_size + + ret = prev_computed_block_ids + for i in range(prev_prefix_size, cur_size): + block_id = block_ids[i] + if self.block_is_computed(block_id): + ret.append(block_id) + return ret + def get_common_computed_block_ids( - self, seq_block_ids: List[List[int]]) -> List[int]: + self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. Only those blocks that are immutable and already be marked @@ -424,14 +551,9 @@ def get_common_computed_block_ids( # prompt is cached. This would cause erroneous behavior in model # runner. - ids_list = [ - list( - takewhile(lambda block_id: self.block_is_computed(block_id), - seq[:-1])) for seq in seq_block_ids - ] # It returns a list of int although type annotation says list of string. return commonprefix([ - ids for ids in ids_list # type: ignore + ids for ids in computed_seq_block_ids # type: ignore if ids != [] ]) @@ -473,10 +595,10 @@ def swap_out(self, blocks: List[Block]) -> None: blocks: List of blocks to be swapped out. """ for block in blocks: - self.free(block) + self._free_block_id(block) def swap_in(self, blocks: List[Block]) -> None: - """Execute the swap int actions. Change the block id from + """Execute the swap in actions. Change the block id from old allocator to current allocator for each block to finish the block table update. @@ -484,13 +606,22 @@ def swap_in(self, blocks: List[Block]) -> None: blocks: List of blocks to be swapped in. """ for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object if block.is_full: - alloc = self.allocate_immutable(block.prev_block, - block.token_ids) + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) else: - alloc = self.allocate_mutable(block.prev_block) - alloc.append_token_ids(block.token_ids) - block.block_id = alloc.block_id + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id class PrefixCachingBlock(Block): @@ -507,7 +638,7 @@ class PrefixCachingBlock(Block): token_ids (List[int]): The initial token IDs to be stored in the block. block_size (int): The maximum number of token IDs that can be stored in the block. - prefix_caching_allocator (BlockAllocator): The prefix + allocator (BlockAllocator): The prefix caching block allocator associated with this block. block_id (Optional[int], optional): The physical block index of this block. Defaults to None. @@ -518,31 +649,55 @@ def __init__( prev_block: Optional[Block], token_ids: List[int], block_size: int, - prefix_caching_allocator: BlockAllocator, + allocator: BlockAllocator, block_id: Optional[int] = None, computed: bool = False, ): - assert isinstance(prefix_caching_allocator, - PrefixCachingBlockAllocator), ( - "Currently this class is only tested with " - "PrefixCachingBlockAllocator.") + assert isinstance(allocator, PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator. Got instead allocator = {}".format( + allocator)) assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block self._cached_content_hash: Optional[int] = None - self._cached_num_tokens_total: Optional[int] = None - self._prefix_caching_allocator = prefix_caching_allocator + self._cached_num_tokens_total: int = 0 + self._allocator = allocator self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME self._computed = computed - self._block = NaiveBlock( - prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=prefix_caching_allocator, - _cow_target=self, - ) + # On the first time, we create the block object, and next we only + # reinitialize it + if hasattr(self, "_block"): + self._block.__init__( # type: ignore[has-type] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + else: + self._block = NaiveBlock(prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + + self._update_num_tokens_total() + + def _update_num_tokens_total(self): + """Incrementally computes the number of tokens that there is + till the current block (included) + """ + res = 0 + + # Add all previous blocks + if self._prev_block is not None: + res += self._prev_block.num_tokens_total + + # Add current block + res += len(self.token_ids) + + self._cached_num_tokens_total = res @property def computed(self) -> bool: @@ -564,22 +719,28 @@ def append_token_ids(self, token_ids: List[int]) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. - Internally, the naive block handles CoW. - Args: token_ids (List[int]): The token IDs to be appended to the block. """ - assert token_ids + # Ensure this is mutable block (not promoted) + assert self.content_hash is None + assert not self.computed + + if len(token_ids) == 0: + return - # naive block handles CoW. + # Ensure there are input tokens + assert token_ids, "Got token_ids = {}".format(token_ids) + + # Naive block handles CoW. self._block.append_token_ids(token_ids) + self._update_num_tokens_total() # If the content hash is present, then the block can be made immutable. # Register ourselves with the allocator, potentially replacing the # physical block index. if self.content_hash is not None: - self.block_id = (self._prefix_caching_allocator. - promote_to_immutable_block(self)) + self.block_id = self._allocator.promote_to_immutable_block(self) @property def block_id(self) -> Optional[int]: @@ -599,23 +760,6 @@ def num_empty_slots(self) -> int: @property def num_tokens_total(self) -> int: - """return the total tokens so far. - - Here we iterate the block chain till to the first block, while - cache the result in local to prevent repeated computations. - """ - if self._cached_num_tokens_total is not None: - return self._cached_num_tokens_total - - _block: Optional[Block] = self - self._cached_num_tokens_total = 0 - - # TODO: current implement here take O(N^2), we expect future - # we have O(1) here - while _block is not None: - self._cached_num_tokens_total += len(_block.token_ids) - _block = _block.prev_block - return self._cached_num_tokens_total @property @@ -638,7 +782,6 @@ def content_hash(self) -> Optional[int]: For the content-based hash to be defined, the current block must be full. """ - # If the hash is already computed, return it. if self._cached_content_hash is not None: return self._cached_content_hash @@ -688,7 +831,129 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) +class ComputedBlocksTracker: + """Handles caching of per-sequence computed block ids. + When a sequence appears for the first time, it traverses all of the + blocks and detects the prefix of blocks that is computed. On the + subsequent times, it only traverses the new blocks that were added + and updates the already recorded prefix of blocks with the newly + computed blocks. + + To avoid redundant traversals, the algorithm also detects when there + is a "gap" in the computed prefix. For example, if we have blocks = + [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then + we won't try to add more computed blocks to [1,2,3] in this sequence + iteration, and will add more computed blocks only after the sequence is + freed and reused again. + + Note that currently, for a given sequence, we also skip the last + block id for caching purposes, to avoid caching of a full sequence + """ + + def __init__(self, allocator): + self._allocator = allocator + self._cached_computed_seq_blocks: Dict[int, Tuple[List[int], + bool]] = {} + + def add_seq(self, seq_id: int) -> None: + """Start tracking seq_id + """ + assert seq_id not in self._cached_computed_seq_blocks + self._cached_computed_seq_blocks[seq_id] = ([], False) + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking seq_id + """ + assert seq_id in self._cached_computed_seq_blocks + del self._cached_computed_seq_blocks[seq_id] + + def get_cached_computed_blocks_and_update( + self, seq_id: int, block_ids: List[int]) -> List[int]: + """ Look at the class documentation for details + """ + # Ensure seq_id is already tracked + assert seq_id in self._cached_computed_seq_blocks + + # Get cached data (may be empty on the first time) + prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[ + seq_id] + + if has_gap: + # When gap is detected, we do not add more computed blocks at this + # sequence iteration + return prev_computed_block_ids + + # We do not consider the last block id for caching purposes. + num_cur_blocks = len(block_ids) - 1 + assert num_cur_blocks >= 0 + + if len(prev_computed_block_ids) >= num_cur_blocks: + # Cache HIT + assert len(prev_computed_block_ids) == num_cur_blocks + return prev_computed_block_ids + + # If here, then we may possibly add more computed blocks. As a result, + # traverse the additional blocks after prev_computed_block_ids to + # detect more computed blocks and add them. + + # Incremental init for seq_id => Look only at the new blocks + computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501 + prev_computed_block_ids, + block_ids, + skip_last_block_id= + True, # We skip last block id to avoid caching of full seq + ) + + # Detect if there is a "gap" + has_gap = len(computed_block_ids) < num_cur_blocks + + # Record + self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, + has_gap) + + return computed_block_ids + + +class LastAccessBlocksTracker: + """Manages the last access time of the tracked sequences, in order to allow + an efficient update of allocator's block last access times + """ + + def __init__(self, allocator): + self._allocator = allocator + self._seq_last_access: Dict[int, Optional[float]] = {} + + def add_seq(self, seq_id: int) -> None: + """Start tracking seq_id + """ + assert seq_id not in self._seq_last_access + self._seq_last_access[seq_id] = None + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking seq_id + """ + assert seq_id in self._seq_last_access + del self._seq_last_access[seq_id] + + def update_last_access(self, seq_id: int, time: float) -> None: + assert seq_id in self._seq_last_access + self._seq_last_access[seq_id] = time + + def update_seq_blocks_last_access(self, seq_id: int, + block_ids: List[int]) -> None: + assert seq_id in self._seq_last_access + + ts = self._seq_last_access[seq_id] + + if ts is None: + # No last access was recorded, no need to update. + return + + self._allocator.mark_blocks_as_accessed(block_ids, ts) + + def assert_prefix_caching_block_or_none(block: Optional[Block]): if block is None: return - assert isinstance(block, PrefixCachingBlock) + assert isinstance(block, + PrefixCachingBlock), "Got block = {}".format(block) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 309775237a715..6a6eebc39c58e 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -7,6 +7,8 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.interfaces import Block +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus @@ -100,6 +102,11 @@ def __init__( self.block_tables: Dict[SeqId, BlockTable] = {} self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} + self._computed_blocks_tracker = ComputedBlocksTracker( + self.block_allocator) + self._last_access_blocks_tracker = LastAccessBlocksTracker( + self.block_allocator) + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. @@ -157,10 +164,18 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table: BlockTable = self._allocate_sequence(seq) self.block_tables[seq.seq_id] = block_table + # Track seq + self._computed_blocks_tracker.add_seq(seq.seq_id) + self._last_access_blocks_tracker.add_seq(seq.seq_id) + # Assign the block table for each sequence. for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() + # Track seq + self._computed_blocks_tracker.add_seq(seq.seq_id) + self._last_access_blocks_tracker.add_seq(seq.seq_id) + # Allocate cross-attention block table for encoder sequence # # NOTE: Here we assume that all sequences in the group have the same @@ -224,11 +239,23 @@ def append_slots( return new_cows def free(self, seq: Sequence) -> None: - if seq.seq_id not in self.block_tables: + seq_id = seq.seq_id + + if seq_id not in self.block_tables: # Already freed or haven't been scheduled yet. return - self.block_tables[seq.seq_id].free() - del self.block_tables[seq.seq_id] + + # Update seq block ids with the latest access time + self._last_access_blocks_tracker.update_seq_blocks_last_access( + seq_id, self.block_tables[seq.seq_id].physical_block_ids) + + # Untrack seq + self._last_access_blocks_tracker.remove_seq(seq_id) + self._computed_blocks_tracker.remove_seq(seq_id) + + # Free table/blocks + self.block_tables[seq_id].free() + del self.block_tables[seq_id] def free_cross(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id @@ -239,9 +266,7 @@ def free_cross(self, seq_group: SequenceGroup) -> None: del self.cross_block_tables[request_id] def get_block_table(self, seq: Sequence) -> List[int]: - assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids - assert all(b is not None for b in block_ids) return block_ids # type: ignore def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: @@ -252,20 +277,14 @@ def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: return block_ids # type: ignore def access_all_blocks_in_seq(self, seq: Sequence, now: float): - # Update the last accessed time of all the blocks accessed - # in this step. - # And the accessed time is only useful for prefix caching now, - # as it support internal evictor policy for which cached - # block could be refilled, to keep cached content could be reused - # at max extend. if self.enable_caching: - block_table = self.block_tables[seq.seq_id] - block_ids: List[Optional[int]] = [] - for block_id in block_table.physical_block_ids: - block_ids.append(block_id) - self.block_allocator.mark_blocks_as_accessed( - block_ids, # type: ignore - now) + # Record the latest access time for the sequence. The actual update + # of the block ids is deferred to the sequence free(..) call, since + # only during freeing of block ids, the blocks are actually added to + # the evictor (which is when the most updated time is required) + # (This avoids expensive calls to mark_blocks_as_accessed(..)) + self._last_access_blocks_tracker.update_last_access( + seq.seq_id, now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): # The only need for mark block as computed is for prefix caching, @@ -285,17 +304,26 @@ def get_common_computed_block_ids( This method determines which blocks can be safely skipped for all sequences in the sequence group. """ - seq_block_ids = [ - self.block_tables[seq.seq_id].physical_block_ids for seq in seqs - ] + computed_seq_block_ids = [] + for seq in seqs: + computed_seq_block_ids.append( + self._computed_blocks_tracker. + get_cached_computed_blocks_and_update( + seq.seq_id, + self.block_tables[seq.seq_id].physical_block_ids)) + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( - seq_block_ids) # type: ignore + computed_seq_block_ids) # type: ignore def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.fork() + # Track child seq + self._computed_blocks_tracker.add_seq(child_seq.seq_id) + self._last_access_blocks_tracker.add_seq(child_seq.seq_id) + def can_swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> AllocStatus: """Returns the AllocStatus for the given sequence_group @@ -323,19 +351,31 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: List[Tuple[int, int]]: The mapping of swapping block from CPU to GPU. """ - blocks = self._get_blocks_for_swap(seq_group, SequenceStatus.SWAPPED) - current_swap_mapping = self.block_allocator.swap( - blocks=blocks, source_device=Device.CPU, dest_device=Device.GPU) - - block_number_mapping = { - self.block_allocator.get_physical_block_id(Device.CPU, - cpu_block_id): - self.block_allocator.get_physical_block_id(Device.GPU, - gpu_block_id) - for cpu_block_id, gpu_block_id in current_swap_mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.CPU, + dst_device=Device.GPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id): + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id) + for cpu_block_id, gpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping def can_swap_out(self, seq_group: SequenceGroup) -> bool: """Returns whether we can swap out the given sequence_group @@ -355,7 +395,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: return True return False - def swap_out(self, sequence_group: SequenceGroup) -> List[Tuple[int, int]]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: """Returns the block id mapping (from GPU to CPU) generated by swapping out the given sequence_group with num_lookahead_slots. @@ -366,19 +406,31 @@ def swap_out(self, sequence_group: SequenceGroup) -> List[Tuple[int, int]]: List[Tuple[int, int]]: The mapping of swapping block from GPU to CPU. """ - blocks = self._get_blocks_for_swap(sequence_group, - SequenceStatus.RUNNING) - current_swap_mapping = self.block_allocator.swap( - blocks=blocks, source_device=Device.GPU, dest_device=Device.CPU) - block_number_mapping = { - self.block_allocator.get_physical_block_id(Device.GPU, - gpu_block_id): - self.block_allocator.get_physical_block_id(Device.CPU, - cpu_block_id) - for gpu_block_id, cpu_block_id in current_swap_mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.GPU, + dst_device=Device.CPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id): + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id) + for gpu_block_id, cpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping def get_num_free_gpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.GPU) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7e38c0e6b948..30aae30c70c6b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -177,7 +177,8 @@ def __init__( "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s)", + "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "enable_prefix_caching=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -204,6 +205,8 @@ def __init__( observability_config, model_config.seed, model_config.served_model_name, + scheduler_config.use_v2_block_manager, + cache_config.enable_prefix_caching, ) # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 8741893c92716..1bd0956553884 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -345,7 +345,7 @@ def request_output_to_completion_response( out_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: - token_ids = prompt_token_ids + output.token_ids + token_ids = prompt_token_ids + list(output.token_ids) out_logprobs = (prompt_logprobs + output.logprobs if request.logprobs is not None else None) output_text = prompt_text + output.text diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index f95de56f39b57..ad5fb13176edc 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -427,8 +427,8 @@ def from_sampling_metadata( if seq_group.do_sample: for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) + prompt_tokens.append(list(seq_data.prompt_token_ids)) + output_tokens.append(list(seq_data.output_token_ids)) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, diff --git a/vllm/outputs.py b/vllm/outputs.py index 49f526b5f9300..4cb7f06bdb8c7 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from vllm.lora.request import LoRARequest from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, @@ -28,7 +28,7 @@ class CompletionOutput: index: int text: str - token_ids: List[int] + token_ids: Tuple[int, ...] cumulative_logprob: float logprobs: Optional[SampleLogprobs] finish_reason: Optional[str] = None diff --git a/vllm/sequence.py b/vllm/sequence.py index 22cb26dc08ef7..21c558d4483da 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -116,41 +116,66 @@ def __init__( prompt_token_ids: List[int], output_token_ids: Optional[List[int]] = None, ) -> None: - if output_token_ids is None: - output_token_ids = [] + self._prompt_token_ids: List[int] = list(prompt_token_ids) + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) + self._output_token_ids: List[int] = ( + list(output_token_ids) if output_token_ids is not None else []) - self.prompt_token_ids = prompt_token_ids - self._prompt_token_ids_tuple = tuple(prompt_token_ids) - self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL + self._update_cached_all_tokens() + + def _update_cached_all_tokens(self): + self._cached_all_token_ids: List[int] = (self._prompt_token_ids + + self._output_token_ids) + + @property + def prompt_token_ids(self) -> Tuple[int, ...]: + return self._prompt_token_ids_tuple + + @prompt_token_ids.setter + def prompt_token_ids(self, new_prompt_token_ids) -> None: + self._prompt_token_ids = list(new_prompt_token_ids) + self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) + self._update_cached_all_tokens() + + @property + def output_token_ids(self) -> Tuple[int, ...]: + return tuple(self._output_token_ids) + + @output_token_ids.setter + def output_token_ids(self, new_output_token_ids) -> None: + self._output_token_ids = list(new_output_token_ids) + self._update_cached_all_tokens() + def append_token_id(self, token_id: int, logprob: float) -> None: - self.output_token_ids.append(token_id) + self._output_token_ids.append(token_id) + self._cached_all_token_ids.append(token_id) self.cumulative_logprob += logprob def get_len(self) -> int: - return len(self.output_token_ids) + len(self.prompt_token_ids) + return len(self._output_token_ids) + len(self._prompt_token_ids) def get_prompt_len(self) -> int: - return len(self.prompt_token_ids) + return len(self._prompt_token_ids) def get_output_len(self) -> int: - return len(self.output_token_ids) + return len(self._output_token_ids) def get_token_ids(self) -> List[int]: - return self.prompt_token_ids + self.output_token_ids + return self._cached_all_token_ids def get_prefix_token_ids( self, num_tokens: int ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: """Get prefix tokens, and make the return value hashable""" - prompt_length = len(self.prompt_token_ids) + prompt_length = self.get_prompt_len() if num_tokens > prompt_length: return (self._prompt_token_ids_tuple, - tuple(self.output_token_ids[:num_tokens - prompt_length])) + tuple(self._output_token_ids[:num_tokens - prompt_length])) else: return (self._prompt_token_ids_tuple[:num_tokens], None) @@ -183,14 +208,14 @@ def get_num_uncomputed_tokens(self) -> int: return self.get_len() - self.get_num_computed_tokens() def get_last_token_id(self) -> int: - if not self.output_token_ids: - return self.prompt_token_ids[-1] - return self.output_token_ids[-1] + if not self._output_token_ids: + return self._prompt_token_ids[-1] + return self._output_token_ids[-1] - def get_prompt_token_ids(self) -> List[int]: + def get_prompt_token_ids(self) -> Tuple[int, ...]: return self.prompt_token_ids - def get_output_token_ids(self) -> List[int]: + def get_output_token_ids(self) -> Tuple[int, ...]: return self.output_token_ids @property @@ -199,8 +224,8 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" - f"prompt_token_ids={self.prompt_token_ids}, " - f"output_token_ids={self.output_token_ids}, " + f"prompt_token_ids={self._prompt_token_ids}, " + f"output_token_ids={self._output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob})") @@ -306,14 +331,14 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.data.get_token_ids() - def get_prompt_token_ids(self) -> List[int]: + def get_prompt_token_ids(self) -> Tuple[int, ...]: return self.data.get_prompt_token_ids() def get_last_token_id(self) -> int: return self.data.get_last_token_id() - def get_output_token_ids(self) -> List[int]: - return self.data.output_token_ids + def get_output_token_ids(self) -> Tuple[int, ...]: + return self.data.get_output_token_ids() def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob