Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][WIP] 2nd try of Hybrid allocator for full attention & sliding window attention interleaved models #13296

Draft
wants to merge 56 commits into
base: main
Choose a base branch
from

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Feb 14, 2025

The goal of this PR is:

  1. For hybrid models with interleaved full attention & sliding window attention, only allocate block for kv cache of full attention layers and kv cache inside the sliding window, and still keep the prefix caching compatibility.
  2. For pure sliding window models, drop the kv cache outside the sliding window and keep the prefix cache compatibility.

High Level Idea of Hybrid Allocator

When the model becomes hybrid, we want each layer have their own block_table, so that we can allocate different number of blocks to different layers. For example, allocate blocks to all tokens in the full attention layers, and only allocate blocks to tokens in the sliding window for sliding window layers.

However, allocating blocks for each layer is inefficient as we need to perform the allocation num_layer times for every block_size tokens. So we introduce KV Cache Group to group layers that can share the same block_table. We need to make sure the following properties:

  1. Layers in each group can share the same block_table, e.g., to avoid memory waste, we should not let a full attention layer and a sliding window layer share the same block_table.
  2. All groups have the same number of layers, to make sure the page size is uniform for all groups.

For example, mistralai/Ministral-8B-Instruct-2410 has 9 full attention layers and 27 sliding window layers. We can group the layers as:

  • Group 0: full.0, full.1, ..., full.8;
  • Group 1: sliding.0, sliding.1, ..., sliding.8;
  • Group 2: sliding.9, sliding.10, ..., sliding.17;
  • Group 3: sliding.18, sliding.19, ..., sliding.26.

Each group has 9 layers (property 2), and each group either contains full attention layers or sliding window layers (property 1). We can allocate blocks for each group, and share the block_table among layers in the same group.

To achieve this, we need to add the group dimension to the block_table, changing it from List[int] to List[List[int]], and support the block allocation & model execution with group dimension. We also need to make abstraction to different type of layers, so that we can manage them in the same way.

For layer abstraction, we introduce SpecializedManager with the following interfaces:

  • get_possible_cached_prefix to detect the hitted prefix based on the hit rule of the layer;
  • get_num_new_blocks to get the number of new blocks needed to be allocated for the layer;
  • remove_useless_blocks to free the blocks that are not used anymore;

This PR implements this interface for full attention layers and sliding window layers. The abstraction also helps the current KVCacheManager to support sliding window models.

For block allocation, we introduce HybridKVCacheManager with the same interface as KVCacheManager, but support multiple groups:

  • get_computed_blocks finds the longest prefix that is a cache hit prefix of all groups;
  • allocate_slots sums up the number of new blocks needed for each group to check whether we can allocate blocks for the given request, and then allocate blocks for each group;
  • free frees the blocks for each group, sorting them by the eviction order and put them back to the free_block_queue;
  • get_num_common_prefix_blocks is computed for each group.

Note that the hybrid allocator still do allocation / free / prefix cache at block level, so we can reuse most operation on _block_pool, _free_block_queue, and _cached_block_hash_to_block. We move the operations on these objects from KVCacheManager to a new BlockPool class in v1/core/block_pool.py to avoid code duplication of the two managers.

For model execution, we

  1. Introduce GroupedBlockTable to record the block_table in the device side. It contains multiple BlockTable instance, one for each group. The GroupedBlockTable has the same interface with BlockTable, and will broadcast the operations to all BlockTable instances.
    # block_table.py
    class GroupedBlockTable:
        def __init__(self, ...):
            ...
            for f_name in ("move_row", "commit", "clear"):
                setattr(self, f_name, self._make_grouped_func(f_name))
            ...
    
        def _make_grouped_func(self, f_name: str) -> Callable[P, None]:
    
            def grouped_func(*args: P.args, **kwargs: P.kwargs) -> None:
                for block_table in self.block_tables:
                    getattr(block_table, f_name)(*args, **kwargs)
    
            return grouped_func
  2. Change the attention metadata from a global one to all layers, to a Dict[layer_name, AttentionMetadata] to allow different layers having different block_table. The AttentionMetadata of each group is built once, and used by layers in that group:
    # gpu_model_runner.py:
    for group_id, kv_cache_group in enumerate(self.kv_cache_config.groups):
        attn_metadata_of_group = FlashAttentionMetadata(...)
        for layer_name in kv_cache_group.layer_names:
            attn_metadata[layer_name] = attn_metadata_of_group
  3. The memory layout: Instead of using num_layer seperate tensors to saving the kv cache at runtime, we use num_layer/num_group tensors now, and each tensor is shared by one layer in each group. As different groups owns different block_id, the memory allocated for each group is not overlapped. For example, in the above Ministral model, the memory becomes:
    • Tensor.0: full.0, sliding.0, sliding.9, sliding.18;
    • Tensor.1: full.1, sliding.1, sliding.10, sliding.19;
    • ...
    • Tensor.8: full.8, sliding.8, sliding.17, sliding.26.

To avoid introducing overhead to uniform models, we still keep the original KVCacheManager & BlockTable. When the model has only one group, we use KVCacheManager & BlockTable, passing block_table with type List[int] between them in SchedulerOutput, which is the same as before. When the model has multiple groups, we use HybridKVCacheManager & GroupedBlockTable, passing block_table with type List[List[int]] between them in SchedulerOutput. Therefore, the type of block_table in NewRequestData and CachedRequestData becomes MayGroupedBlockIDs:

class GroupedBlockIDs:
    _block_ids: List[List[int]]

MayGroupedBlockIDs = Union[GroupedBlockIDs, List[int]]

Detailed Modifications

  • v1/core/specialized_manager.py makes some abstraction for the different logic of full attention layer and sliding window layer. In addition to support hybrid allocator, we also use these abstraction to support pure sliding window model in KVCacheManager (v1/core/kv_cache_manager.py).

  • v1/core/kv_cache_manager.py], v1/core/hybrid_kv_cache_manager.py, v1/core/block_pool.py two kv cache managers with the same interface. The common functions are in BlockPool class in block_pool.py.

  • [v1/core/hybrid_kv_cache_manager.py] The HybridKVCacheManager, see its implementation in the above section.

  • [attention/layer.py & v1/worker/gpu_model_runner.py] Change the attention metadata to a Dict[layer_name, AttentionMetadata].

  • [v1/core/kv_cache_interface.py, v1/core/kv_cache_utils.py, v1/worker/gpu_model_runner.py]

    • Group the layers by is_kv_cache_page_size_uniform and _get_kv_cache_config_uniform_page_size, and tell the worker the kv cache memory layout with KVCacheConfig.tensors. In the ministral example, it will be:
    {
        "full.0": KVCacheNewTensor(size=xxx),
        "sliding.0": KVCacheReuseTensor(reused_layer_name="full.0"),
        "sliding.9": KVCacheReuseTensor(reused_layer_name="full.0"),
        "sliding.18": KVCacheReuseTensor(reused_layer_name="full.0"),
        ...
    }
    
  • [config.py] Small interface change of KVCacheSpec inside KVCacheConfig: move from dict[layer_name, KVCacheSpec] to an attribute in each group (KVCacheGroup.kv_cache_spec).

  • [v1/core/kv_cache_utils.py]:

    • Add kv_cache_group_id to block hash, to know which group the block belongs to.
    • Add notes on the null block with block_id=-1
  • [v1/core/scheduler.py]

    • change block_id from List[int] to MaybeGroupedBlockID.
    • when if num_new_tokens == 0: and hybrid manager, ignore all computed tokens as a temporary solution. (It would be easier to handle this corner case in HybridKVCacheManager instead of the scheduler.)
  • [v1/worker/block_table.py, v1/worker/gpu_model_runner] add GroupedBlockTable. Small interface changes of BlockTable class to make the broadcast easier, e.g., remove start_index argumenet in append_row.

    • Also move the slot_mapping_cpu into BlockTable as each group needs one instance.
    • The number of block table dependes on the number of kv cache groups, so initialize the InputBatch, which contains the BlockTable, in GPUModelRunner.__init__ instead of GPUModelRunner._initialize_kv_caches.
  • [forward_context.py] When AttentionMetadata becomes a dict for each layer, it will be difficult to get the global information that is held by all layers, e.g., num_input_tokens. Therefore, add ForwardMetadata to save the global information.

  • [v1/request.py] cherry-pick [V1] Move KV block hashes from Request to KVCacheManager #12922

  • [v1/worker/gpu_model_runner.py] remove model-related args like num_attn_layers, num_query_heads, num_kv_heads. They are available in kv-cache-spec.

RFC #11382

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
heheda12345 and others added 21 commits February 4, 2025 05:28
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Feb 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
@hmellor
Copy link
Collaborator

hmellor commented Feb 18, 2025

Is this PR also built on top of #12086?

@heheda12345
Copy link
Collaborator Author

Is this PR also built on top of #12086?

Yes, this pr have included the commits in #12086

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Feb 25, 2025

Hi @heheda12345, thanks for the great work and sorry again for the delays in my review. This work is truly amazing!

Now I think I (almost) fully understand the idea. Most changes proposed in this PR look very reasonable and to be the "right" thing to do. What I'm unsure about are:

  1. How to refine some concepts here. For example, "group" and "page" in this PR are a bit confusing.
  2. How to minimize the performance overheads, for both non-SWA models and hybrid models.
  3. How much we should/could overfit to the N:1 hybrid models.

Let me write down my understanding here:

To my understanding, this PR starts from the observation that if a certain pattern of layers is repeating, we can use the symmetry to simplify the memory view. For instance, if a single type of attention repeats for the entire layers, we can view them as if the model only has a single layer. For another example, for models with N:1 sliding & global attention layers, we can consider as if the model only has N+1 layers. This is essentially the concept of "group" in this PR.

[Confusion 1] However, the N+1 layers are NOT "group". IIRC, there's no name for this set of layers. I feel like these N+1 layers should be defined as a "group"...

Then, the N+1 layers need to dynamically share a fixed amount of memory space. A good news is that, for sliding & global attention mix, both types of attention require the same size of memory (e.g., 2 * num_kv_heads * head_size) for each token and use the same block size (i.e., number of tokens per "block"). Thanks to this, we can easily avoid memory fragmentation issues.

Let's say L0 uses global attention while L1, ..., LN use sliding window attention. Then, in this PR, we create a cache manager for each layer; N+1 managers in total. These managers are controlled by HybridKVCacheManager.

[Proposal 1] While this may be contradictory with my previous suggestion, now I feel that it's nicer to name this AttnCacheManager or MemoryManager and use it for all cases (i.e., no special case for global-attention-only models).

[Question 1] I understand that we need N+1 block tables. However, do we really need N+1 memory managers? I feel like we can have 2 managers, one for global attention and another for SWA, and have the SWA memory manager manage the blocks for N layers. This could potentially reduce the overheads.

For allocation, free, and prefix caching, HybridKVCacheManager invokes the N+1 layer-wise memory managers and aggregates their outputs. Each memory manager gets/returns blocks from/to BlockPool. This process is pretty straightforward for global attention. For sliding window attention, however, this is quite complicated.

[Confusion 2] I don't fully understand how we manage the memory for sliding window attention (apologies if I missed something). Especially, I find it difficult to understand how the prefix caching works when it's mixed with global attention. I think we should be clearer about this.

[Question 2] Specifically, we need to provide clear answers to the following questions:

  1. What's the shape of the block table for SWA? Is it append-only?
  2. Let's say the window size is 1K and the input prompt length is 2K. How many slots do we allocate? Do we always allocate 1K slots at max? Or, do we allocate 2K slots initially and drop the first 1K after the first iteration.
  3. How does it handle sliding window when block size > 1? What assumptions does it make on the attention kernel?
  4. How does it handle prefix caching?
  5. How does it handle cascade attention?

The changes in model executor seem relatively easy to understand. We only need to do some extra work because we now have N+1 block tables instead of one.

[Proposal 2] Do we really need GroupedBlockTable (btw here the term group is confusing again 😅)? Can we just use a single tensor of shape [N+1, num_blocks, max_model_len // block_size]?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants