Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate InternLMForCausalLM to LlamaForCausalLM #2860

Merged
merged 3 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
Expand Down
299 changes: 0 additions & 299 deletions vllm/model_executor/models/internlm.py

This file was deleted.

6 changes: 4 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
bias: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -120,13 +121,13 @@ def __init__(
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
bias=bias,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
bias=bias,
linear_method=linear_method,
)

Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
bias=getattr(config, "bias", False),
Copy link

@masahi masahi Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be "attention_bias"? (Note they use the term in a different sense compared to the conventional one)

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/configuration_llama.py#L161

)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand Down
Loading