Skip to content

Commit

Permalink
Support for Cohere R7B model
Browse files Browse the repository at this point in the history
  • Loading branch information
janimo committed Dec 14, 2024
1 parent 9c3dadd commit 5c8027d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Text Generation (``--task generate``)
- ✅︎
* - :code:`CohereForCausalLM`
- Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
- :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc.
- ✅︎
- ✅︎
* - :code:`DbrxForCausalLM`
Expand Down
14 changes: 12 additions & 2 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -171,12 +171,21 @@ def __init__(
rope_scaling=self.rope_scaling,
is_neox_style=False,
)

layer_idx = extract_layer_index(prefix)
is_sliding = (
getattr(config, "sliding_window_pattern", False)
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)

self.sliding_window = (getattr(config, "sliding_window", None)
if is_sliding else None)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=self.sliding_window,
prefix=f"{prefix}.attn")
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
Expand Down Expand Up @@ -206,7 +215,8 @@ def forward(
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
if self.sliding_window:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
Expand Down

0 comments on commit 5c8027d

Please sign in to comment.