Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Autotune encoder cache budget #11895

Merged
merged 15 commits into from
Jan 15, 2025
15 changes: 10 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,13 +1383,15 @@ class SchedulerConfig:

is_multimodal_model: bool = False

# FIXME(woosuk & ywang96): Below are placeholder values. We need to
# calculate the actual values from the configurations.
# Multimodal encoder run compute budget, only used in V1
max_num_encoder_input_tokens = 16384
# NOTE: The following multimodal encoder budget will be initialized to
# max_num_batched_tokens and overridden in case max multimodal embedding
# size is larger.
# TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1
max_num_encoder_input_tokens: int = field(default=None) # type: ignore

# Multimodal encoder cache size, only used in V1
encoder_cache_size = 16384
encoder_cache_size: int = field(default=None) # type: ignore

# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
Expand Down Expand Up @@ -1463,6 +1465,9 @@ def __post_init__(self) -> None:
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens

if self.enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
Expand Down
29 changes: 24 additions & 5 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,8 @@ def get_max_tokens_per_item_by_modality(
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.

Note:
This is currently directly used only in V1.
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(
Expand All @@ -272,6 +269,28 @@ def get_max_tokens_per_item_by_modality(
for key, plugin in self._plugins.items()
}

def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.

Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
limits_per_plugin = self._limits_by_model[model_config]

return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0
}

def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
Expand Down
78 changes: 77 additions & 1 deletion vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Dict, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Set, Tuple

from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.v1.request import Request

if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig

logger = init_logger(__name__)


class EncoderCacheManager:

Expand Down Expand Up @@ -46,3 +53,72 @@ def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed


def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.

Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.

Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""

if not model_config.is_multimodal_model:
return 0, 0

# TODO: handle encoder-decoder models once we support them.
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)

return encoder_compute_budget, encoder_cache_size


def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why do we need this separate function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually we might need to use compute_encoder_budget for enc-dec models (and thus different encoder budget logic), so separating it out for now so that it's not tied to multimodal models.

scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.

Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.

Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""

max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)

if not max_tokens_by_modality_dict:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return 0, 0

_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])

encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
max_tokens_per_mm_item)

return encoder_compute_budget, encoder_cache_size
26 changes: 18 additions & 8 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
Expand All @@ -25,6 +26,7 @@ class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
Expand Down Expand Up @@ -69,16 +71,24 @@ def __init__(
self.running_reqs_data: Dict[str, RunningRequestData] = {}

# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)

# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(
cache_size=self.scheduler_config.encoder_cache_size)
cache_size=encoder_cache_size)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
9 changes: 6 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ def __init__(
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config)
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
)

self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
Expand Down
60 changes: 32 additions & 28 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -88,8 +89,12 @@ def __init__(
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False

self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size

# Lazy initialization
# self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -721,44 +726,30 @@ def profile_run(self) -> None:
]

# Profile with multimodal encoder & encoder cache.
if self.is_multimodal_model:

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
# TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
and self.encoder_cache_size > 0):

# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
self.model_config)

dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])

# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item

# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")
# the encoder budget.
encoder_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)

max_num_mm_items_encoder_budget = cdiv(encoder_budget,
max_tokens_per_mm_item)

# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
self.model_config)[dummy_data_modality]

# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
Expand All @@ -769,6 +760,19 @@ def profile_run(self) -> None:
max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)

logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_budget, max_num_mm_items, dummy_data_modality)

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
Comment on lines +768 to +774
Copy link
Member Author

@ywang96 ywang96 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is just a reordering for better readability.


# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
Expand Down
Loading