Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Autotune encoder cache budget #11895

Merged
merged 15 commits into from
Jan 15, 2025
13 changes: 13 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,16 @@ class SchedulerConfig:

is_multimodal_model: bool = False

# NOTE: The following multimodal encoder budget will be initialized to
# max_num_batched_tokens and overridden in case max multimodal embedding
# size is larger.
# TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1
max_num_encoder_input_tokens: int = field(default=None) # type: ignore

# Multimodal encoder cache size, only used in V1
encoder_cache_size: int = field(default=None) # type: ignore

# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
Expand Down Expand Up @@ -1451,6 +1461,9 @@ def __post_init__(self) -> None:
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens

if self.enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
Expand Down
94 changes: 30 additions & 64 deletions vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Set, Tuple

from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.request import Request

if TYPE_CHECKING:
Expand Down Expand Up @@ -56,10 +55,10 @@ def get_freed_ids(self) -> List[Tuple[str, int]]:
return freed


def compute_encoder_cache_budget(
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> int:
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.

Expand All @@ -68,26 +67,28 @@ def compute_encoder_cache_budget(
scheduler_config: Scheduler configuration.

Returns:
The encoder cache budget, in unit of number of tokens
in the input sequence.
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""

encoder_cache_budget = 0

if not model_config.is_multimodal_model:
return encoder_cache_budget
return 0, 0

# TODO: handle encoder-decoder models once we support them.
encoder_cache_budget, _, _ = compute_encoder_cache_budget_multimodal(
model_config, scheduler_config)
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)

return encoder_cache_budget
return encoder_compute_budget, encoder_cache_size


def compute_encoder_cache_budget_multimodal(
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why do we need this separate function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually we might need to use compute_encoder_budget for enc-dec models (and thus different encoder budget logic), so separating it out for now so that it's not tied to multimodal models.

scheduler_config: "SchedulerConfig",
) -> tuple[int, Optional[str], int]:
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.

Expand All @@ -96,14 +97,12 @@ def compute_encoder_cache_budget_multimodal(
scheduler_config: Scheduler configuration.

Returns:
- The encoder cache budget, in unit of number of tokens in the
input sequence.
- The modality of the multimodal item that requires the most tokens.
- The number of multimodal items used to compute the encoder cache
budget.
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""

encoder_cache_budget = 0
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)

Expand All @@ -112,47 +111,14 @@ def compute_encoder_cache_budget_multimodal(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return encoder_cache_budget, None, 0

modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])

max_num_batched_tokens = scheduler_config.max_num_batched_tokens
max_num_reqs = scheduler_config.max_num_seqs

# The biggest possible multimodal item cannot be fully prefilled in a
# batch, so every batch can partially prefill at most one of such item.
if max_tokens_per_mm_item > max_num_batched_tokens:
num_items = 1

# A batch can fully cover multiple biggest possible multimodal items, and
# one that will be partially prefilled.
else:
num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item)

# NOTE: We need the encoder cache to be able to compute & hold ONE
# ADDITIONAL multimodal item, and is required only when:
# - Two requests in the current batch share the same prefix with such item
# as part of the prefix.
# - AND the prefix length is divisible by the block size, triggering the
# recomputation of the last block.
# - AND the part of the embeddings of the item is in this last block.

# This issue can be fundamentally resolved by supporting num_new_tokens=0
# on the model runner.
num_items += 1

# Number of items needed cannot be bigger than max number of running
# requests * max number of multimodal items per request.
max_mm_items_per_req = max(
MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values())

num_items = min(num_items, max_num_reqs * max_mm_items_per_req)
encoder_cache_budget = num_items * max_tokens_per_mm_item

logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_cache_budget, num_items, modality)

return encoder_cache_budget, modality, num_items
return 0, 0

_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])

encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
max_tokens_per_mm_item)

return encoder_compute_budget, encoder_cache_size
12 changes: 7 additions & 5 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_cache_budget)
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -74,18 +74,20 @@ def __init__(
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_cache_budget = compute_encoder_cache_budget(
model_config, scheduler_config)
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)

# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
self.max_num_encoder_input_tokens = encoder_cache_budget
self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_budget)
cache_size=encoder_cache_size)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
50 changes: 41 additions & 9 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import (
compute_encoder_cache_budget, compute_encoder_cache_budget_multimodal)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -90,8 +89,12 @@ def __init__(
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False

self.encoder_cache_budget = compute_encoder_cache_budget(
self.model_config, self.scheduler_config)
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size

# Lazy initialization
# self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -724,15 +727,44 @@ def profile_run(self) -> None:

# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
if self.is_multimodal_model and self.encoder_cache_budget > 0:
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
and self.encoder_cache_size > 0):

# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
_, dummy_data_modality, max_num_mm_items = compute_encoder_cache_budget_multimodal( # noqa: E501
self.model_config,
self.scheduler_config,
)
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
self.model_config)
dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])

# Check how many items of this modality can be supported by
# the encoder budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)

max_num_mm_items_encoder_budget = cdiv(encoder_cache_budget,
max_tokens_per_mm_item)

# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())

# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req

max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)

logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_cache_budget, max_num_mm_items, dummy_data_modality)

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
Expand Down
Loading