Skip to content

Commit

Permalink
Added Support for Vanilla and Quantized ChatGLM3 Models to Model Buil…
Browse files Browse the repository at this point in the history
…der (#921)

We have integrated support for both vanilla and quantized versions of
the [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) models into the
Model Builder. Additionally, we have successfully performed parity
checks for both model types to ensure consistency and reliability across
different configurations.

---------

Co-authored-by: Bowen Bao <[email protected]>
  • Loading branch information
amd-sudo-sh and BowenBao authored Oct 18, 2024
1 parent 9e6f4ca commit f8f8c12
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 49 deletions.
87 changes: 67 additions & 20 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# Modifications Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved
"""
Run this script to create the desired ONNX model.
"""
Expand All @@ -21,16 +22,16 @@


class Model:
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.context_length = config.max_position_embeddings
self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else config.max_position_embeddings
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.context_length = config.seq_length if hasattr(config, "seq_length") else config.max_position_embeddings
self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.rope_scaling["original_max_position_embeddings"] if hasattr(config, "rope_scaling") and hasattr(config.rope_scaling, "original_max_position_embeddings") else self.context_length
self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel
self.intermediate_size = config.intermediate_size
self.intermediate_size = config.ffn_hidden_size if hasattr(config, "ffn_hidden_size") else config.intermediate_size
self.hidden_size = config.hidden_size
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.multi_query_group_num if hasattr(config, "multi_query_group_num") else config.num_attention_heads
self.num_attn_heads = config.num_attention_heads
self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers if hasattr(config, "num_hidden_layers") else config.num_layers
self.vocab_size = config.vocab_size
self.activation = config.hidden_activation if hasattr(config, "hidden_activation") and config.hidden_activation is not None else config.hidden_act

Expand Down Expand Up @@ -1504,14 +1505,15 @@ def make_mlp(self, layer_id, mlp, root_input):
raise NotImplementedError(f"The MLP layer type is not set.")

def make_mlp_unpacked(self, layer_id, mlp, root_input):
packed_proj = getattr(mlp, "gate_up_proj", None) or getattr(mlp, "dense_h_to_4h", None)
mlp.gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size)
mlp.gate_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[ : self.intermediate_size, :])
mlp.gate_proj.weight = torch.nn.Parameter(packed_proj.weight[: self.intermediate_size, :])

mlp.up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size)
mlp.up_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[self.intermediate_size :, :])
mlp.up_proj.weight = torch.nn.Parameter(packed_proj.weight[self.intermediate_size :, :])

# Delete original packed weights
del mlp.gate_up_proj
del packed_proj

def make_mlp_proj(self, layer_id, mlp, root_input):
# Make nodes for the MLP subgraph
Expand Down Expand Up @@ -1541,8 +1543,9 @@ def make_mlp_proj(self, layer_id, mlp, root_input):
self.make_mul(mul_name, mul_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size])

# Make output MatMul node
down_proj = getattr(mlp, "down_proj", None) or getattr(mlp, "dense_4h_to_h", None)
down_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul"
down_name = self.make_matmul(mlp.down_proj, down_basename, f"{mul_name}/output_0")
down_name = self.make_matmul(down_proj, down_basename, f"{mul_name}/output_0")

# Assign output 0 of previous MatMul as skip input to next SkipLayerNorm
self.layernorm_attrs["skip_input"] = f"{down_name}/output_0"
Expand Down Expand Up @@ -1752,7 +1755,7 @@ def make_relu_squared(self, layer_id, root_input, activation):
return pow_name

def make_activation(self, layer_id, root_input):
if self.activation in {"silu", "swish"}:
if self.activation in {"silu", "swish", "swiglu"}:
output_name = self.make_activation_with_mul(layer_id, root_input, activation="Sigmoid", domain=None)
elif self.activation in {"gelu_new", "gelu_fast", "gelu_pytorch_tanh"}:
output_name = self.make_gelu(layer_id, root_input, activation="FastGelu")
Expand Down Expand Up @@ -1836,7 +1839,17 @@ def make_model(self, input_path):
from onnxruntime_genai.models.quantized_model import QuantModel
q_size = self.num_attn_heads * self.head_size
kv_size = self.num_kv_heads * self.head_size
model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers)
model = QuantModel.from_pretrained(
self.quant_type,
input_path,
self.quant_attrs["bits"],
self.quant_attrs["group_size"],
self.quant_attrs["use_g_idx"],
q_size,
kv_size,
self.intermediate_size,
self.num_layers,
)
else:
# Load PyTorch model
extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {}
Expand All @@ -1845,6 +1858,7 @@ def make_model(self, input_path):
# Loop through model and map each module to ONNX/ORT ops
self.layer_id = 0
for module in model.modules():

if isinstance(module, torch.nn.Embedding) or (hasattr(model, "embedding") and module == model.embedding):
# Checks (Hugging Face logic) or (GGUF logic)
if not self.exclude_embeds:
Expand All @@ -1856,7 +1870,7 @@ def make_model(self, input_path):
self.layernorm_attrs["root_input"] = "inputs_embeds"
self.layernorm_attrs["skip_input"] = "inputs_embeds"

elif module.__class__.__name__.endswith("DecoderLayer") and self.layer_id < self.num_layers:
elif (module.__class__.__name__.endswith("DecoderLayer") or module.__class__.__name__.endswith("GLMBlock")) and self.layer_id < self.num_layers:
# Each decoder layer of model
print(f"Reading decoder layer {self.layer_id}")
self.make_layer(self.layer_id, module)
Expand All @@ -1866,7 +1880,7 @@ def make_model(self, input_path):
# SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm)
print("Reading final norm")
self.make_layernorm(self.layer_id, module, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm")

elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or (hasattr(model, "lm_head") and module == model.lm_head):
# Checks (Hugging Face logic) or (GGUF logic)
if not self.exclude_lm_head:
Expand All @@ -1877,12 +1891,13 @@ def make_model(self, input_path):
del model

def has_final_norm(self, module, model):
# Hugging Face names
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm
# GGUF names
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm
return hf_norm or hf_final_layernorm or gguf_final_norm
# Hugging Face names
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm
hf_transformer_final_layernorm = hasattr(model, "transformer") and hasattr(model.transformer, "encoder") and hasattr(model.transformer.encoder, "final_layernorm") and module == model.transformer.encoder.final_layernorm
# GGUF names
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm
return hf_norm or hf_final_layernorm or hf_transformer_final_layernorm or gguf_final_norm

def make_preprocessing_nodes(self):
self.make_attention_mask_reformatting()
Expand Down Expand Up @@ -2806,6 +2821,34 @@ def make_layer(self, layer_id, layer):
self.layernorm_attrs["last_layernorm"] = True


class ChatGLMModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.rotemb_attrs["num_heads"] = self.num_attn_heads
self.rotemb_attrs["partial_rotary_factor"] = 0.5 # Line 755 of modeling_chatglm.py check self.rotary_pos_emb declaration
self.rotemb_attrs["rotary_embedding_dim"] = int(self.head_size * self.rotemb_attrs["partial_rotary_factor"])
self.rotemb_attrs["interleaved"] = 1

def make_rotary_embedding(self, rotemb, name, root_input, **kwargs):
super().make_rotary_embedding(rotemb, name, root_input, num_heads=self.rotemb_attrs["num_heads"], rotary_embedding_dim=self.rotemb_attrs["rotary_embedding_dim"], **kwargs)

def make_attention(self, layer_id, attention, root_input, **kwargs):
if self.quant_type is None:
super().make_attention_unpacked(layer_id, attention, root_input, **kwargs)
# Add dummy rotary_emb attribute
attention.rotary_emb = type("RotaryEmbedding", (object,), {'content':{}})()
return super().make_attention(layer_id, attention, root_input, **kwargs)


def make_mlp_proj(self, layer_id, mlp, root_input):
if self.quant_type is None:
super().make_mlp_unpacked(layer_id, mlp, root_input)
super().make_mlp_proj(layer_id, mlp, root_input)

def make_layer(self, layer_id, layer):
layer.self_attn = layer.self_attn if hasattr(layer, 'self_attn') else layer.self_attention
super().make_layer(layer_id, layer)

def check_extra_options(kv_pairs):
if "use_8bits_moe" in kv_pairs:
assert(kv_pairs["use_8bits_moe"] == "1" or kv_pairs["use_8bits_moe"] == "0"), "use_8bits_moe must be 0 or 1."
Expand Down Expand Up @@ -2895,6 +2938,10 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
onnx_model = QwenModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "NemotronForCausalLM":
onnx_model = NemotronModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "ChatGLMForConditionalGeneration" or config.architectures[0] == "ChatGLMModel":
# Quantized ChatGLM model has ChatGLMForConditionalGeneration as architecture whereas HF model as the latter
config.hidden_act = "swiglu"
onnx_model = ChatGLMModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
else:
raise NotImplementedError(f"The {hf_name} model is not currently supported.")

Expand Down
Loading

0 comments on commit f8f8c12

Please sign in to comment.