Skip to content

Commit

Permalink
zhinao merge pull 4097,3950
Browse files Browse the repository at this point in the history
  • Loading branch information
caokai committed Apr 17, 2024
1 parent bdbc416 commit 516c9be
Showing 1 changed file with 49 additions and 15 deletions.
64 changes: 49 additions & 15 deletions vllm/model_executor/models/zhinao.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2024 The Zhinao team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
Expand All @@ -20,15 +21,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
"""Inference-only Zhinao model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
Expand All @@ -40,12 +43,11 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip


class ZhinaoMLP(nn.Module):
Expand Down Expand Up @@ -115,6 +117,15 @@ def __init__(
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0

self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
Expand Down Expand Up @@ -153,7 +164,8 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
output, _ = self.o_proj(attn_output)
return output

Expand All @@ -172,6 +184,10 @@ def __init__(
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = ZhinaoAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
Expand All @@ -181,7 +197,7 @@ def __init__(
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
bias=getattr(config, "bias", False),
bias=attention_bias,
sliding_window=sliding_window,
)
self.mlp = ZhinaoMLP(
Expand Down Expand Up @@ -355,11 +371,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
# ("qkv_proj", "q_proj", "q"),
Expand All @@ -372,8 +384,7 @@ def load_weights(self,
total_num_kv_heads = self.config.num_attention_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
Expand Down Expand Up @@ -423,3 +434,26 @@ def load_weights(self,
default_weight_loader)
weight_loader(param, loaded_weight)

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
layer_self_attn = self.model.layers[layer_idx].self_attn

if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

0 comments on commit 516c9be

Please sign in to comment.