Skip to content

Commit

Permalink
Migrate InternLMForCausalLM to LlamaForCausalLM (vllm-project#2860)
Browse files Browse the repository at this point in the history
Co-authored-by: Roy <[email protected]>
  • Loading branch information
2 people authored and jimpang committed Mar 4, 2024
1 parent 99c2320 commit 92d45f8
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 302 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
Expand Down
299 changes: 0 additions & 299 deletions vllm/model_executor/models/internlm.py

This file was deleted.

6 changes: 4 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
bias: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -120,13 +121,13 @@ def __init__(
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
bias=bias,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
bias=bias,
linear_method=linear_method,
)

Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
bias=getattr(config, "bias", False),
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand Down

0 comments on commit 92d45f8

Please sign in to comment.