Skip to content

Commit

Permalink
[New Model] support Baichuan-M1 model
Browse files Browse the repository at this point in the history
Signed-off-by: dangshunya <[email protected]>
  • Loading branch information
dangshunya committed Jan 22, 2025
1 parent 68ad4e3 commit 23b8eab
Show file tree
Hide file tree
Showing 8 changed files with 894 additions and 9 deletions.
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ See [this page](#generative-models) for more information on how to use generativ
- `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.
- ✅︎
- ✅︎
* - `BaichuanM1ForCausalLM`
- Baichuan-M1
- `baichuan-inc/Baichuan-M1-14B-Instruct`, `baichuan-inc/Baichuan-M1-14B-Base`, etc.
- ✅︎
- ✅︎
* - `BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- `bigscience/bloom`, `bigscience/bloomz`, etc.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def check_available_online(
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True),
"BaichuanM1ForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-M1-14B-Instruct", # noqa: E501
trust_remote_code=True),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
Expand Down
48 changes: 45 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,12 @@ def __init__(
self.enforce_eager = False

sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_layers = getattr(self.hf_text_config,
"sliding_window_layers", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
(self.hf_text_config.model_type in ["gemma2", "cohere2"])
or sliding_window_layers is not None)

if (not self.disable_sliding_window and has_interleaved_attention):
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
Expand Down Expand Up @@ -721,6 +724,9 @@ def get_hf_config_sliding_window(
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
if hasattr(self.hf_text_config, 'sliding_window_layers'):
return None

return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
Expand All @@ -732,6 +738,10 @@ def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()

def get_sliding_window_layers(self,
parallel_config) -> Optional[List[int]]:
return getattr(self.hf_text_config, "sliding_window_layers", [])

def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size

Expand Down Expand Up @@ -759,6 +769,12 @@ def get_head_size(self) -> int:
return (self.hf_text_config.hidden_size //
self.hf_text_config.num_attention_heads)

def get_head_size_swa(self) -> int:
if hasattr(self.hf_text_config, "num_swa_attention_heads"):
return (self.hf_text_config.hidden_size //
self.hf_text_config.num_swa_attention_heads)
return self.get_head_size()

def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
Expand Down Expand Up @@ -805,6 +821,22 @@ def get_total_num_kv_heads(self) -> int:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads

def get_total_num_kv_heads_swa(self) -> int:
if hasattr(self.hf_text_config, "num_swa_key_value_heads"):
return self.hf_text_config.num_swa_key_value_heads
return self.get_total_num_kv_heads()

def get_num_swa_key_value_heads(self,
parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads_swa = self.get_total_num_kv_heads_swa()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(
1, total_num_kv_heads_swa // parallel_config.tensor_parallel_size)

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
Expand Down Expand Up @@ -847,7 +879,18 @@ def get_num_layers_by_block_type(

if is_transformer:
# Handle the basic case first
return end - start if attn_block_type else 0
swa_layers = self.get_sliding_window_layers(parallel_config)
num_layers = 0
if not swa_layers:
num_layers = end - start if attn_block_type else 0
else:
for layer_id in range(start, end):
if (block_type == LayerBlockType.attention
and layer_id not in swa_layers) or (
block_type == LayerBlockType.swa
and layer_id in swa_layers):
num_layers += 1
return num_layers
elif self.is_attention_free:
# Attention free
# Note that this code assumes there
Expand Down Expand Up @@ -2363,7 +2406,6 @@ def _get_and_verify_max_len(
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)

# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:
Expand Down
13 changes: 10 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None

# This is used to evict the finished requests from the Mamba cache and
# Baichuan-M1, We should use it to keep finished_req_ids when scheduler
# is empty.
self.finished_requests_ids: List[str] = list()

# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
Expand Down Expand Up @@ -1315,6 +1320,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:

finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
self.finished_requests_ids.extend(finished_requests_ids)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
Expand All @@ -1327,8 +1333,6 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
else:
finished_requests_ids = list()

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
Expand All @@ -1349,11 +1353,14 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
finished_requests_ids=self.finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)

# Clear finished_requests_ids list.
self.finished_requests_ids = list()

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
Expand Down
Loading

0 comments on commit 23b8eab

Please sign in to comment.