Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate InternLMForCausalLM to LlamaForCausalLM #2860

Merged
merged 3 commits into from
Feb 14, 2024

Conversation

pcmoritz
Copy link
Collaborator

The difference between InternLM and Llama is very small, just the bias for the attention layer.

For maintainability and to make things like LoRA support more uniform, this PR merges the two models. There should be no user-visible change.

This was proposed by @esmeetu in #2637 who is a coauthor of this PR.

Here is the diff between the models:

1c1,23
< # -*- coding: utf-8 -*-
---
> # coding=utf-8
> # Adapted from
> # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
> # Copyright 2023 The vLLM team.
> # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
> #
> # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
> # and OPT implementations in this library. It has been modified from its
> # original forms to accommodate minor architectural differences compared
> # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
> #
> # Licensed under the Apache License, Version 2.0 (the "License");
> # you may not use this file except in compliance with the License.
> # You may obtain a copy of the License at
> #
> #     http://www.apache.org/licenses/LICENSE-2.0
> #
> # Unless required by applicable law or agreed to in writing, software
> # distributed under the License is distributed on an "AS IS" BASIS,
> # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> # See the License for the specific language governing permissions and
> # limitations under the License.
> """Inference-only LLaMA model compatible with HuggingFace weights."""
19c41
<     VocabParallelEmbedding, ParallelLMHead)
---
>     VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
25a48
> from vllm.config import LoRAConfig
30c53
< class InternLMMLP(nn.Module):
---
> class LlamaMLP(nn.Module):
38c61
<     ):
---
>     ) -> None:
60c83
< class InternLMAttention(nn.Module):
---
> class LlamaAttention(nn.Module):
66c89
<         bias: bool,
---
>         num_kv_heads: int,
67a91
>         rope_scaling: Optional[Dict[str, Any]] = None,
70,71c94
<         rope_scaling: Optional[Dict[str, Any]] = None,
<     ):
---
>     ) -> None:
74,75c97
<         tensor_model_parallel_world_size = (
<             get_tensor_model_parallel_world_size())
---
>         tp_size = get_tensor_model_parallel_world_size()
77,79c99,110
<         assert self.total_num_heads % tensor_model_parallel_world_size == 0
<         self.num_heads = (self.total_num_heads //
<                           tensor_model_parallel_world_size)
---
>         assert self.total_num_heads % tp_size == 0
>         self.num_heads = self.total_num_heads // tp_size
>         self.total_num_kv_heads = num_kv_heads
>         if self.total_num_kv_heads >= tp_size:
>             # Number of KV heads is greater than TP size, so we partition
>             # the KV heads across multiple tensor parallel GPUs.
>             assert self.total_num_kv_heads % tp_size == 0
>         else:
>             # Number of KV heads is less than TP size, so we replicate
>             # the KV heads across multiple tensor parallel GPUs.
>             assert tp_size % self.total_num_kv_heads == 0
>         self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
80a112,113
>         self.q_size = self.num_heads * self.head_dim
>         self.kv_size = self.num_kv_heads * self.head_dim
89c122,123
<             bias=bias,
---
>             self.total_num_kv_heads,
>             bias=False,
95c129
<             bias=bias,
---
>             bias=False,
97a132
> 
101,102c136,137
<             max_position=self.max_position_embeddings,
<             base=self.rope_theta,
---
>             max_position=max_position_embeddings,
>             base=rope_theta,
105c140,143
<         self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
---
>         self.attn = PagedAttention(self.num_heads,
>                                    self.head_dim,
>                                    self.scaling,
>                                    num_kv_heads=self.num_kv_heads)
115c153
<         q, k, v = qkv.chunk(chunks=3, dim=-1)
---
>         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
123c161
< class InternLMDecoderLayer(nn.Module):
---
> class LlamaDecoderLayer(nn.Module):
129c167
<     ):
---
>     ) -> None:
132a171
>         rope_scaling = getattr(config, "rope_scaling", None)
135c174
<         self.self_attn = InternLMAttention(
---
>         self.self_attn = LlamaAttention(
138c177
<             bias=config.bias,
---
>             num_kv_heads=config.num_key_value_heads,
139a179
>             rope_scaling=rope_scaling,
142d181
<             rope_scaling=getattr(config, "rope_scaling", None),
144c183
<         self.mlp = InternLMMLP(
---
>         self.mlp = LlamaMLP(
184c223
< class InternLMModel(nn.Module):
---
> class LlamaModel(nn.Module):
190c229,230
<     ):
---
>         lora_config: Optional[LoRAConfig] = None,
>     ) -> None:
194,196c234,237
<         self.vocab_size = config.vocab_size
< 
<         vocab_size = ((config.vocab_size + 63) // 64) * 64
---
>         lora_vocab = (lora_config.lora_extra_vocab_size *
>                       (lora_config.max_loras or 1)) if lora_config else 0
>         self.vocab_size = config.vocab_size + lora_vocab
>         self.org_vocab_size = config.vocab_size
198c239
<             vocab_size,
---
>             self.vocab_size,
199a241
>             org_num_embeddings=config.vocab_size,
202c244
<             InternLMDecoderLayer(config, linear_method)
---
>             LlamaDecoderLayer(config, linear_method)
229c271,282
< class InternLMForCausalLM(nn.Module):
---
> class LlamaForCausalLM(nn.Module):
>     packed_modules_mapping = {
>         "qkv_proj": [
>             "q_proj",
>             "k_proj",
>             "v_proj",
>         ],
>         "gate_up_proj": [
>             "gate_proj",
>             "up_proj",
>         ],
>     }
230a284,298
>     # LoRA specific attributes
>     supported_lora_modules = [
>         "qkv_proj",
>         "o_proj",
>         "gate_up_proj",
>         "down_proj",
>         "embed_tokens",
>         "lm_head",
>     ]
>     embedding_modules = {
>         "embed_tokens": "input_embeddings",
>         "lm_head": "output_embeddings",
>     }
>     embedding_padding_modules = ["lm_head"]
> 
233c301
<         config,
---
>         config: LlamaConfig,
235c303,304
<     ):
---
>         lora_config: Optional[LoRAConfig] = None,
>     ) -> None:
239,241c308,321
<         self.model = InternLMModel(config, linear_method)
<         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
<         self.sampler = Sampler(config.vocab_size)
---
>         self.model = LlamaModel(config, linear_method, lora_config=lora_config)
>         self.unpadded_vocab_size = config.vocab_size
>         if lora_config:
>             self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
>         self.lm_head = ParallelLMHead(
>             self.unpadded_vocab_size,
>             config.hidden_size,
>             org_num_embeddings=config.vocab_size,
>             padding_size=DEFAULT_VOCAB_PADDING_SIZE
>             # We need bigger padding if using lora for kernel
>             # compatibility
>             if not lora_config else lora_config.lora_vocab_padding_size,
>         )
>         self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
280a361,365
>             if ("rotary_emb.cos_cached" in name
>                     or "rotary_emb.sin_cached" in name):
>                 # Models trained using ColossalAI may include these tensors in
>                 # the checkpoint. Skip them.
>                 continue

@WoosukKwon WoosukKwon self-requested a review February 14, 2024 00:41
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR!

vllm/model_executor/models/llama.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon merged commit 7eacffd into vllm-project:main Feb 14, 2024
17 of 19 checks passed
WoosukKwon pushed a commit that referenced this pull request Feb 14, 2024
jvmncs pushed a commit to jvmncs/vllm that referenced this pull request Feb 14, 2024
jvmncs pushed a commit to jvmncs/vllm that referenced this pull request Feb 14, 2024
@@ -179,6 +180,7 @@ def __init__(
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
bias=getattr(config, "bias", False),
Copy link

@masahi masahi Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be "attention_bias"? (Note they use the term in a different sense compared to the conventional one)

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/configuration_llama.py#L161

xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
coolCatalyst added a commit to coolCatalyst/vllm that referenced this pull request Jun 1, 2024
shaojiewang pushed a commit to shaojiewang/vllm-rocm that referenced this pull request Jul 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants