diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index fb519b3c0cf92..2985e9c69ae34 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -24,7 +24,7 @@ "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), - "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py deleted file mode 100644 index 5d0b93793c89d..0000000000000 --- a/vllm/model_executor/models/internlm.py +++ /dev/null @@ -1,299 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple - -import torch -from torch import nn -from transformers import LlamaConfig - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class InternLMMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class InternLMAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - bias: bool, - rope_theta: float = 10000, - max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - ): - super().__init__() - self.hidden_size = hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) - self.head_dim = hidden_size // self.total_num_heads - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - bias=bias, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=bias, - linear_method=linear_method, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class InternLMDecoderLayer(nn.Module): - - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = InternLMAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - bias=config.bias, - rope_theta=rope_theta, - max_position_embeddings=max_position_embeddings, - linear_method=linear_method, - rope_scaling=getattr(config, "rope_scaling", None), - ) - self.mlp = InternLMMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class InternLMModel(nn.Module): - - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.embed_tokens = VocabParallelEmbedding( - vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class InternLMForCausalLM(nn.Module): - - def __init__( - self, - config, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = InternLMModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 860a8f267acf8..6202e81fffa7c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -91,6 +91,7 @@ def __init__( rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, linear_method: Optional[LinearMethodBase] = None, + bias: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -120,13 +121,13 @@ def __init__( self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=False, + bias=bias, linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, - bias=False, + bias=bias, linear_method=linear_method, ) @@ -179,6 +180,7 @@ def __init__( rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, linear_method=linear_method, + bias=getattr(config, "bias", False), ) self.mlp = LlamaMLP( hidden_size=self.hidden_size,