diff --git a/vllm/block.py b/vllm/block.py index 2cc6b947f2255..e7fb29c8c2c61 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -1,5 +1,7 @@ """Token blocks.""" -from typing import List +import weakref +from collections import defaultdict +from typing import Dict, List from vllm.utils import Device @@ -7,6 +9,35 @@ DEFAULT_LAST_ACCESSED_TIME = -1 +TokensBlock = List[int] + + +class BlockPool: + """A pool of physical blocks. + When requests come, we create a lot of logical blocks; + when requests are done, we destroy a lot of logical blocks. + It turns out that creating and destroying logical blocks can be expensive, + especially for the `token_ids` field, which is a list of integers. + To avoid this overhead, we use a pool to manage the logical blocks. + When an old request is done and a new request comes, we can reuse the + logical blocks from the old request to feed the new request. + """ + + def __init__(self) -> None: + # block size to list of token blocks + self.pool: Dict[int, List[TokensBlock]] = defaultdict(list) + + def alloc_block(self, block_size: int) -> TokensBlock: + if block_size in self.pool and self.pool[block_size]: + return self.pool[block_size].pop() + return [_BLANK_TOKEN_ID] * block_size + + def del_block(self, block: TokensBlock) -> None: + self.pool[len(block)].append(block) + + +_BLOCK_POOL = BlockPool() + class LogicalTokenBlock: """A block that stores a contiguous chunk of tokens from left to right. @@ -23,7 +54,13 @@ def __init__( self.block_number = block_number self.block_size = block_size - self.token_ids = [_BLANK_TOKEN_ID] * block_size + self.token_ids = _BLOCK_POOL.alloc_block(block_size) + # this finalizer is used to return the block to the pool when the object is deleted # noqa + # NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa + # i.e. `self.token_ids` may be deleted before `self`, and we lose + # the opportunity to return the block to the pool + self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block, + self.token_ids) self.num_tokens = 0 def is_empty(self) -> bool: