diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index 12daf8422..61d055e33 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -10,6 +10,7 @@ This folder contains the model builder for quickly creating optimized and quanti - [Original PyTorch Model from Hugging Face](#original-pytorch-model-from-hugging-face) - [Original PyTorch Model from Disk](#original-pytorch-model-from-disk) - [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model) + - [Quantized PyTorch Model](#quantized-pytorch-model) - [GGUF Model](#gguf-model) - [Extra Options](#extra-options) - [Config Only](#config-only) @@ -82,6 +83,18 @@ python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o p python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files ``` +### Quantized PyTorch model + +This scenario is where your PyTorch model is one of the currently supported model architectures, has already been quantized to INT4 precision, and your model can be loaded in the Hugging Face style via [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) or [AutoAWQ](https://github.com/casper-hansen/AutoAWQ). + +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p int4 -e execution_provider -c cache_dir_to_store_temp_files + +# From source: +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p int4 -e execution_provider -c cache_dir_to_store_temp_files +``` + ### GGUF Model This scenario is where your float16/float32 GGUF model is already on disk. diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index bc507076a..8da1e8046 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -38,6 +38,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.model_type = config.architectures[0] self.io_dtype = io_dtype # {'fp16', 'fp32'} self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"} + self.quant_type = config.quantization_config["quant_method"] if hasattr(config, "quantization_config") else None self.cache_dir = cache_dir self.filename = extra_options["filename"] if "filename" in extra_options else "model.onnx" @@ -251,6 +252,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "accuracy_level": int(extra_options["int4_accuracy_level"]) if "int4_accuracy_level" in extra_options else None, } } + if self.quant_type is not None: + # Create quantized attributes from quantization config + self.quant_attrs["bits"] = config.quantization_config["bits"] + self.quant_attrs["group_size"] = config.quantization_config["group_size"] + self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): config = GenerationConfig.from_pretrained(model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs) @@ -351,7 +357,7 @@ def save_model(self, out_dir): # Quantize ONNX model to desired precision # TODO: Replace by quantizing the MatMuls as they are created - if self.onnx_dtype == "int4": + if self.onnx_dtype == "int4" and self.quant_type is None: model = self.to_int4(model) # Save ONNX model with only one external data file and delete any existing duplicate copies @@ -601,45 +607,135 @@ def make_transpose(self, name, root_input, dtype, shape, perm): self.make_node("Transpose", inputs=[root_input], outputs=[output], name=name, perm=perm) self.make_value_info(output, dtype, shape=shape) - def make_matmul(self, matmul, name, root_input, **kwargs): - self.make_matmul_fp16_or_fp32(matmul, name, root_input, **kwargs) - - # TODO: add other dtypes - # if self.onnx_dtype in {"fp16", "fp32"}: - # self.make_matmul_fp16_or_fp32(matmul, name, root_input, **kwargs) - # elif self.onnx_dtype == "int8": - # pass - # elif self.onnx_dtype == "int4": - # int4_name = f"{name}NBits" - # self.make_matmul_int4(matmul, int4_name, root_input, **kwargs) + def make_matmul(self, matmul, basename, root_input, **kwargs): + if self.onnx_dtype in {"fp16", "fp32"}: + return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs) + elif self.onnx_dtype == "int4": + return self.make_matmul_int4(matmul, basename, root_input, **kwargs) + else: + raise NotImplementedError(f"The {self.onnx_dtype} precision is not currently supported.") def make_matmul_fp16_or_fp32(self, matmul, name, root_input, **kwargs): weight = name[1:].replace("/", ".") + ".weight" - self.make_external_tensor(matmul.transpose().astype(self.to_numpy_dtype[self.io_dtype]), weight) + self.make_external_tensor(matmul.weight.detach().numpy().transpose().astype(self.to_numpy_dtype[self.io_dtype]), weight) - last_dim = matmul.shape[0] + last_dim = matmul.weight.shape[0] output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" self.make_node("MatMul", inputs=[root_input, weight], outputs=[output], name=name) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', last_dim]) - # TODO: quantize weights, then save new MatMul numpy weights for onnx model - # def make_matmul_int4(self, matmul, name, root_input, **kwargs): - # weight = name[1:].replace("/", ".") + ".weight" - # scales = name[1:].replace("/", ".") + ".scales" + return name + + def make_matmul_int4(self, matmul, basename, root_input, **kwargs): + if not hasattr(matmul, "qweight"): + # TODO: quantize weights, then save new MatMul numpy weights for onnx model + # print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.") + # print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.") + return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs) + + name = f"{basename}NBits" + + # Input weights are quantized, save quantized MatMul numpy weights for onnx model + weight = name[1:].replace("/", ".") + ".qweight" + self.make_external_tensor(matmul.qweight.detach().numpy(), weight) + scales = name[1:].replace("/", ".") + ".scales" + self.make_external_tensor(matmul.scales.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), scales) - # output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" - # self.make_node("MatMulNBits", inputs=[root_input, weight, scales], outputs=[output], name=name) - # self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + inputs = [root_input, weight, scales] - def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, name, root_input, **kwargs): + if hasattr(matmul, "qzeros") and matmul.qzeros is not None: + zeros = name[1:].replace("/", ".") + ".qzeros" + self.make_external_tensor(matmul.qzeros.detach().numpy(), zeros) + inputs.append(zeros) + + if hasattr(matmul, "g_idx") and matmul.g_idx is not None: + g_idx = name[1:].replace("/", ".") + ".g_idx" + self.make_external_tensor(matmul.g_idx.detach().numpy().astype(np.int32), g_idx) + inputs.append(g_idx) + + output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" + self.make_node( + "MatMulNBits", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", + bits=matmul.bits, block_size=matmul.group_size, K=matmul.in_features, N=matmul.out_features, + ) + self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features]) + + return name + + def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs): + if self.onnx_dtype in {"fp16", "fp32"}: + return self.make_packed_matmul_fp16_or_fp32(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) + elif self.onnx_dtype == "int4": + return self.make_packed_matmul_int4(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) + else: + raise NotImplementedError(f"The {self.onnx_dtype} precision is not currently supported.") + + def make_packed_matmul_fp16_or_fp32(self, q_matmul, k_matmul, v_matmul, name, root_input, **kwargs): # N_q = num_attention_heads * head_size, N_kv = num_key_value_heads * head_size, H = hidden_size # Combine 3 MatMuls of shape N_q x H, N_kv x H, N_kv x H into 1 packed MatMul of shape (N_q+N_kv+N_kv)xH # Note: Packed MatMul is of shape (N_q+N_kv+N_kv)xH instead of Hx(N_q+N_kv+N_kv) because `make_matmul` will # apply a transpose before saving - N_q, H = q_matmul.shape - N_kv, _ = k_matmul.shape - matmul = np.concatenate([q_matmul, k_matmul, v_matmul], axis=0).reshape(N_q + N_kv + N_kv, H) - self.make_matmul(matmul, name, root_input, **kwargs) + N_q, H = q_matmul.weight.shape + N_kv, _ = k_matmul.weight.shape + + # Create dummy PackedMatMul class + class PackedMatMul: + def __init__(self): + self.weight = torch.concatenate([q_matmul.weight.detach().cpu(), k_matmul.weight.detach().cpu(), v_matmul.weight.detach().cpu()], dim=0).reshape(N_q + N_kv + N_kv, H) + matmul = PackedMatMul() + new_name = self.make_matmul(matmul, name, root_input, **kwargs) + + return new_name + + def make_packed_matmul_int4(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs): + if not hasattr(q_matmul, "qweight"): + # TODO: quantize weights, then save new MatMul numpy weights for onnx model + # print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.") + # print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.") + return self.make_packed_matmul_fp16_or_fp32(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs) + + name = f"{basename}NBits" + + # Create dummy PackedMatMul class + class PackedMatMul: + def __init__(self): + self.qweight = torch.concatenate([q_matmul.qweight.detach().cpu(), k_matmul.qweight.detach().cpu(), v_matmul.qweight.detach().cpu()], dim=0) + self.scales = torch.concatenate([q_matmul.scales.detach().cpu(), k_matmul.scales.detach().cpu(), v_matmul.scales.detach().cpu()], dim=0) + self.qzeros = torch.concatenate([q_matmul.qzeros.detach().cpu(), k_matmul.qzeros.detach().cpu(), v_matmul.qzeros.detach().cpu()], dim=0) + self.g_idx = q_matmul.g_idx + + self.in_features = q_matmul.in_features + self.out_features = q_matmul.out_features + k_matmul.out_features + v_matmul.out_features + self.bits = q_matmul.bits + self.group_size = q_matmul.group_size + matmul = PackedMatMul() + + # Input weights are quantized, save quantized MatMul numpy weights for onnx model + weight = name[1:].replace("/", ".") + ".qweight" + self.make_external_tensor(matmul.qweight.detach().numpy(), weight) + scales = name[1:].replace("/", ".") + ".scales" + self.make_external_tensor(matmul.scales.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), scales) + + inputs = [root_input, weight, scales] + + if hasattr(matmul, "qzeros") and matmul.qzeros is not None: + zeros = name[1:].replace("/", ".") + ".qzeros" + self.make_external_tensor(matmul.qzeros.detach().numpy(), zeros) + inputs.append(zeros) + + if hasattr(matmul, "g_idx") and matmul.g_idx is not None: + g_idx = name[1:].replace("/", ".") + ".g_idx" + self.make_external_tensor(matmul.g_idx.detach().numpy().astype(np.int32), g_idx) + inputs.append(g_idx) + + output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" + self.make_node( + "MatMulNBits", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", + bits=matmul.bits, block_size=matmul.group_size, K=matmul.in_features, N=matmul.out_features, + ) + self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features]) + + return name def make_add_bias(self, add, name, root_input, **kwargs): bias = name[1:].replace("/", ".") + ".bias" @@ -737,7 +833,7 @@ def make_mscale_yarn(self, mscale): return 0.1 * np.log(mscale) + 1.0 def make_mscale(self, mscale): - if self.rotemb_attrs["mscale_policy"] == "su": + if self.rotemb_attrs["mscale_policy"] in {"su", "longrope"}: return self.make_mscale_su(mscale) elif self.rotemb_attrs["mscale_policy"] == "yarn": return self.make_mscale_yarn(mscale) @@ -1144,24 +1240,24 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # Make MatMul nodes if self.attention_attrs["use_packed_matmul"]: # Combine 3 MatMuls into 1 packed MatMul - qkv_matmul_name = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul" - self.make_packed_matmul(attention.q_proj.weight.detach().numpy(), attention.k_proj.weight.detach().numpy(), attention.v_proj.weight.detach().numpy(), qkv_matmul_name, root_input) + qkv_matmul_basename = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul" + qkv_matmul_name = self.make_packed_matmul(attention.q_proj, attention.k_proj, attention.v_proj, qkv_matmul_basename, root_input) q_input_to_attention = f"{qkv_matmul_name}/output_0" else: - q_matmul_name = f"/model/layers.{layer_id}/attn/q_proj/MatMul" - self.make_matmul(attention.q_proj.weight.detach().numpy(), q_matmul_name, root_input) + q_matmul_basename = f"/model/layers.{layer_id}/attn/q_proj/MatMul" + q_matmul_name = self.make_matmul(attention.q_proj, q_matmul_basename, root_input) q_input_to_attention = f"{q_matmul_name}/output_0" - k_matmul_name = f"/model/layers.{layer_id}/attn/k_proj/MatMul" - self.make_matmul(attention.k_proj.weight.detach().numpy(), k_matmul_name, root_input) + k_matmul_basename = f"/model/layers.{layer_id}/attn/k_proj/MatMul" + k_matmul_name = self.make_matmul(attention.k_proj, k_matmul_basename, root_input) k_input_to_attention = f"{k_matmul_name}/output_0" - v_matmul_name = f"/model/layers.{layer_id}/attn/v_proj/MatMul" - self.make_matmul(attention.v_proj.weight.detach().numpy(), v_matmul_name, root_input) + v_matmul_basename = f"/model/layers.{layer_id}/attn/v_proj/MatMul" + v_matmul_name = self.make_matmul(attention.v_proj, v_matmul_basename, root_input) v_input_to_attention = f"{v_matmul_name}/output_0" # Make Add nodes (if bias exists) - q_bias_exists = attention.q_proj.bias is not None - k_bias_exists = attention.k_proj.bias is not None - v_bias_exists = attention.v_proj.bias is not None + q_bias_exists = attention.q_proj.bias is not None and torch.count_nonzero(attention.q_proj.bias) > 0 + k_bias_exists = attention.k_proj.bias is not None and torch.count_nonzero(attention.k_proj.bias) > 0 + v_bias_exists = attention.v_proj.bias is not None and torch.count_nonzero(attention.v_proj.bias) > 0 all_bias_exists = q_bias_exists and k_bias_exists and v_bias_exists if all_bias_exists and self.attention_attrs["use_packed_matmul"]: @@ -1215,9 +1311,9 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # Make MatMul node (output projection weight node) o_proj = 'o_proj' if hasattr(attention, 'o_proj') else 'dense' - o_matmul_name = f"/model/layers.{layer_id}/attn/o_proj/MatMul" - o_weight = eval(f"attention.{o_proj}.weight.detach().numpy()") - self.make_matmul(o_weight, o_matmul_name, f"{attn_name}/output_0") + o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" + o_weight = eval(f"attention.{o_proj}") + o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, f"{attn_name}/output_0") # Make Add node (output projection bias node if bias exists) o_bias_exists = eval(f"attention.{o_proj}.bias") is not None @@ -1263,6 +1359,16 @@ def make_mlp(self, layer_id, mlp, root_input): else: raise NotImplementedError(f"The MLP layer type is not set.") + def make_mlp_unpacked(self, layer_id, mlp, root_input): + mlp.gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) + mlp.gate_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[ : self.intermediate_size, :]) + + mlp.up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) + mlp.up_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[self.intermediate_size :, :]) + + # Delete original packed weights + del mlp.gate_up_proj + def make_mlp_proj(self, layer_id, mlp, root_input): # Make nodes for the MLP subgraph # @@ -1277,10 +1383,10 @@ def make_mlp_proj(self, layer_id, mlp, root_input): # DownProjMatMul # Make MatMul nodes - gate_name = f"/model/layers.{layer_id}/mlp/gate_proj/MatMul" - self.make_matmul(mlp.gate_proj.weight.detach().numpy(), gate_name, root_input) - up_name = f"/model/layers.{layer_id}/mlp/up_proj/MatMul" - self.make_matmul(mlp.up_proj.weight.detach().numpy(), up_name, root_input) + gate_basename = f"/model/layers.{layer_id}/mlp/gate_proj/MatMul" + gate_name = self.make_matmul(mlp.gate_proj, gate_basename, root_input) + up_basename = f"/model/layers.{layer_id}/mlp/up_proj/MatMul" + up_name = self.make_matmul(mlp.up_proj, up_basename, root_input) # Make activation node(s) act_fn_name = self.make_activation(layer_id, root_input=f"{gate_name}/output_0") @@ -1291,8 +1397,8 @@ def make_mlp_proj(self, layer_id, mlp, root_input): self.make_mul(mul_name, mul_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) # Make output MatMul node - down_name = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" - self.make_matmul(mlp.down_proj.weight.detach().numpy(), down_name, f"{mul_name}/output_0") + down_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" + down_name = self.make_matmul(mlp.down_proj, down_basename, f"{mul_name}/output_0") # Assign output 0 of previous MatMul as skip input to next SkipLayerNorm self.layernorm_attrs["skip_input"] = f"{down_name}/output_0" @@ -1313,8 +1419,8 @@ def make_mlp_fc(self, layer_id, mlp, root_input): # FC2_Add # Make first layer of fully connected nodes (FC1) - fc1_matmul_name = f"/model/layers.{layer_id}/mlp/fc1/MatMul" - self.make_matmul(mlp.fc1.weight.detach().numpy(), fc1_matmul_name, root_input) + fc1_matmul_basename = f"/model/layers.{layer_id}/mlp/fc1/MatMul" + fc1_matmul_name = self.make_matmul(mlp.fc1, fc1_matmul_basename, root_input) fc1_add_name = f"/model/layers.{layer_id}/mlp/fc1/Add" self.make_add_bias(mlp.fc1.bias.detach().numpy(), fc1_add_name, root_input=f"{fc1_matmul_name}/output_0") @@ -1322,8 +1428,8 @@ def make_mlp_fc(self, layer_id, mlp, root_input): act_fn_name = self.make_activation(layer_id, root_input=f"{fc1_add_name}/output_0") # Make second layer of fully connected nodes (FC2) - fc2_matmul_name = f"/model/layers.{layer_id}/mlp/fc2/MatMul" - self.make_matmul(mlp.fc2.weight.detach().numpy(), fc2_matmul_name, root_input=f"{act_fn_name}/output_0") + fc2_matmul_basename = f"/model/layers.{layer_id}/mlp/fc2/MatMul" + fc2_matmul_name = self.make_matmul(mlp.fc2, fc2_matmul_basename, root_input=f"{act_fn_name}/output_0") fc2_add_name = f"/model/layers.{layer_id}/mlp/fc2/Add" self.make_add_bias(mlp.fc2.bias.detach().numpy(), fc2_add_name, root_input=f"{fc2_matmul_name}/output_0") @@ -1380,9 +1486,9 @@ def make_lm_head(self, lm_head): scale_exists = self.lm_head_attrs["scale"] != 1 mask_exists = self.lm_head_attrs["mask"] is not None - matmul_name = "/lm_head/MatMul" + matmul_basename = "/lm_head/MatMul" root_input = self.layernorm_attrs["output_0"] - self.make_matmul(lm_head.weight.detach().numpy(), matmul_name, root_input, logits=not bias_exists and not scale_exists) + matmul_name = self.make_matmul(lm_head, matmul_basename, root_input, logits=not bias_exists and not scale_exists) if bias_exists: add_name = "/lm_head/Add" @@ -1432,6 +1538,12 @@ def make_model(self, input_path): from gguf_model import GGUFModel model = GGUFModel.from_pretrained(self.model_type, input_path, self.head_size, self.hidden_size, self.intermediate_size, self.num_attn_heads, self.num_kv_heads, self.vocab_size) self.layernorm_attrs["add_offset"] = 0 # add offset already done for GGUF models + elif self.quant_type is not None: + # Load quantized PyTorch model + from quantized_model import QuantModel + q_size = self.num_attn_heads * self.head_size + kv_size = self.num_kv_heads * self.head_size + model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size) else: # Load PyTorch model extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir} @@ -1451,7 +1563,7 @@ def make_model(self, input_path): self.layernorm_attrs["root_input"] = "inputs_embeds" self.layernorm_attrs["skip_input"] = "inputs_embeds" - elif module.__class__.__name__.endswith("DecoderLayer"): + elif module.__class__.__name__.endswith("DecoderLayer") and self.layer_id < self.num_layers: # Each decoder layer of model print(f"Reading decoder layer {self.layer_id}") self.make_layer(self.layer_id, module) @@ -1978,17 +2090,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) def make_attention(self, layer_id, attention, root_input, **kwargs): - super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) + if self.quant_type is None: + super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) super().make_attention(layer_id, attention, root_input, **kwargs) def make_mlp_proj(self, layer_id, mlp, root_input): - mlp.gate_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) - mlp.gate_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[ : self.intermediate_size, :]) - - mlp.up_proj = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size) - mlp.up_proj.weight = torch.nn.Parameter(mlp.gate_up_proj.weight[self.intermediate_size :, :]) - - del mlp.gate_up_proj + if self.quant_type is None: + super().make_mlp_unpacked(layer_id, mlp, root_input) super().make_mlp_proj(layer_id, mlp, root_input) diff --git a/src/python/py/models/quantized_model.py b/src/python/py/models/quantized_model.py new file mode 100644 index 000000000..48c4ec7bd --- /dev/null +++ b/src/python/py/models/quantized_model.py @@ -0,0 +1,658 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +""" +A set of Python classes to unpack the quantized weights and repack them in ONNX Runtime's +standard format. + +The goal is for `QuantModel` to repack the quantized weights into a standard format +so that the original Hugging Face --> ONNX code can store the quantized weights as +ONNX Runtime's format no matter where the quantized weights actually come from. +""" + +from safetensors.torch import load_file +import torch + +import os +import re + + +class QuantizedTensorModule: + def __init__(self, bits, group_size): + self.qweight = None + self.scales = None + self.qzeros = None + self.g_idx = None + self.bias = None + + self.in_features = 0 + self.out_features = 0 + self.bits = bits + self.group_size = group_size + + def __str__(self): + qweight = f"qweight = {self.qweight.shape}, {self.qweight}\n" + scales = f"scales = {self.scales.shape}, {self.scales}\n" + qzeros = "" if self.qzeros is None else f"qzeros = {self.qzeros.shape}, {self.qzeros}\n" + g_idx = "" if self.g_idx is None else f"g_idx = {self.g_idx.shape}, {self.g_idx}\n" + + in_feats = f"in_features = {self.in_features}, " + out_feats = f"out_features = {self.out_features}, " + bits = f"bits = {self.bits}, " + group_size = f"group_size = {self.group_size}, " + + return qweight + qzeros + scales + g_idx + in_feats + out_feats + bits + group_size + + +class TensorModule: + def __init__(self): + self.weight = None + self.bias = None + + +class QuantizedAttention: + def __init__(self, bits, group_size): + self.q_proj = QuantizedTensorModule(bits, group_size) + self.k_proj = QuantizedTensorModule(bits, group_size) + self.v_proj = QuantizedTensorModule(bits, group_size) + self.o_proj = QuantizedTensorModule(bits, group_size) + self.rotary_emb = TensorModule() + + +class QuantizedMLP: + def __init__(self, bits, group_size): + self.gate_proj = QuantizedTensorModule(bits, group_size) + self.up_proj = QuantizedTensorModule(bits, group_size) + self.down_proj = QuantizedTensorModule(bits, group_size) + self.fc1 = QuantizedTensorModule(bits, group_size) + self.fc2 = QuantizedTensorModule(bits, group_size) + + +class QuantizedDecoderLayer: + def __init__(self, layer_id, bits, group_size): + self.layer_id = layer_id + self.input_layernorm = TensorModule() + self.self_attn = QuantizedAttention(bits, group_size) + self.post_attention_layernorm = TensorModule() + self.mlp = QuantizedMLP(bits, group_size) + + def is_empty(self): + return self.input_layernorm.weight is None + + +class QuantizedModel: + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size): + self.quant_type = quant_type + self.embedding = TensorModule() + self.final_norm = TensorModule() + self.lm_head = TensorModule() + self.layers = [] + + layer_id = 0 + for weight_file in os.listdir(input_path): + if weight_file.endswith(".safetensors"): + module = QuantizedDecoderLayer(layer_id, bits, group_size) + weights = load_file(os.path.join(input_path, weight_file)) + + # Map weights to modules + for name, tensor in weights.items(): + if tensor.dtype == torch.bfloat16: + # Cast bfloat16 to float32 since NumPy does not support bfloat16 + tensor = tensor.to(torch.float32) + + if name == "model.embed_tokens.weight": + self.embedding.weight = tensor + elif name == "model.norm.weight": + self.final_norm.weight = tensor + elif name == "model.norm.bias": + self.final_norm.bias = tensor + elif name == "lm_head.weight": + self.lm_head.weight = tensor + elif name == "lm_head.bias": + self.lm_head.bias = tensor + else: + curr_layer_id = int(name.split(".")[2]) + if curr_layer_id != layer_id: + # Add layer to list of modules + self.layers.append(module) + layer_id = curr_layer_id + module = QuantizedDecoderLayer(layer_id, bits, group_size) + + # Map weights and biases of norm, attention, and feed-forward network + # Graph order is input_layernorm --> q_proj/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj + if bool(re.match(r"^model.layers\.\d+\.input_layernorm\.weight$", name)): + # model.layers.layer_id.input_layernorm.weight + module.input_layernorm.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.input_layernorm\.bias$", name)): + # model.layers.layer_id.input_layernorm.bias + module.input_layernorm.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.rotary_emb\.inv_freq$", name)): + # model.layers.layer_id.self_attn.rotary_emb.inv_freq + # Skip rotary embedding weights since they can be re-calculated when looping through the model + continue + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.qweight$", name)): + # model.layers.layer_id.self_attn.q_proj.qweight + module.self_attn.q_proj.qweight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.scales$", name)): + # model.layers.layer_id.self_attn.q_proj.scales + module.self_attn.q_proj.scales = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.qzeros$", name)): + # model.layers.layer_id.self_attn.q_proj.qzeros + module.self_attn.q_proj.qzeros = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.g_idx$", name)): + # model.layers.layer_id.self_attn.q_proj.g_idx + module.self_attn.q_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.bias$", name)): + # model.layers.layer_id.self_attn.q_proj.bias + module.self_attn.q_proj.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.qweight$", name)): + # model.layers.layer_id.self_attn.k_proj.qweight + module.self_attn.k_proj.qweight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.scales$", name)): + # model.layers.layer_id.self_attn.k_proj.scales + module.self_attn.k_proj.scales = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.qzeros$", name)): + # model.layers.layer_id.self_attn.k_proj.qzeros + module.self_attn.k_proj.qzeros = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.g_idx$", name)): + # model.layers.layer_id.self_attn.k_proj.g_idx + module.self_attn.k_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.bias$", name)): + # model.layers.layer_id.self_attn.k_proj.bias + module.self_attn.k_proj.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.qweight$", name)): + # model.layers.layer_id.self_attn.v_proj.qweight + module.self_attn.v_proj.qweight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.scales$", name)): + # model.layers.layer_id.self_attn.v_proj.scales + module.self_attn.v_proj.scales = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.qzeros$", name)): + # model.layers.layer_id.self_attn.v_proj.qzeros + module.self_attn.v_proj.qzeros = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.g_idx$", name)): + # model.layers.layer_id.self_attn.v_proj.g_idx + module.self_attn.v_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.bias$", name)): + # model.layers.layer_id.self_attn.v_proj.bias + module.self_attn.v_proj.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.o_proj\.qweight$", name)): + # model.layers.layer_id.self_attn.o_proj.qweight + module.self_attn.o_proj.qweight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.o_proj\.scales$", name)): + # model.layers.layer_id.self_attn.o_proj.scales + module.self_attn.o_proj.scales = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.o_proj\.qzeros$", name)): + # model.layers.layer_id.self_attn.o_proj.qzeros + module.self_attn.o_proj.qzeros = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.o_proj\.g_idx$", name)): + # model.layers.layer_id.self_attn.o_proj.g_idx + module.self_attn.o_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.o_proj\.bias$", name)): + # model.layers.layer_id.self_attn.o_proj.bias + module.self_attn.o_proj.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.post_attention_layernorm\.weight$", name)): + # model.layers.layer_id.post_attention_layernorm.weight + module.post_attention_layernorm.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.post_attention_layernorm\.bias$", name)): + # model.layers.layer_id.post_attention_layernorm.bias + module.post_attention_layernorm.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.qweight$", name)): + # model.layers.layer_id.mlp.gate_proj.qweight + module.mlp.gate_proj.qweight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.scales$", name)): + # model.layers.layer_id.mlp.gate_proj.scales + module.mlp.gate_proj.scales = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.qzeros$", name)): + # model.layers.layer_id.mlp.gate_proj.qzeros + module.mlp.gate_proj.qzeros = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.g_idx$", name)): + # model.layers.layer_id.mlp.gate_proj.g_idx + module.mlp.gate_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.bias$", name)): + # model.layers.layer_id.mlp.gate_proj.bias + module.mlp.gate_proj.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.qweight$", name)): + # model.layers.layer_id.mlp.up_proj.qweight + module.mlp.up_proj.qweight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.scales$", name)): + # model.layers.layer_id.mlp.up_proj.scales + module.mlp.up_proj.scales = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.qzeros$", name)): + # model.layers.layer_id.mlp.up_proj.qzeros + module.mlp.up_proj.qzeros = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.g_idx$", name)): + # model.layers.layer_id.mlp.up_proj.g_idx + module.mlp.up_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.bias$", name)): + # model.layers.layer_id.mlp.up_proj.bias + module.mlp.up_proj.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.down_proj\.qweight$", name)): + # model.layers.layer_id.mlp.down_proj.qweight + module.mlp.down_proj.qweight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.down_proj\.scales$", name)): + # model.layers.layer_id.mlp.down_proj.scales + module.mlp.down_proj.scales = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.down_proj\.qzeros$", name)): + # model.layers.layer_id.mlp.down_proj.qzeros + module.mlp.down_proj.qzeros = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.down_proj\.g_idx$", name)): + # model.layers.layer_id.mlp.down_proj.g_idx + module.mlp.down_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.down_proj\.bias$", name)): + # model.layers.layer_id.mlp.down_proj.bias + module.mlp.down_proj.bias = tensor + # Match against fused layers + elif bool(re.match(r"^model.layers\.\d+\.self_attn.qkv_proj\.qweight$", name)): + # model.layers.layer_id.self_attn.qkv_proj.qweight + q_dim = q_size // (32 // bits) if quant_type == "awq" else q_size + kv_dim = kv_size // (32 // bits) if quant_type == "awq" else kv_size + module.self_attn.q_proj.qweight = tensor[:, : q_dim] + module.self_attn.k_proj.qweight = tensor[:, q_dim : q_dim + kv_dim] + module.self_attn.v_proj.qweight = tensor[:, q_dim + kv_dim :] + elif bool(re.match(r"^model.layers\.\d+\.self_attn.qkv_proj\.scales$", name)): + # model.layers.layer_id.self_attn.qkv_proj.scales + module.self_attn.q_proj.scales = tensor[:, : q_size] + module.self_attn.k_proj.scales = tensor[:, q_size : q_size + kv_size] + module.self_attn.v_proj.scales = tensor[:, q_size + kv_size :] + elif bool(re.match(r"^model.layers\.\d+\.self_attn.qkv_proj\.qzeros$", name)): + # model.layers.layer_id.self_attn.qkv_proj.qzeros + q_dim = q_size // (32 // bits) if quant_type in {"awq", "gptq"} else q_size + kv_dim = kv_size // (32 // bits) if quant_type in {"awq", "gptq"} else kv_size + module.self_attn.q_proj.qzeros = tensor[:, : q_dim] + module.self_attn.k_proj.qzeros = tensor[:, q_dim : q_dim + kv_dim] + module.self_attn.v_proj.qzeros = tensor[:, q_dim + kv_dim :] + elif bool(re.match(r"^model.layers\.\d+\.self_attn.qkv_proj\.g_idx$", name)): + # model.layers.layer_id.self_attn.qkv_proj.g_ix + module.self_attn.q_proj.g_idx = tensor + module.self_attn.k_proj.g_idx = tensor + module.self_attn.v_proj.g_idx = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_up_proj\.qweight$", name)): + # model.layers.layer_id.mlp.gate_up_proj.qweight + intermediate_dim = intermediate_size // (32 // bits) if quant_type == "awq" else intermediate_size + module.mlp.gate_proj.qweight = tensor[:, : intermediate_dim] + module.mlp.up_proj.qweight = tensor[:, intermediate_dim :] + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_up_proj\.scales$", name)): + # model.layers.layer_id.mlp.gate_up_proj.scales + module.mlp.gate_proj.scales = tensor[:, : intermediate_size] + module.mlp.up_proj.scales = tensor[:, intermediate_size :] + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_up_proj\.qzeros$", name)): + # model.layers.layer_id.mlp.gate_up_proj.qzeros + intermediate_dim = intermediate_size // (32 // bits) if quant_type in {"awq", "gptq"} else intermediate_size + module.mlp.gate_proj.qzeros = tensor[:, : intermediate_dim] + module.mlp.up_proj.qzeros = tensor[:, intermediate_dim :] + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_up_proj\.g_idx$", name)): + # model.layers.layer_id.mlp.gate_up_proj.g_idx + module.mlp.gate_proj.g_idx = tensor + module.mlp.up_proj.g_idx = tensor + else: + raise NotImplementedError(f"{name} in your quantized model is not recognized.") + + if not module.is_empty(): + # Append final layer to list of layers + self.layers.append(module) + + # Set LM head weights + biases if not already set + if self.lm_head.weight is None: + # Embedding and LM head share same weights + biases (lm_head.weight == embedding.weight and lm_head.bias == embedding.bias) + self.lm_head.weight = self.embedding.weight + if self.lm_head.bias is not None: + self.lm_head.bias = self.embedding.bias + + # Sort list of layers by layer id + self.layers.sort(key=lambda m: m.layer_id) + + # Set properties of each layer based on quantization type + self.set_properties() + + def set_properties(self): + """ + Set in_features, out_features, and g_idx based on quantization type + """ + for module in self.layers: + if self.quant_type == "awq": + # Set in_features and out_features + module.self_attn.q_proj.out_features = module.self_attn.q_proj.scales.shape[1] + module.self_attn.q_proj.in_features = module.self_attn.q_proj.qweight.shape[0] + module.self_attn.k_proj.out_features = module.self_attn.k_proj.scales.shape[1] + module.self_attn.k_proj.in_features = module.self_attn.k_proj.qweight.shape[0] + module.self_attn.v_proj.out_features = module.self_attn.v_proj.scales.shape[1] + module.self_attn.v_proj.in_features = module.self_attn.v_proj.qweight.shape[0] + module.self_attn.o_proj.out_features = module.self_attn.o_proj.scales.shape[1] + module.self_attn.o_proj.in_features = module.self_attn.o_proj.qweight.shape[0] + module.mlp.gate_proj.out_features = module.mlp.gate_proj.scales.shape[1] + module.mlp.gate_proj.in_features = module.mlp.gate_proj.qweight.shape[0] + module.mlp.up_proj.out_features = module.mlp.up_proj.scales.shape[1] + module.mlp.up_proj.in_features = module.mlp.up_proj.qweight.shape[0] + module.mlp.down_proj.out_features = module.mlp.down_proj.scales.shape[1] + module.mlp.down_proj.in_features = module.mlp.down_proj.qweight.shape[0] + + # Set g_idx if not already set + module.self_attn.q_proj.g_idx = module.self_attn.q_proj.g_idx if module.self_attn.q_proj.g_idx is not None else torch.tensor([i // module.self_attn.q_proj.group_size for i in range(module.self_attn.q_proj.in_features)], dtype=torch.int32) + module.self_attn.k_proj.g_idx = module.self_attn.k_proj.g_idx if module.self_attn.k_proj.g_idx is not None else torch.tensor([i // module.self_attn.k_proj.group_size for i in range(module.self_attn.k_proj.in_features)], dtype=torch.int32) + module.self_attn.v_proj.g_idx = module.self_attn.v_proj.g_idx if module.self_attn.v_proj.g_idx is not None else torch.tensor([i // module.self_attn.v_proj.group_size for i in range(module.self_attn.v_proj.in_features)], dtype=torch.int32) + module.self_attn.o_proj.g_idx = module.self_attn.o_proj.g_idx if module.self_attn.o_proj.g_idx is not None else torch.tensor([i // module.self_attn.o_proj.group_size for i in range(module.self_attn.o_proj.in_features)], dtype=torch.int32) + module.mlp.gate_proj.g_idx = module.mlp.gate_proj.g_idx if module.mlp.gate_proj.g_idx is not None else torch.tensor([i // module.mlp.gate_proj.group_size for i in range(module.mlp.gate_proj.in_features)], dtype=torch.int32) + module.mlp.up_proj.g_idx = module.mlp.up_proj.g_idx if module.mlp.up_proj.g_idx is not None else torch.tensor([i // module.mlp.up_proj.group_size for i in range(module.mlp.up_proj.in_features)], dtype=torch.int32) + module.mlp.down_proj.g_idx = module.mlp.down_proj.g_idx if module.mlp.down_proj.g_idx is not None else torch.tensor([i // module.mlp.down_proj.group_size for i in range(module.mlp.down_proj.in_features)], dtype=torch.int32) + + elif self.quant_type == "gptq": + # Set in_features and out_features + module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[1] + module.self_attn.q_proj.in_features = module.self_attn.q_proj.g_idx.shape[0] + module.self_attn.k_proj.out_features = module.self_attn.k_proj.qweight.shape[1] + module.self_attn.k_proj.in_features = module.self_attn.k_proj.g_idx.shape[0] + module.self_attn.v_proj.out_features = module.self_attn.v_proj.qweight.shape[1] + module.self_attn.v_proj.in_features = module.self_attn.v_proj.g_idx.shape[0] + module.self_attn.o_proj.out_features = module.self_attn.o_proj.qweight.shape[1] + module.self_attn.o_proj.in_features = module.self_attn.o_proj.g_idx.shape[0] + module.mlp.gate_proj.out_features = module.mlp.gate_proj.qweight.shape[1] + module.mlp.gate_proj.in_features = module.mlp.gate_proj.g_idx.shape[0] + module.mlp.up_proj.out_features = module.mlp.up_proj.qweight.shape[1] + module.mlp.up_proj.in_features = module.mlp.up_proj.g_idx.shape[0] + module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[1] + module.mlp.down_proj.in_features = module.mlp.down_proj.g_idx.shape[0] + + else: + raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.") + + def modules(self): + """ + Return list of modules in quantized model in order of appearance in the model + """ + return [self.embedding] + self.layers + [self.final_norm, self.lm_head] + + def unpack(self, module): + """ + Unpack `qzeros` and `qweight` to standard format + """ + self.unpack_qzeros(module) + self.unpack_qweight(module) + self.dequant_weight(module) + + def repack(self, module): + """ + Repack `scales`, `qzeros` and `qweight` to ORT format + """ + intweight = self.quant_weight(module) + self.pack_ort_format(module, intweight) + + def unpack_qzeros(self, module): + """ + Unpack `qzeros` to standard format + """ + if module.qzeros is None: + return + expected_shape = (module.in_features // module.group_size, module.out_features) + transpose = module.qzeros.shape[0] != expected_shape[0] + module.qzeros = self.unpack_on_row(module.qzeros, module.bits, transpose) + + def unpack_qweight(self, module): + """ + Unpack `qweight` to standard format + """ + expected_shape = (module.in_features, module.qweight.shape[1]) + transpose = module.qweight.shape[0] != expected_shape[0] + module.qweight = self.unpack_on_row(module.qweight, module.bits, transpose) + + def pack_qzeros(self, module): + """ + Pack `qzeros` to quantized format + """ + expected_shape = (module.in_features // module.group_size, module.out_features) + transpose = module.qzeros.shape[0] != expected_shape[0] + module.qzeros = self.pack_on_row(module.qzeros, module.bits, transpose) + + def unpack_on_row_for_2_4_8_bits(self, tensor, bits, transpose): + """ + Perform general-purpose unpacking on 2-bit, 4-bit, or 8-bit tensor + """ + pack_tensor = tensor.T if transpose else tensor + wf = torch.arange(0, 32, bits, device=pack_tensor.device).unsqueeze(0).unsqueeze(0) + out = torch.bitwise_right_shift(torch.unsqueeze(pack_tensor, 2), wf) + out = out.reshape(pack_tensor.shape[0], -1) + out = torch.bitwise_and(out, (2 ** bits) - 1) + return out.T if transpose else out + + def unpack_on_row(self, tensor, bits, transpose): + """ + Unpack tensor by row + """ + if bits in {2, 4, 8}: + return self.unpack_on_row_for_2_4_8_bits(tensor, bits, transpose) + else: + raise NotImplementedError(f"Unpacking for {bits}-bit quantization is not currently supported.") + + def pack_on_row_for_2_4_8_bits(self, tensor, bits, transpose): + """ + Perform general-purpose packing on 2-bit, 4-bit, or 8-bit tensor + """ + orig_tensor = tensor.T if transpose else tensor + wf = torch.arange(0, bits).view(1, 1, -1) + out = torch.bitwise_right_shift(orig_tensor.unsqueeze(-1), wf) + out = torch.bitwise_and(out, 1) + out = out.reshape(orig_tensor.shape[0], -1, 32) + wf1 = torch.arange(0, 32, 1).view(1, 1, -1) + out = torch.bitwise_left_shift(out, wf1) + out = out.sum(dim=-1).int() + return out.T if transpose else out + + def pack_on_row(self, tensor, bits, transpose): + """ + Pack tensor by row + """ + if bits in {2, 4, 8}: + return self.pack_on_row_for_2_4_8_bits(tensor, bits, transpose) + else: + raise NotImplementedError(f"Packing for {bits}-bit quantization is not currently supported.") + + def dequant_weight(self, module): + """ + De-quantize `qweight` to higher precision (float16) + """ + # Note: `qweight` and `qzeros` have already been unpacked and stored in those variables respectively + intweight = module.qweight + zeros = module.qzeros + scales = module.scales + g_idx = module.g_idx + + # De-quantize weight to higher precision + scale_zeros = zeros * scales + scale_mat = scales[g_idx] + scale_zeros_mat = scale_zeros[g_idx] + qdq_weight_T = intweight * scale_mat - scale_zeros_mat.half() + + # Store unpacked result in `qweight` + module.qweight = qdq_weight_T.T + + def quant_weight(self, module): + """ + Calculate integer weight to quantize `qweight` with + """ + weight = module.qweight.T + zeros = module.qzeros + scales = module.scales + g_idx = module.g_idx + + scale_zeros = zeros * scales + scale_mat = scales[g_idx] + scale_zeros_mat = scale_zeros[g_idx] + intweight_T = torch.round((weight + scale_zeros_mat) / scale_mat).to(torch.int) + + return intweight_T + + def pack_ort_format(self, module, intweight): + """ + Pack `scales`, `qzeros`, and `qweight` to ORT format + """ + if module.bits != 4: + raise NotImplementedError(f"{modue.bits}-bit quantization in ORT is not currently supported by this tool.") + + intzeros_pt = module.qzeros.T if module.qzeros.dtype == module.scales.dtype else module.qzeros.T.byte() + intweight_pt = intweight.byte() + block_size = module.group_size + + rows, cols = intweight_pt.shape + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0) + intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0) + + if module.qzeros.dtype != module.scales.dtype: + intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4) + intzeros_pt = intzeros_pt.reshape(-1) + + intweight_pt_T = intweight.T + intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4) + intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size) + + scales_pt = module.scales.T.reshape(-1) + + module.scales = scales_pt.contiguous() + module.qweight = intweight_pt_T.contiguous().byte() + if module.qzeros.dtype != module.scales.dtype: + module.qzeros = intzeros_pt.contiguous().byte() + else: + module.qzeros = intzeros_pt.contiguous() + + +class AWQModel(QuantizedModel): + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + + # Unpack and repack all `QuantizedTensorModule` classes in model + for i, layer in enumerate(self.layers): + print(f"Unpacking and repacking layer {i}") + + # Unpack and repack all `QuantizedTensorModule` classes in attention + for name, q_tensors in layer.self_attn.__dict__.items(): + if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: + self.unpack(q_tensors) + self.repack(q_tensors) + + # Set `g_idx` to None since it's not used in `MatMulNBits` + q_tensors.g_idx = None + + # Unpack and repack all `Quantized TensorModule` classes in MLP + for name, q_tensors in layer.mlp.__dict__.items(): + if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: + self.unpack(q_tensors) + self.repack(q_tensors) + + # Set `g_idx` to None since it's not used in `MatMulNBits` + q_tensors.g_idx = None + + def unpack_qweight(self, module): + """ + Unpack `qweight` to standard format + """ + expected_shape = (module.qweight.shape[0], module.out_features) + transpose = module.qweight.shape != expected_shape + module.qweight = self.unpack_on_row(module.qweight.T, module.bits, transpose) + module.qweight = self.reverse_reorder_tensor(module.qweight.T, module.bits) + + def unpack_qzeros(self, module): + """ + Unpack `qzeros` to standard format + """ + super().unpack_qzeros(module) + module.qzeros = self.reverse_reorder_tensor(module.qzeros, module.bits) + + def reverse_reorder_tensor(self, tensor, bits): + """ + Re-arrange tensor data in a new order + """ + compress_ratio = 32 // bits + assert tensor.shape[-1] % compress_ratio == 0 + + if bits == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + raise NotImplementedError(f"Unpacking for {bits}-bit quantization is not currently supported.") + + order_tensor = torch.tensor(order_map, dtype=torch.int32).reshape(1, -1) + order_tensor = order_tensor.repeat(tensor.shape[1] // compress_ratio, 1) + order_tensor = order_tensor + torch.arange(0, tensor.shape[1], compress_ratio, dtype=torch.int32).reshape(-1, 1) + order_tensor = order_tensor.reshape(-1) + + reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor] + reverse_order_tensor = reverse_order_tensor[order_tensor] + int_tensor = tensor[:, reverse_order_tensor] + return int_tensor + + +class GPTQModel(QuantizedModel): + def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + + # Unpack and repack all `QuantizedTensorModule` classes in model + for i, layer in enumerate(self.layers): + print(f"Unpacking and repacking layer {i}") + # Unpack and repack all `QuantizedTensorModule` classes in attention + + for name, q_tensors in layer.self_attn.__dict__.items(): + if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: + self.handle_qzeros(q_tensors) + self.unpack(q_tensors) + self.repack(q_tensors) + + if not use_g_idx: + # Set `g_idx` to None since it's not used in `MatMulNBits` + q_tensors.g_idx = None + + # Unpack and repack all `QuantizedTensorModule` classes in MLP + for name, q_tensors in layer.mlp.__dict__.items(): + if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: + self.handle_qzeros(q_tensors) + self.unpack(q_tensors) + self.repack(q_tensors) + + if not use_g_idx: + # Set `g_idx` to None since it's not used in `MatMulNBits` + q_tensors.g_idx = None + + def handle_qzeros(self, module): + """ + Re-pack `qzeros` to handle extra `-1`s + """ + if module.qzeros is None or module.qzeros.numel() == 0: + return + + class TempModule: + def __init__(self, module): + self.in_features = module.in_features + self.out_features = module.out_features + self.group_size = module.group_size + self.bits = module.bits + self.qzeros = module.qzeros + + temp_module = TempModule(module) + self.unpack_qzeros(temp_module) + + temp_module.qzeros += 1 + temp_module.qzeros = torch.bitwise_and(temp_module.qzeros, (2 ** temp_module.bits) - 1) + + self.pack_qzeros(temp_module) + module.qzeros = temp_module.qzeros + + +class QuantModel: + @staticmethod + def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size): + """ + Unpack quantized weights in PyTorch models, store them in a standard format, and repack them + into ONNX Runtime's format. Also performs any pre-processing and post-processing when unpacking + the quantized weights. + """ + if quant_type == "awq": + model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + elif quant_type == "gptq": + model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size) + else: + raise NotImplementedError(f"The {quant_type} quantized model is not currently supported.") + + return model \ No newline at end of file