From 495f669f8154237f0a89ac6325928f38e392df54 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 12:19:12 +0000 Subject: [PATCH 01/13] initial Signed-off-by: Roger Wang --- vllm/config.py | 8 ---- vllm/v1/core/encoder_cache_manager.py | 64 ++++++++++++++++++++++++++- vllm/v1/core/scheduler.py | 24 ++++++---- vllm/v1/engine/core.py | 9 ++-- vllm/v1/worker/gpu_model_runner.py | 38 ++++------------ 5 files changed, 93 insertions(+), 50 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 19609085cc960..883cad05a0323 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1379,14 +1379,6 @@ class SchedulerConfig: is_multimodal_model: bool = False - # FIXME(woosuk & ywang96): Below are placeholder values. We need to - # calculate the actual values from the configurations. - # Multimodal encoder run compute budget, only used in V1 - max_num_encoder_input_tokens = 16384 - - # Multimodal encoder cache size, only used in V1 - encoder_cache_size = 16384 - # Whether to perform preemption by swapping or # recomputation. If not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 845bd5ea05e3c..1ddc179d8bf96 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,7 +1,15 @@ -from typing import Dict, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Set, Tuple +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.utils import cdiv from vllm.v1.request import Request +if TYPE_CHECKING: + from vllm.config import ModelConfig, SchedulerConfig + +logger = init_logger(__name__) + class EncoderCacheManager: @@ -46,3 +54,57 @@ def get_freed_ids(self) -> List[Tuple[str, int]]: freed = self.freed self.freed = [] return freed + + +def compute_encoder_cache_budget( + model_config: "ModelConfig", + scheduler_config: "SchedulerConfig", +) -> int: + """Compute the encoder cache budget based on the model and scheduler configurations.""" + + encoder_cache_budget = 0 + if not model_config.is_multimodal_model: + return encoder_cache_budget + + max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_modality( # noqa: E501 + model_config) + + modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), + key=lambda item: item[1]) + + max_num_batched_tokens = scheduler_config.max_num_batched_tokens + max_num_seqs = scheduler_config.max_num_seqs + + # In case that the biggest possible multimodal item takes space more + # than the batch size, then it needs to be cached and chunk prefilled. + if max_tokens_per_mm_item > max_num_batched_tokens: + num_items = 1 + + # In case that the biggest possible multimodal item takes space less + # the batch size, then all items will be full prefilled except one. + else: + num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item) + + # NOTE: We need the encoder cache to be able to compute & hold ONE + # ADDITIONAL multimodal item, and is required only when: + # - Two requests in the current batch share the same prefix with such item + # as part of the prefix. + # - AND the prefix length is divisible by the block size, triggering the + # recomputation of the last block. + # - AND the part of the embeddings of the item is in this last block. + + # This can be improved when we have a global encoder cache that does + # not associate items to request id only. + num_items += 1 + + # Number of items needed cannot be bigger than max number of running + # sequences. + num_items = min(num_items, max_num_seqs) + + encoder_cache_budget = num_items * max_tokens_per_mm_item + logger.info( + "Encoder cache will be initialized with a budget of %s tokens, and " + "profiled with %s %s items of the maximum feature size.", encoder_cache_budget, num_items, + modality) + + return encoder_cache_budget diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index b26716f5c02e6..51ca9aa924404 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -3,10 +3,11 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, Tuple, Union) -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.v1.core.encoder_cache_manager import EncoderCacheManager +from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, + compute_encoder_cache_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput from vllm.v1.outputs import ModelRunnerOutput @@ -24,6 +25,7 @@ class Scheduler: def __init__( self, scheduler_config: SchedulerConfig, + model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], ) -> None: @@ -68,16 +70,22 @@ def __init__( self.running_reqs_data: Dict[str, RunningRequestData] = {} # Encoder-related. + # Calculate encoder cache size if applicable + # NOTE: For now we use the same budget for both compute and space. + # This can be changed when we make encoder cache for embedding caching + # across requests. + encoder_cache_budget = compute_encoder_cache_budget( + model_config, scheduler_config) + # NOTE(woosuk): Here, "encoder" includes the vision encoder (and # projector if needed). Currently, we assume that the encoder also # has the Transformer architecture (e.g., ViT). - self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501 - # NOTE(woosuk): For the models without encoder (e.g., text-only models), - # the encoder cache will not be initialized and used, regardless of - # the cache size. This is because the memory space for the encoder cache - # is preallocated in the profiling run. + self.max_num_encoder_input_tokens = encoder_cache_budget + # NOTE: For the models without encoder (e.g., text-only models), + # the encoder cache will not be initialized because cache size is 0 + # for these models. self.encoder_cache_manager = EncoderCacheManager( - cache_size=self.scheduler_config.encoder_cache_size) + cache_size=encoder_cache_budget) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 975ce11fe8aff..a002e95b6d0ce 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -58,9 +58,12 @@ def __init__( vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks # Setup scheduler. - self.scheduler = Scheduler(vllm_config.scheduler_config, - vllm_config.cache_config, - vllm_config.lora_config) + self.scheduler = Scheduler( + scheduler_config=vllm_config.scheduler_config, + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + ) self._last_logging_time = time.time() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1d4f9b135789..5ee81fe3b3d4c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,6 +19,7 @@ LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) +from vllm.v1.core.encoder_cache_manager import compute_encoder_cache_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -87,8 +88,8 @@ def __init__( self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) self.mm_input_mapper_profiling.use_cache = False - self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 - self.encoder_cache_size = self.scheduler_config.encoder_cache_size + self.encoder_cache_budget = compute_encoder_cache_budget( + self.model_config, self.scheduler_config) # Lazy initialization # self.model: nn.Module # Set after load_model @@ -721,6 +722,10 @@ def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. if self.is_multimodal_model: + + # Encoder cache budget should be set to the model and scheduler + # configurations accordingly. + assert self.encoder_cache_budget > 0 # Create dummy batch of multimodal inputs. dummy_request_data = self.input_registry.dummy_data_for_profiling( @@ -739,34 +744,7 @@ def profile_run(self) -> None: dummy_data_modality, max_tokens_per_mm_item = max( max_tokens_by_modality_dict.items(), key=lambda item: item[1]) - # Check how many items of this modality can be supported by - # the encoder cache budget. - encoder_cache_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) - max_num_mm_items_encoder_budget = encoder_cache_budget // \ - max_tokens_per_mm_item - - # TODO: Allow users to set encoder_cache_budget in case this - # happens. - assert max_num_mm_items_encoder_budget > 0, ( - f"Encoder cache budget={encoder_cache_budget} is too small to " - f"support the maximum possible size of multimodal embeddings" - f"={max_tokens_per_mm_item}.") - - # Check how many items of this modality can be supported by - # the decoder budget. - max_mm_items_per_req = max( - self.mm_registry.get_mm_limits_per_prompt( - self.model_config).values()) - - # NOTE: We do not consider max_num_batched_tokens on purpose - # because the multimodal embeddings can be generated in advance - # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req - - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) + max_num_mm_items = self.encoder_cache_budget // max_tokens_per_mm_item # noqa: E501 # Dummy data definition in V0 may contain multiple multimodal items # (e.g, multiple images) for a single request, therefore here we From 8c67ecdf1cdcc2aecafcc2acb556e27489481bdb Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 12:30:25 +0000 Subject: [PATCH 02/13] format Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 10 ++++++---- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 1ddc179d8bf96..1f169c67413d2 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -60,7 +60,9 @@ def compute_encoder_cache_budget( model_config: "ModelConfig", scheduler_config: "SchedulerConfig", ) -> int: - """Compute the encoder cache budget based on the model and scheduler configurations.""" + """Compute the encoder cache budget based on the model and scheduler + configurations. + """ encoder_cache_budget = 0 if not model_config.is_multimodal_model: @@ -89,7 +91,7 @@ def compute_encoder_cache_budget( # ADDITIONAL multimodal item, and is required only when: # - Two requests in the current batch share the same prefix with such item # as part of the prefix. - # - AND the prefix length is divisible by the block size, triggering the + # - AND the prefix length is divisible by the block size, triggering the # recomputation of the last block. # - AND the part of the embeddings of the item is in this last block. @@ -104,7 +106,7 @@ def compute_encoder_cache_budget( encoder_cache_budget = num_items * max_tokens_per_mm_item logger.info( "Encoder cache will be initialized with a budget of %s tokens, and " - "profiled with %s %s items of the maximum feature size.", encoder_cache_budget, num_items, - modality) + "profiled with %s %s items of the maximum feature size.", + encoder_cache_budget, num_items, modality) return encoder_cache_budget diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5ee81fe3b3d4c..8e91128abe77b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -722,7 +722,7 @@ def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. if self.is_multimodal_model: - + # Encoder cache budget should be set to the model and scheduler # configurations accordingly. assert self.encoder_cache_budget > 0 From 5938a1fb591e3ab27b11ac1b6369dbc39955e09c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 12:33:21 +0000 Subject: [PATCH 03/13] reword Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 1f169c67413d2..c66d7d6f228dc 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -75,7 +75,7 @@ def compute_encoder_cache_budget( key=lambda item: item[1]) max_num_batched_tokens = scheduler_config.max_num_batched_tokens - max_num_seqs = scheduler_config.max_num_seqs + max_num_reqs = scheduler_config.max_num_seqs # In case that the biggest possible multimodal item takes space more # than the batch size, then it needs to be cached and chunk prefilled. @@ -100,8 +100,8 @@ def compute_encoder_cache_budget( num_items += 1 # Number of items needed cannot be bigger than max number of running - # sequences. - num_items = min(num_items, max_num_seqs) + # requests. + num_items = min(num_items, max_num_reqs) encoder_cache_budget = num_items * max_tokens_per_mm_item logger.info( From 0e4ab3c341923e99903a181b9b3f2ccaf638a87f Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 12:38:00 +0000 Subject: [PATCH 04/13] update Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index c66d7d6f228dc..9b364234656a4 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -100,8 +100,10 @@ def compute_encoder_cache_budget( num_items += 1 # Number of items needed cannot be bigger than max number of running - # requests. - num_items = min(num_items, max_num_reqs) + # requests * max number of multimodal items per request. + max_mm_items_per_req = max( + MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values()) + num_items = min(num_items, max_num_reqs * max_mm_items_per_req) encoder_cache_budget = num_items * max_tokens_per_mm_item logger.info( From bd1ccf1617de9f73e47b76af49b55fe894ef10ac Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 23:18:06 +0000 Subject: [PATCH 05/13] address comments Signed-off-by: Roger Wang --- vllm/multimodal/registry.py | 29 ++++++++++++++++++++++----- vllm/v1/core/encoder_cache_manager.py | 26 ++++++++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 25 ++++++++++------------- 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 9eceefb08c93f..7a47e70472ad9 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -253,11 +253,8 @@ def get_max_tokens_per_item_by_modality( model_config: "ModelConfig", ) -> Mapping[str, int]: """ - Get the maximum number of tokens per data item from each modality - for profiling the memory usage of a model. - - Note: - This is currently directly used only in V1. + Get the maximum number of tokens per data item from each modality based + on underlying model configuration. """ if self.has_processor(model_config): tokenizer = cached_get_tokenizer(model_config.tokenizer) @@ -270,6 +267,28 @@ def get_max_tokens_per_item_by_modality( for key, plugin in self._plugins.items() } + def get_max_tokens_per_item_by_nonzero_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: + """ + Get the maximum number of tokens per data item from each modality based + on underlying model configuration, excluding modalities that user + explicitly disabled via `limit_mm_per_prompt`. + + Note: + This is currently directly used only in V1 for profiling the memory + usage of a model. + """ + limits_per_plugin = self._limits_by_model[model_config] + + return { + key: max_tokens_per_mm_item + for key, max_tokens_per_mm_item in + self.get_max_tokens_per_item_by_modality(model_config).items() + if limits_per_plugin[key] > 0 + } + def get_max_tokens_by_modality( self, model_config: "ModelConfig", diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 9b364234656a4..5b56aa9262047 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -62,15 +62,32 @@ def compute_encoder_cache_budget( ) -> int: """Compute the encoder cache budget based on the model and scheduler configurations. + + Args: + model_config: Model configuration. + scheduler_config: Scheduler configuration. + + Returns: + The encoder cache budget, in unit of number of tokens + in the input sequence. """ encoder_cache_budget = 0 + + # TODO: handle encoder-decoder models once we support them. if not model_config.is_multimodal_model: return encoder_cache_budget - max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_modality( # noqa: E501 + max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 model_config) + if not max_tokens_by_modality_dict: + logger.warning( + "All non-text modalities supported by the model have been " + "explicitly disabled via limit_mm_per_prompt. Encoder cache will " + "not be initialized.") + return encoder_cache_budget + modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), key=lambda item: item[1]) @@ -103,12 +120,13 @@ def compute_encoder_cache_budget( # requests * max number of multimodal items per request. max_mm_items_per_req = max( MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values()) - num_items = min(num_items, max_num_reqs * max_mm_items_per_req) + num_items = min(num_items, max_num_reqs * max_mm_items_per_req) encoder_cache_budget = num_items * max_tokens_per_mm_item + logger.info( - "Encoder cache will be initialized with a budget of %s tokens, and " - "profiled with %s %s items of the maximum feature size.", + "Encoder cache will be initialized with a budget of %s tokens," + " and profiled with %s %s items of the maximum feature size.", encoder_cache_budget, num_items, modality) return encoder_cache_budget diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8e91128abe77b..f4ab7a08c0fe1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -721,24 +721,13 @@ def profile_run(self) -> None: ] # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model: - - # Encoder cache budget should be set to the model and scheduler - # configurations accordingly. - assert self.encoder_cache_budget > 0 - - # Create dummy batch of multimodal inputs. - dummy_request_data = self.input_registry.dummy_data_for_profiling( - model_config=self.model_config, - seq_len=self.max_num_tokens, - mm_registry=self.mm_registry, - ) - dummy_mm_data = dummy_request_data.multi_modal_data + # TODO: handle encoder-decoder models once we support them. + if self.is_multimodal_model and self.encoder_cache_budget > 0: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501 + max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 self.model_config) dummy_data_modality, max_tokens_per_mm_item = max( @@ -746,6 +735,14 @@ def profile_run(self) -> None: max_num_mm_items = self.encoder_cache_budget // max_tokens_per_mm_item # noqa: E501 + # Create dummy batch of multimodal inputs. + dummy_request_data = self.input_registry.dummy_data_for_profiling( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_registry=self.mm_registry, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + # Dummy data definition in V0 may contain multiple multimodal items # (e.g, multiple images) for a single request, therefore here we # always replicate first item by max_num_mm_items times since in V1 From 2a4b1d5c5e87c2a200283a9a7b80d1a1c4f3fd37 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 23:28:52 +0000 Subject: [PATCH 06/13] clarify comment Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 5b56aa9262047..da3fefb55a541 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -94,13 +94,13 @@ def compute_encoder_cache_budget( max_num_batched_tokens = scheduler_config.max_num_batched_tokens max_num_reqs = scheduler_config.max_num_seqs - # In case that the biggest possible multimodal item takes space more - # than the batch size, then it needs to be cached and chunk prefilled. + # The biggest possible multimodal item cannot be fully prefilled in a + # batch, so every batch can partially prefill at most one of such item. if max_tokens_per_mm_item > max_num_batched_tokens: num_items = 1 - # In case that the biggest possible multimodal item takes space less - # the batch size, then all items will be full prefilled except one. + # A batch can fully cover multiple biggest possible multimodal items, and + # one that will be partially prefilled. else: num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item) From 9ee3f3d9ad9771edbdc0e179dcc5fe95efade8c5 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 23:32:17 +0000 Subject: [PATCH 07/13] clarify Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index da3fefb55a541..45e89837ea12f 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -112,8 +112,8 @@ def compute_encoder_cache_budget( # recomputation of the last block. # - AND the part of the embeddings of the item is in this last block. - # This can be improved when we have a global encoder cache that does - # not associate items to request id only. + # This issue can be fundamentally resolved by supporting num_new_tokens=0 + # on the model runner. num_items += 1 # Number of items needed cannot be bigger than max number of running From 761488808569f9ceabcc120cb3f2df0ccc2590a9 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 9 Jan 2025 23:37:41 +0000 Subject: [PATCH 08/13] format Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 45e89837ea12f..f2bd288708b48 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -112,7 +112,7 @@ def compute_encoder_cache_budget( # recomputation of the last block. # - AND the part of the embeddings of the item is in this last block. - # This issue can be fundamentally resolved by supporting num_new_tokens=0 + # This issue can be fundamentally resolved by supporting num_new_tokens=0 # on the model runner. num_items += 1 From aaf3cefb03a7d8b5243ed48047c48512b7b0c90b Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 10 Jan 2025 06:41:44 +0000 Subject: [PATCH 09/13] separate Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 34 +++++++++++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 14 +++++------ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index f2bd288708b48..7dedc034574e3 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY @@ -74,10 +74,36 @@ def compute_encoder_cache_budget( encoder_cache_budget = 0 - # TODO: handle encoder-decoder models once we support them. if not model_config.is_multimodal_model: return encoder_cache_budget + # TODO: handle encoder-decoder models once we support them. + encoder_cache_budget, _, _ = compute_encoder_cache_budget_multimodal( + model_config, scheduler_config) + + return encoder_cache_budget + + +def compute_encoder_cache_budget_multimodal( + model_config: "ModelConfig", + scheduler_config: "SchedulerConfig", +) -> tuple[int, Optional[str], int]: + """Compute the encoder cache budget based on the model and scheduler + configurations for a multimodal model. + + Args: + model_config: Model configuration. + scheduler_config: Scheduler configuration. + + Returns: + - The encoder cache budget, in unit of number of tokens in the + input sequence. + - The modality of the multimodal item that requires the most tokens. + - The number of multimodal items used to compute the encoder cache + budget. + """ + + encoder_cache_budget = 0 max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 model_config) @@ -86,7 +112,7 @@ def compute_encoder_cache_budget( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " "not be initialized.") - return encoder_cache_budget + return encoder_cache_budget, None, 0 modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), key=lambda item: item[1]) @@ -129,4 +155,4 @@ def compute_encoder_cache_budget( " and profiled with %s %s items of the maximum feature size.", encoder_cache_budget, num_items, modality) - return encoder_cache_budget + return encoder_cache_budget, modality, num_items diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f4ab7a08c0fe1..9637e3d6a82f2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,7 +19,8 @@ LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) -from vllm.v1.core.encoder_cache_manager import compute_encoder_cache_budget +from vllm.v1.core.encoder_cache_manager import ( + compute_encoder_cache_budget, compute_encoder_cache_budget_multimodal) from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -727,13 +728,10 @@ def profile_run(self) -> None: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 - self.model_config) - - dummy_data_modality, max_tokens_per_mm_item = max( - max_tokens_by_modality_dict.items(), key=lambda item: item[1]) - - max_num_mm_items = self.encoder_cache_budget // max_tokens_per_mm_item # noqa: E501 + _, dummy_data_modality, max_num_mm_items = compute_encoder_cache_budget_multimodal( # noqa: E501 + self.model_config, + self.scheduler_config, + ) # Create dummy batch of multimodal inputs. dummy_request_data = self.input_registry.dummy_data_for_profiling( From 767b0d6d85c24fd6aee0a03d2c4c45db58daecb5 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 15 Jan 2025 05:55:29 +0000 Subject: [PATCH 10/13] update Signed-off-by: Roger Wang --- vllm/config.py | 13 ++++ vllm/v1/core/encoder_cache_manager.py | 94 +++++++++------------------ vllm/v1/core/scheduler.py | 12 ++-- vllm/v1/worker/gpu_model_runner.py | 50 +++++++++++--- 4 files changed, 91 insertions(+), 78 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e270a31875cac..24fb4992814b3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1379,6 +1379,16 @@ class SchedulerConfig: is_multimodal_model: bool = False + # NOTE: The following multimodal encoder budget will be initialized to + # max_num_batched_tokens and overridden in case max multimodal embedding + # size is larger. + # TODO (ywang96): Make these configurable. + # Multimodal encoder compute budget, only used in V1 + max_num_encoder_input_tokens: int = field(default=None) # type: ignore + + # Multimodal encoder cache size, only used in V1 + encoder_cache_size: int = field(default=None) # type: ignore + # Whether to perform preemption by swapping or # recomputation. If not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than @@ -1451,6 +1461,9 @@ def __post_init__(self) -> None: _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) + self.max_num_encoder_input_tokens = self.max_num_batched_tokens + self.encoder_cache_size = self.max_num_batched_tokens + if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 7dedc034574e3..36a374996204a 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,8 +1,7 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Set, Tuple from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.utils import cdiv from vllm.v1.request import Request if TYPE_CHECKING: @@ -56,10 +55,10 @@ def get_freed_ids(self) -> List[Tuple[str, int]]: return freed -def compute_encoder_cache_budget( +def compute_encoder_budget( model_config: "ModelConfig", scheduler_config: "SchedulerConfig", -) -> int: +) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations. @@ -68,26 +67,28 @@ def compute_encoder_cache_budget( scheduler_config: Scheduler configuration. Returns: - The encoder cache budget, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, in unit of number of tokens + in the input sequence. + - Space budget for encoder cache size, in unit of number of tokens + in the input sequence. """ - encoder_cache_budget = 0 - if not model_config.is_multimodal_model: - return encoder_cache_budget + return 0, 0 # TODO: handle encoder-decoder models once we support them. - encoder_cache_budget, _, _ = compute_encoder_cache_budget_multimodal( - model_config, scheduler_config) + ( + encoder_compute_budget, + encoder_cache_size, + ) = _compute_encoder_budget_multimodal(model_config, scheduler_config) - return encoder_cache_budget + return encoder_compute_budget, encoder_cache_size -def compute_encoder_cache_budget_multimodal( +def _compute_encoder_budget_multimodal( model_config: "ModelConfig", scheduler_config: "SchedulerConfig", -) -> tuple[int, Optional[str], int]: +) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. @@ -96,14 +97,12 @@ def compute_encoder_cache_budget_multimodal( scheduler_config: Scheduler configuration. Returns: - - The encoder cache budget, in unit of number of tokens in the - input sequence. - - The modality of the multimodal item that requires the most tokens. - - The number of multimodal items used to compute the encoder cache - budget. + - Compute budget for encoder execution, in unit of number of tokens + in the input sequence. + - Space budget for encoder cache size, in unit of number of tokens + in the input sequence. """ - encoder_cache_budget = 0 max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 model_config) @@ -112,47 +111,14 @@ def compute_encoder_cache_budget_multimodal( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " "not be initialized.") - return encoder_cache_budget, None, 0 - - modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), - key=lambda item: item[1]) - - max_num_batched_tokens = scheduler_config.max_num_batched_tokens - max_num_reqs = scheduler_config.max_num_seqs - - # The biggest possible multimodal item cannot be fully prefilled in a - # batch, so every batch can partially prefill at most one of such item. - if max_tokens_per_mm_item > max_num_batched_tokens: - num_items = 1 - - # A batch can fully cover multiple biggest possible multimodal items, and - # one that will be partially prefilled. - else: - num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item) - - # NOTE: We need the encoder cache to be able to compute & hold ONE - # ADDITIONAL multimodal item, and is required only when: - # - Two requests in the current batch share the same prefix with such item - # as part of the prefix. - # - AND the prefix length is divisible by the block size, triggering the - # recomputation of the last block. - # - AND the part of the embeddings of the item is in this last block. - - # This issue can be fundamentally resolved by supporting num_new_tokens=0 - # on the model runner. - num_items += 1 - - # Number of items needed cannot be bigger than max number of running - # requests * max number of multimodal items per request. - max_mm_items_per_req = max( - MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values()) - - num_items = min(num_items, max_num_reqs * max_mm_items_per_req) - encoder_cache_budget = num_items * max_tokens_per_mm_item - - logger.info( - "Encoder cache will be initialized with a budget of %s tokens," - " and profiled with %s %s items of the maximum feature size.", - encoder_cache_budget, num_items, modality) - - return encoder_cache_budget, modality, num_items + return 0, 0 + + _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), + key=lambda item: item[1]) + + encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, + max_tokens_per_mm_item) + encoder_cache_size = max(scheduler_config.encoder_cache_size, + max_tokens_per_mm_item) + + return encoder_compute_budget, encoder_cache_size diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 51ca9aa924404..63dd1c9b20222 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_cache_budget) + compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput from vllm.v1.outputs import ModelRunnerOutput @@ -74,18 +74,20 @@ def __init__( # NOTE: For now we use the same budget for both compute and space. # This can be changed when we make encoder cache for embedding caching # across requests. - encoder_cache_budget = compute_encoder_cache_budget( - model_config, scheduler_config) + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + ) # NOTE(woosuk): Here, "encoder" includes the vision encoder (and # projector if needed). Currently, we assume that the encoder also # has the Transformer architecture (e.g., ViT). - self.max_num_encoder_input_tokens = encoder_cache_budget + self.max_num_encoder_input_tokens = encoder_compute_budget # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_budget) + cache_size=encoder_cache_size) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7dbd809167271..6687014bfb2cf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -20,8 +20,7 @@ is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) -from vllm.v1.core.encoder_cache_manager import ( - compute_encoder_cache_budget, compute_encoder_cache_budget_multimodal) +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -90,8 +89,12 @@ def __init__( self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) self.mm_input_mapper_profiling.use_cache = False - self.encoder_cache_budget = compute_encoder_cache_budget( - self.model_config, self.scheduler_config) + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size # Lazy initialization # self.model: nn.Module # Set after load_model @@ -724,15 +727,44 @@ def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. # TODO: handle encoder-decoder models once we support them. - if self.is_multimodal_model and self.encoder_cache_budget > 0: + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - _, dummy_data_modality, max_num_mm_items = compute_encoder_cache_budget_multimodal( # noqa: E501 - self.model_config, - self.scheduler_config, - ) + max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 + self.model_config) + dummy_data_modality, max_tokens_per_mm_item = max( + max_tokens_by_modality_dict.items(), key=lambda item: item[1]) + + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_cache_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + max_num_mm_items_encoder_budget = cdiv(encoder_cache_budget, + max_tokens_per_mm_item) + + # Check how many items of this modality can be supported by + # the decoder budget. + max_mm_items_per_req = max( + self.mm_registry.get_mm_limits_per_prompt( + self.model_config).values()) + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + + logger.info( + "Encoder cache will be initialized with a budget of %s tokens," + " and profiled with %s %s items of the maximum feature size.", + encoder_cache_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. dummy_request_data = self.input_registry.dummy_data_for_profiling( From eb125b5ce6740b37f7a626bcf835f05a54c39743 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 15 Jan 2025 06:02:40 +0000 Subject: [PATCH 11/13] rename Signed-off-by: Roger Wang --- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6687014bfb2cf..4c41330ec105b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -740,10 +740,10 @@ def profile_run(self) -> None: # Check how many items of this modality can be supported by # the encoder budget. - encoder_cache_budget = min(self.max_num_encoder_input_tokens, + encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) - max_num_mm_items_encoder_budget = cdiv(encoder_cache_budget, + max_num_mm_items_encoder_budget = cdiv(encoder_budget, max_tokens_per_mm_item) # Check how many items of this modality can be supported by @@ -764,7 +764,7 @@ def profile_run(self) -> None: logger.info( "Encoder cache will be initialized with a budget of %s tokens," " and profiled with %s %s items of the maximum feature size.", - encoder_cache_budget, max_num_mm_items, dummy_data_modality) + encoder_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. dummy_request_data = self.input_registry.dummy_data_for_profiling( From f53947057a2600a1d7a28699e527dac6e1426ad3 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 15 Jan 2025 06:14:44 +0000 Subject: [PATCH 12/13] respect modality Signed-off-by: Roger Wang --- vllm/v1/worker/gpu_model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c41330ec105b..de83640b27cd6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -741,16 +741,15 @@ def profile_run(self) -> None: # Check how many items of this modality can be supported by # the encoder budget. encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) + self.encoder_cache_size) max_num_mm_items_encoder_budget = cdiv(encoder_budget, max_tokens_per_mm_item) # Check how many items of this modality can be supported by # the decoder budget. - max_mm_items_per_req = max( - self.mm_registry.get_mm_limits_per_prompt( - self.model_config).values()) + max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( + self.model_config)[dummy_data_modality] # NOTE: We do not consider max_num_batched_tokens on purpose # because the multimodal embeddings can be generated in advance From 29ad3590247738fc73a7db999bd8a1b100671087 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 15 Jan 2025 08:00:04 +0000 Subject: [PATCH 13/13] use typing.Tuple Signed-off-by: Roger Wang --- vllm/v1/core/encoder_cache_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 36a374996204a..0cd8c806a3e47 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -58,7 +58,7 @@ def get_freed_ids(self) -> List[Tuple[str, int]]: def compute_encoder_budget( model_config: "ModelConfig", scheduler_config: "SchedulerConfig", -) -> tuple[int, int]: +) -> Tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations. @@ -88,7 +88,7 @@ def compute_encoder_budget( def _compute_encoder_budget_multimodal( model_config: "ModelConfig", scheduler_config: "SchedulerConfig", -) -> tuple[int, int]: +) -> Tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model.