-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate AquilaForCausalLM to LlamaForCausalLM #2867
Conversation
I tried out
Here is the diff: 27a28
> from transformers import LlamaConfig
31a33
> from vllm.model_executor.layers.layernorm import RMSNorm
39c41
< VocabParallelEmbedding, ParallelLMHead)
---
> VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
46c48
< from vllm.transformers_utils.configs.aquila import AquilaConfig
---
> from vllm.config import LoRAConfig
51c53
< class AquilaMLP(nn.Module):
---
> class LlamaMLP(nn.Module):
59c61
< ):
---
> ) -> None:
81c83
< class AquilaRMSNorm(nn.Module):
---
> class LlamaAttention(nn.Module):
83,102d84
< def __init__(self, hidden_size, eps=1e-6):
< """
< AquilaRMSNorm is equivalent to T5LayerNorm
< """
< super().__init__()
< self.weight = nn.Parameter(torch.ones(hidden_size))
< self.variance_epsilon = eps
<
< def forward(self, hidden_states):
< input_dtype = hidden_states.dtype
< variance = hidden_states.to(torch.float32).pow(2).mean(-1,
< keepdim=True)
< hidden_states = hidden_states * torch.rsqrt(variance +
< self.variance_epsilon)
<
< return (self.weight * hidden_states).to(input_dtype)
<
<
< class AquilaAttention(nn.Module):
<
109d90
< max_position_embeddings: int = 8192,
110a92
> max_position_embeddings: int = 8192,
112c94,95
< ):
---
> bias: bool = False,
> ) -> None:
120,121c103,111
< assert self.total_num_kv_heads % tp_size == 0
< self.num_kv_heads = self.total_num_kv_heads // tp_size
---
> if self.total_num_kv_heads >= tp_size:
> # Number of KV heads is greater than TP size, so we partition
> # the KV heads across multiple tensor parallel GPUs.
> assert self.total_num_kv_heads % tp_size == 0
> else:
> # Number of KV heads is less than TP size, so we replicate
> # the KV heads across multiple tensor parallel GPUs.
> assert tp_size % self.total_num_kv_heads == 0
> self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134c124
< bias=False,
---
> bias=bias,
140c130
< bias=False,
---
> bias=bias,
142a133
>
146,147c137,138
< max_position=self.max_position_embeddings,
< base=self.rope_theta,
---
> max_position=max_position_embeddings,
> base=rope_theta,
171c162
< class AquilaDecoderLayer(nn.Module):
---
> class LlamaDecoderLayer(nn.Module):
175c166
< config: AquilaConfig,
---
> config: LlamaConfig,
177c168
< ):
---
> ) -> None:
184c175
< self.self_attn = AquilaAttention(
---
> self.self_attn = LlamaAttention(
187c178,179
< num_kv_heads=config.num_key_value_heads,
---
> num_kv_heads=getattr(config, "num_key_value_heads",
> config.num_attention_heads),
189d180
< max_position_embeddings=max_position_embeddings,
190a182
> max_position_embeddings=max_position_embeddings,
191a184
> bias=getattr(config, "bias", False),
193c186
< self.mlp = AquilaMLP(
---
> self.mlp = LlamaMLP(
199,202c192,195
< self.input_layernorm = AquilaRMSNorm(config.hidden_size,
< eps=config.rms_norm_eps)
< self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
< eps=config.rms_norm_eps)
---
> self.input_layernorm = RMSNorm(config.hidden_size,
> eps=config.rms_norm_eps)
> self.post_attention_layernorm = RMSNorm(config.hidden_size,
> eps=config.rms_norm_eps)
210c203,204
< ) -> torch.Tensor:
---
> residual: Optional[torch.Tensor],
> ) -> Tuple[torch.Tensor, torch.Tensor]:
212,213c206,211
< residual = hidden_states
< hidden_states = self.input_layernorm(hidden_states)
---
> if residual is None:
> residual = hidden_states
> hidden_states = self.input_layernorm(hidden_states)
> else:
> hidden_states, residual = self.input_layernorm(
> hidden_states, residual)
220d217
< hidden_states = residual + hidden_states
223,224c220,221
< residual = hidden_states
< hidden_states = self.post_attention_layernorm(hidden_states)
---
> hidden_states, residual = self.post_attention_layernorm(
> hidden_states, residual)
226,227c223
< hidden_states = residual + hidden_states
< return hidden_states
---
> return hidden_states, residual
230c226
< class AquilaModel(nn.Module):
---
> class LlamaModel(nn.Module):
234c230
< config: AquilaConfig,
---
> config: LlamaConfig,
236c232,233
< ):
---
> lora_config: Optional[LoRAConfig] = None,
> ) -> None:
240c237,240
< self.vocab_size = config.vocab_size
---
> lora_vocab = (lora_config.lora_extra_vocab_size *
> (lora_config.max_loras or 1)) if lora_config else 0
> self.vocab_size = config.vocab_size + lora_vocab
> self.org_vocab_size = config.vocab_size
242c242
< config.vocab_size,
---
> self.vocab_size,
243a244
> org_num_embeddings=config.vocab_size,
246c247
< AquilaDecoderLayer(config, linear_method)
---
> LlamaDecoderLayer(config, linear_method)
249c250
< self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
---
> self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
258a260
> residual = None
261c263
< hidden_states = layer(
---
> hidden_states, residual = layer(
265a268
> residual,
267,268c270
< hidden_states = self.norm(hidden_states)
<
---
> hidden_states, _ = self.norm(hidden_states, residual)
272c274,285
< class AquilaForCausalLM(nn.Module):
---
> class LlamaForCausalLM(nn.Module):
> packed_modules_mapping = {
> "qkv_proj": [
> "q_proj",
> "k_proj",
> "v_proj",
> ],
> "gate_up_proj": [
> "gate_proj",
> "up_proj",
> ],
> }
273a287,301
> # LoRA specific attributes
> supported_lora_modules = [
> "qkv_proj",
> "o_proj",
> "gate_up_proj",
> "down_proj",
> "embed_tokens",
> "lm_head",
> ]
> embedding_modules = {
> "embed_tokens": "input_embeddings",
> "lm_head": "output_embeddings",
> }
> embedding_padding_modules = ["lm_head"]
>
276c304
< config,
---
> config: LlamaConfig,
278c306,307
< ):
---
> lora_config: Optional[LoRAConfig] = None,
> ) -> None:
282,284c311,324
< self.model = AquilaModel(config, linear_method)
< self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
< self.sampler = Sampler(config.vocab_size)
---
> self.model = LlamaModel(config, linear_method, lora_config=lora_config)
> self.unpadded_vocab_size = config.vocab_size
> if lora_config:
> self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
> self.lm_head = ParallelLMHead(
> self.unpadded_vocab_size,
> config.hidden_size,
> org_num_embeddings=config.vocab_size,
> padding_size=DEFAULT_VOCAB_PADDING_SIZE
> # We need bigger padding if using lora for kernel
> # compatibility
> if not lora_config else lora_config.lora_vocab_padding_size,
> )
> self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
323a364,368
> if ("rotary_emb.cos_cached" in name
> or "rotary_emb.sin_cached" in name):
> # Models trained using ColossalAI may include these tensors in
> # the checkpoint. Skip them.
> continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These implementations indeed look the same to me :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the PR!
> ```Can you post your test code? The results of my test look relatively poor. Thanks |
There is no actual difference between
AquilaForCausalLM
andLlamaForCausalLM
.This is a subsequent PR of #2637.