diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 3e3d2e3f5c53d..c2d1c5769619b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) - assert attn_metadata.kv_cache_dtype == "auto" assert attn_metadata.num_prefills == prefill_batch_size if enforce_eager: assert attn_metadata.num_decode_tokens == decode_batch_size diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 7636b34a16fed..088f48def7668 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -5,9 +5,9 @@ from vllm.attention.selector import get_attn_backend __all__ = [ + "Attention", "AttentionBackend", "AttentionMetadata", - "Attention", - "get_attn_backend", "AttentionMetadataPerStage", + "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 64ccb309a0480..98d70fcab1a18 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]): # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The kv cache's data type. - kv_cache_dtype: str def __post_init__(self): if self.num_prefill_tokens > 0: @@ -116,6 +114,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: raise NotImplementedError @@ -127,6 +126,6 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4bad226512b69..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -140,16 +140,18 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -167,7 +169,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata[FlashAttentionMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -196,8 +198,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -264,7 +265,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 36e162671f944..92d0fe0487516 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -149,20 +149,33 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - if sliding_window is not None: - raise ValueError("Sliding window is not supported in FlashInfer.") - self.sliding_window = (-1, -1) - self.alibi_slopes = alibi_slopes - self.scale = scale self.num_heads = num_heads self.head_size = head_size + self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.kv_cache_dtype = kv_cache_dtype - def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], - kv_scale: float): + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float = 1.0, + ) -> torch.Tensor: + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -183,7 +196,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, ) if prefill_meta := attn_metadata.prefill_metadata: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8fc1af1aa1e1c..539585b46c7aa 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -138,25 +138,27 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. @@ -229,7 +231,7 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, kv_scale, ) @@ -323,7 +325,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c29218dfd0cfc..2dd72a00c6e30 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -83,26 +83,32 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window if alibi_slopes is not None: - assert len(alibi_slopes) == num_heads alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") def forward( self, @@ -111,7 +117,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -124,6 +130,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -136,8 +143,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -195,7 +201,7 @@ def forward( attn_metadata.block_tables, attn_metadata.seq_lens_tensor, attn_metadata.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2a9150dea5875..cb2028553461f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -149,15 +149,17 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -175,7 +177,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[XFormersMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -188,7 +190,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -203,8 +204,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -262,7 +262,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ee7be26c0876c..8a872dba8c877 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,6 +7,7 @@ from vllm.attention.backends.abstract import (AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig class Attention(nn.Module): @@ -29,10 +30,24 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() - self.backend = get_attn_backend(torch.get_default_dtype()) - impl_cls = self.backend.get_impl_cls() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + if num_kv_heads is None: + num_kv_heads = num_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) + impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index f4446bac6b8d2..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,6 +1,6 @@ import enum from functools import lru_cache -from typing import Type +from typing import Optional, Type import torch @@ -21,8 +21,18 @@ class _Backend(enum.Enum): @lru_cache(maxsize=None) -def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - backend = _which_attn_to_use(dtype) +def get_attn_backend( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> Type[AttentionBackend]: + backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 @@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend. ") + logger.warning("Eager mode is enforced for the Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use(dtype: torch.dtype) -> _Backend: +def _which_attn_to_use( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> _Backend: """Returns which flash attention backend to use.""" if is_cpu(): return _Backend.TORCH_SDPA diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 6f90e49994fb2..e3e32d61ab04d 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -2,26 +2,29 @@ from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.model_executor.model_loader.loader import (BaseModelLoader, get_model_loader) from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) -def get_model( - *, model_config: ModelConfig, load_config: LoadConfig, - device_config: DeviceConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def get_model(*, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: loader = get_model_loader(load_config) return loader.load_model(model_config=model_config, device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, parallel_config=parallel_config, - scheduler_config=scheduler_config) + scheduler_config=scheduler_config, + cache_config=cache_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bafa2de62e5df..fc9c8aa0af44b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,9 +9,9 @@ import torch from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -77,15 +77,16 @@ def _get_model_initialization_kwargs( return extra_kwargs -def _initialize_model( - model_config: ModelConfig, load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] quant_config = _get_quantization_config(model_config, load_config) return model_class(config=model_config.hf_config, + cache_config=cache_config, quant_config=quant_config, **_get_model_initialization_kwargs( model_class, lora_config, vision_language_config)) @@ -103,7 +104,8 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: """Load a model with the given configurations.""" ... @@ -216,11 +218,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, @@ -253,11 +257,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) @@ -286,9 +292,12 @@ def _get_weights_iterator( return tensorizer_weights_iterator(tensorizer_args) def _load_model_unserialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load an unserialized model with tensorizer. @@ -299,15 +308,19 @@ def _load_model_unserialized( with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load a serialized model with tensorizer. @@ -321,6 +334,7 @@ def _load_model_serialized( extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, vision_language_config) extra_kwargs["quant_config"] = quant_config + extra_kwargs["cache_config"] = cache_config tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class @@ -335,16 +349,19 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) if is_vllm_serialized_tensorizer(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) return self._load_model_unserialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 796cef7c4a735..cb99939cbb17a 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -5,6 +5,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -215,6 +216,7 @@ def __init__( self, config: ArcticConfig, layer_idx: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -265,7 +267,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -288,6 +291,7 @@ def __init__( self, config: ArcticConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -297,6 +301,7 @@ def __init__( self.use_residual = config.use_residual and is_moe_layer self.self_attn = ArcticAttention(config, layer_idx, + cache_config, quant_config=quant_config) self.block_sparse_moe = ArcticMoE( config, @@ -356,6 +361,7 @@ class ArcticModel(nn.Module): def __init__( self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -366,7 +372,10 @@ def __init__( config.hidden_size, org_num_embeddings=self.vocab_size) self.layers = nn.ModuleList([ - ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + ArcticDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self._attn_implementation = config._attn_implementation @@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module): def __init__(self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, **kwargs) -> None: super().__init__() self.config = config - self.model = ArcticModel(config, quant_config) + self.model = ArcticModel(config, cache_config, quant_config) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 186cee2584369..58b3405d319d1 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,7 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -111,6 +111,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -162,7 +163,10 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size @@ -197,6 +202,7 @@ def __init__(self, position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = BaiChuanMLP( @@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -255,7 +262,8 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, quant_config) + BaiChuanDecoderLayer(config, position_embedding, cache_config, + quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -304,13 +312,15 @@ def __init__( self, config, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = BaiChuanModel(config, position_embedding, quant_config) + self.model = BaiChuanModel(config, position_embedding, cache_config, + quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", quant_config, lora_config) + super().__init__(config, "ALIBI", cache_config, quant_config, + lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1d7e5d2517c72..fe2de87b20dc9 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,7 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -71,6 +72,7 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,7 +110,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + cache_config=cache_config) def forward( self, @@ -158,6 +161,7 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -165,7 +169,8 @@ def __init__( self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, quant_config) + self.self_attention = BloomAttention(config, cache_config, + quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -214,6 +219,7 @@ class BloomModel(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -229,7 +235,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - BloomBlock(config, quant_config) + BloomBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = BloomModel(config, quant_config) + self.transformer = BloomModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e116af2ed080d..29c76682109c6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -34,6 +34,7 @@ class GLMAttention(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -90,6 +91,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -167,6 +169,7 @@ class GLMBlock(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -181,7 +184,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, quant_config) + self.self_attention = GLMAttention(config, cache_config, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -237,6 +240,7 @@ class GLMTransformer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -246,8 +250,10 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList( - [GLMBlock(config, quant_config) for i in range(self.num_layers)]) + self.layers = nn.ModuleList([ + GLMBlock(config, cache_config, quant_config) + for i in range(self.num_layers) + ]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -292,7 +299,7 @@ def __init__( self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, quant_config) + self.encoder = GLMTransformer(config, cache_config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) @@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config - self.transformer = ChatGLMModel(config, quant_config) + self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 17c2f1223d96b..7354d11f98b15 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,6 +29,7 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -124,6 +125,7 @@ class CohereAttention(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -180,6 +182,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, @@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, quant_config=quant_config) + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config) self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), @@ -258,6 +264,7 @@ class CohereModel(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +273,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - CohereDecoderLayer(config, quant_config=quant_config) + CohereDecoderLayer(config, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = LayerNorm(param_shape=(config.hidden_size), @@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -306,7 +314,7 @@ def __init__( self.quant_config = quant_config self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, quant_config) + self.model = CohereModel(config, cache_config, quant_config) self.sampler = Sampler() @torch.no_grad() diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index a4a0ae50c645e..083ddf0159f71 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -5,6 +5,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -166,6 +167,7 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -221,6 +223,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -279,10 +282,12 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config) + self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, + quant_config) self.ffn = DbrxExperts(config, quant_config) def forward( @@ -308,6 +313,7 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -315,8 +321,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [DbrxBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + DbrxBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, @@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, quant_config) + self.transformer = DbrxModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index be9a6b6813f8f..e293ee491908d 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -28,7 +28,7 @@ import torch from transformers import PretrainedConfig -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, config: Optional[PretrainedConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") super().__init__(config=config, + cache_config=cache_config, quant_config=quant_config, lora_config=lora_config) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index e5f7ba086a35d..62e04f9649915 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -28,6 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -178,6 +179,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -229,7 +231,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -252,6 +255,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -267,6 +271,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.n_routed_experts is not None @@ -321,6 +326,7 @@ class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -332,7 +338,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config) + DeepseekDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekModel(config, quant_config) + self.model = DeepseekModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 08dd69923dc6d..ab9e1994be426 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -77,6 +78,7 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -168,7 +170,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, quant_config) + self.self_attention = FalconAttention(config, cache_config, + quant_config) self.mlp = FalconMLP(config, quant_config) self.config = config @@ -311,6 +316,7 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -327,7 +333,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config, quant_config) + FalconDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = FalconModel(config, quant_config) + self.transformer = FalconModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index bb73ff4d206da..d1502b718a773 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,7 +22,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul @@ -107,6 +107,7 @@ def __init__(self, head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -155,7 +156,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -188,6 +191,7 @@ def __init__( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = GemmaMLP( @@ -236,6 +240,7 @@ class GemmaModel(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -246,7 +251,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, quant_config) + GemmaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -316,7 +322,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = GemmaModel(config, quant_config) + self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 75eaebf0dbd15..0deaa58ed9eb5 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -70,7 +72,10 @@ def __init__( bias=True, quant_config=quant_config, ) - self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config) def forward( self, @@ -122,6 +127,7 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -130,7 +136,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, quant_config) + self.attn = GPT2Attention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config, quant_config) @@ -163,6 +169,7 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -174,7 +181,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, quant_config) + GPT2Block(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(config, quant_config) + self.transformer = GPT2Model(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index d057fd928fdb5..c20fb3230c394 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -85,7 +87,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -151,7 +155,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, quant_config) + self.attn = GPTBigCodeAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigMLP(inner_dim, config, quant_config) @@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -195,7 +200,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPTBigCodeBlock(config, quant_config) + GPTBigCodeBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, quant_config) + self.transformer = GPTBigCodeModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 8d7fe8a5beef7..5f4d8ec3d3a7a 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -83,7 +85,10 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -135,13 +140,14 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, quant_config) + self.attn = GPTJAttention(config, cache_config, quant_config) self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -169,6 +175,7 @@ class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,8 +185,10 @@ def __init__( config.vocab_size, self.embed_dim, ) - self.h = nn.ModuleList( - [GPTJBlock(config, quant_config) for _ in range(config.n_layer)]) + self.h = nn.ModuleList([ + GPTJBlock(config, cache_config, quant_config) + for _ in range(config.n_layer) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, quant_config) + self.transformer = GPTJModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index bab563b9c5a39..dcb52ff666c95 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -84,7 +86,10 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -142,7 +148,7 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, quant_config) + self.attention = GPTNeoXAttention(config, cache_config, quant_config) self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -192,7 +199,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GPTNeoXLayer(config, quant_config) + GPTNeoXLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.gpt_neox = GPTNeoXModel(config, quant_config) + self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 5811cae83bf8b..65f7ddb8b082c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -64,6 +65,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -114,7 +116,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -151,6 +155,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.feed_forward = InternLM2MLP( @@ -196,6 +201,7 @@ class InternLM2Model(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -207,7 +213,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, quant_config) + InternLMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = InternLM2Model(config, quant_config) + self.model = InternLM2Model(config, cache_config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index bd6a180ec8dfc..df30fd1ba0a37 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,6 +26,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -69,6 +70,7 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,6 +110,7 @@ def __init__( self.head_dim, scale=self.scale, alibi_slopes=alibi_slopes, + cache_config=cache_config, ) def forward( @@ -170,6 +173,7 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,7 +182,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, quant_config) + self.attn = JAISAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -211,6 +215,7 @@ class JAISModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -228,7 +233,7 @@ def __init__( else: self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ - JAISBlock(config, quant_config) + JAISBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = JAISModel(config, quant_config) + self.transformer = JAISModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 127e4612b2e40..ebdc64e0e220e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -94,6 +94,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -153,7 +154,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -176,6 +178,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -204,6 +207,7 @@ def __init__( quant_config=quant_config, bias=attention_bias, sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -251,6 +255,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -267,7 +272,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) + LlamaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -332,12 +337,16 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.model = LlamaModel(config, quant_config, lora_config=lora_config) + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dcde4dfa0795e..3b99b337a2765 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -7,7 +7,7 @@ from transformers import CLIPVisionModel, LlavaConfig from vllm.attention import AttentionMetadata -from vllm.config import VisionLanguageConfig +from vllm.config import CacheConfig, VisionLanguageConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): def __init__(self, config: "LlavaConfig", vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional["QuantizationConfig"] = None) -> None: super().__init__() self.config = config @@ -85,7 +86,8 @@ def __init__(self, projector_hidden_act=config.projector_hidden_act) self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, quant_config) + self.language_model = LlamaModel(config.text_config, cache_config, + quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c90bcfbfc4707..0b85cf1c94795 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -28,7 +28,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -181,6 +181,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +235,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -275,6 +278,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) @@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -346,7 +351,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, quant_config) + MiniCPMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -421,6 +427,7 @@ def __init__( self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, + cache_config, quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index efa4de7516212..113abbaa6036d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -252,6 +252,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -313,6 +314,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -348,6 +351,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, @@ -394,6 +398,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -410,7 +415,9 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.model = MixtralModel(config, + cache_config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 38c62afced28a..ee2626b1c1aa2 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,6 +30,7 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -157,14 +158,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -215,6 +219,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -250,6 +256,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) @@ -292,6 +299,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -303,7 +311,9 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config) + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 6fa5c5bd3014a..716ac51cde94d 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -43,6 +44,7 @@ class MPTAttention(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -107,7 +109,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -166,12 +169,13 @@ class MPTBlock(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, quant_config) + self.attn = MPTAttention(config, cache_config, quant_config) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -201,6 +205,7 @@ class MPTModel(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -211,8 +216,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [MPTBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + MPTBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -253,7 +261,7 @@ def __init__( assert config.tie_word_embeddings self.quant_config = quant_config - self.transformer = MPTModel(config, quant_config) + self.transformer = MPTModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index f212ea2166e1d..69f23bbfb5d0a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -28,6 +28,7 @@ from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -55,6 +56,7 @@ class OlmoAttention(nn.Module): def __init__( self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -93,7 +95,8 @@ def __init__( self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) # Attention output projection. self.o_proj = RowParallelLinear( @@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, quant_config) + self.self_attn = OlmoAttention(config, cache_config, quant_config) # MLP block. self.mlp = OlmoMLP(config, quant_config) @@ -217,6 +221,7 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -224,7 +229,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, quant_config) + OlmoDecoderLayer(config, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, @@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = OlmoModel(config, quant_config) + self.model = OlmoModel(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight else: diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 336f765ababaa..d241756e50f4a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -61,6 +62,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -88,7 +90,8 @@ def __init__( ) self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) def forward( self, @@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -117,6 +121,7 @@ def __init__( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, + cache_config=cache_config, quant_config=quant_config, ) self.do_layer_norm_before = config.do_layer_norm_before @@ -181,6 +186,7 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -226,7 +232,7 @@ def __init__( self.final_layer_norm = None self.layers = nn.ModuleList([ - OPTDecoderLayer(config, quant_config) + OPTDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -259,10 +265,11 @@ class OPTModel(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.decoder = OPTDecoder(config, quant_config) + self.decoder = OPTDecoder(config, cache_config, quant_config) def forward( self, @@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = OPTModel(config, quant_config) + self.model = OPTModel(config, cache_config, quant_config) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 9ab5dfb97c19a..59cd42e31b374 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -68,6 +69,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -118,7 +120,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -155,6 +159,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = OrionMLP( @@ -202,6 +207,7 @@ class OrionModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -213,7 +219,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OrionDecoderLayer(config, quant_config) + OrionDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = OrionModel(config, quant_config) + self.model = OrionModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 4a45879201af3..ed25a232f4208 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,6 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -63,6 +64,7 @@ class PhiAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.total_num_heads = config.num_attention_heads @@ -105,7 +107,10 @@ def __init__(self, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -155,11 +160,12 @@ class PhiLayer(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, quant_config) + self.self_attn = PhiAttention(config, cache_config, quant_config) self.mlp = PhiMLP(config, quant_config) def forward( @@ -186,6 +192,7 @@ class PhiModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -193,7 +200,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, quant_config) + PhiLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, @@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.quant_config = quant_config - self.model = PhiModel(config, quant_config) + self.model = PhiModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e5e0028888c88..d158846a3a1f5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -68,6 +69,7 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -101,7 +103,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -123,6 +128,7 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -135,6 +141,7 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, + cache_config=cache_config, quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -175,6 +182,7 @@ class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -186,7 +194,7 @@ def __init__( config.hidden_size, ) self.h = nn.ModuleList([ - QWenBlock(config, quant_config) + QWenBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = QWenModel(config, quant_config) + self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 62bc7fe22c367..31ba6441f9f7a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -29,7 +29,7 @@ from transformers import Qwen2Config from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -87,6 +87,7 @@ def __init__(self, max_position: int = 4096 * 32, rope_theta: float = 10000, use_sliding_window: bool = False, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -137,7 +138,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + sliding_window=self.sliding_window, + cache_config=cache_config) def forward( self, @@ -160,6 +162,7 @@ def __init__( self, config: Qwen2Config, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, use_sliding_window=use_sliding_window, + cache_config=cache_config, quant_config=quant_config, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( @@ -222,6 +226,7 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +239,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, quant_config) + Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -294,7 +300,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2Model(config, quant_config) + self.model = Qwen2Model(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 8da89a2b7ba6c..2a3b0173adf8b 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,6 +30,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -187,6 +188,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -238,7 +240,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -261,6 +264,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -276,6 +280,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.num_experts is not None @@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -339,7 +345,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config) + Qwen2MoeDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(config, quant_config) + self.model = Qwen2MoeModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 3d4f4f700f867..8b4a5507feade 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -72,6 +73,7 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config @@ -124,7 +126,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_key_value_heads) + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config) def forward( self, @@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.self_attn = StablelmAttention(config) + self.self_attn = StablelmAttention(config, cache_config, quant_config) self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) @@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( @@ -195,7 +200,7 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, quant_config) + StablelmDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) norm_eps = getattr(config, "norm_eps", @@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(config, quant_config) + self.model = StableLMEpochModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 33998e2aad5c5..3c19d63276a77 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -101,6 +103,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, quant_config=quant_config) + self.self_attn = Starcoder2Attention(config, + cache_config, + quant_config=quant_config) self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -201,7 +208,9 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, quant_config=quant_config) + Starcoder2DecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = Starcoder2Model(config, quant_config=quant_config) + self.model = Starcoder2Model(config, + cache_config, + quant_config=quant_config) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 0fb2662b2f715..6ef230a8ebbca 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -27,7 +27,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -89,6 +89,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -133,7 +134,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = XverseMLP( hidden_size=self.hidden_size, @@ -221,6 +225,7 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -237,7 +242,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - XverseDecoderLayer(config, quant_config) + XverseDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config=None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = XverseModel(config, quant_config) + self.model = XverseModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1fb63a3e47921..07d51dca226bd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -31,7 +31,7 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_kv_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -43,7 +43,15 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") @@ -56,7 +64,7 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 6c8b1685dadcf..0a0b0d70cfe21 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -53,7 +53,15 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -66,7 +74,8 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, lora_config=self.lora_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) def _prepare_prompt( self, @@ -158,7 +167,6 @@ def _prepare_prompt( decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) @@ -242,7 +250,6 @@ def _prepare_decode( prefill_metadata=None, decode_metadata=None, block_tables=block_tables, - kv_cache_dtype=self.kv_cache_dtype, ) return ( input_tokens, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 5e4ae564cb57e..3ee394f9912e9 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -53,7 +53,15 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 2d3f160c60dc1..d04bebbdc31b6 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -235,7 +235,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, pooling_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f46b475bdc2db..b5e1991717b13 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -141,10 +141,18 @@ def __init__( self.graph_block_tables = np.zeros( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization - self.model: torch.nn.Module # Set after load_model + self.model: nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. @@ -160,6 +168,7 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, + cache_config=self.cache_config, ) self.model_memory_usage = m.consumed_memory @@ -753,7 +762,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, @@ -965,7 +973,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping=slot_mapping[:batch_size], prefill_metadata=None, decode_metadata=decode_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: