Skip to content

Commit

Permalink
Fix graph kernel attributes update after GPTQ run (#1035)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Ofir Gordon <[email protected]>
  • Loading branch information
ofirgo and Ofir Gordon authored Apr 8, 2024
1 parent afca3b3 commit 934ac82
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,24 @@ class BaseNodeQuantizationConfig(object):
Base class for node quantization configuration
"""

def set_quant_config_attr(self, parameter_name: str, parameter_value: Any,
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
*args: List[Any], **kwargs: Dict[str, Any]):
"""
Changes a BaseNodeQuantizationConfig's parameter.
Note that arg and kwargs are only to allow clean override in the child classes.
Args:
parameter_name: parameter name to change.
parameter_value: parameter value to change.
config_parameter_name: parameter name to change.
config_parameter_value: parameter value to change.
args: A list of additional arguments.
kwargs: A dictionary with additional key arguments.
"""

if hasattr(self, parameter_name):
setattr(self, parameter_name, parameter_value)
if hasattr(self, config_parameter_name):
setattr(self, config_parameter_name, config_parameter_value)
else:
Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config and "
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
f"was not updated!")

def __repr__(self) -> str:
Expand Down Expand Up @@ -521,34 +521,35 @@ def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, Weigh
f"{list(attrs_with_name.keys())}.")
return attrs_with_name

def set_quant_config_attr(self, parameter_name: str, parameter_value: Any, attr_name: str = None,
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, attr_name: str = None,
*args: List[Any], **kwargs: Dict[str, Any]):
"""
This method overrides the parent class set_quant_config_attr to enable setting a specific weights
attribute config parameter.
Args:
attr_name: attribute name to change.
parameter_name: parameter name to change.
parameter_value: parameter value to change.
config_parameter_name: parameter name to change.
config_parameter_value: parameter value to change.
args: A list of additional arguments.
kwargs: A dictionary with additional key arguments.
"""

if attr_name is None:
super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(parameter_name, parameter_value,
super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(config_parameter_name,
config_parameter_value,
*args, **kwargs)
else:
if self.has_attribute_config(attr_name):
attr_cfg = self.get_attr_config(attr_name)
if hasattr(attr_cfg, parameter_name):
setattr(attr_cfg, parameter_name, parameter_value)
if hasattr(attr_cfg, config_parameter_name):
setattr(attr_cfg, config_parameter_name, config_parameter_value)
else:
Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config of "
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
f"weights attribute {attr_name} and was not updated!")
else:
Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {parameter_name}.")
Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")

def __eq__(self, other: Any) -> bool:
"""
Expand Down
8 changes: 6 additions & 2 deletions model_compression_toolkit/gptq/keras/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,16 @@ def update_graph(self):
node = node[0]
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
fw_info=self.fw_info)
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
# To enable GPTQ for other attributes, this code needs to be modified.
weights, weight_quant_config, activation_quant_config = \
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
for weight_attr, weight in weights.items():
node.set_weights_by_keys(weight_attr, weight.numpy())
for config_attr, config_value in weight_quant_config.items():
node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
for config_parameter_name, config_parameter_value in weight_quant_config.items():
node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
config_parameter_value,
attr_name=kernel_attribute)
for config_attr, config_value in activation_quant_config.items():
node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
if self.gptq_config.train_bias:
Expand Down
8 changes: 6 additions & 2 deletions model_compression_toolkit/gptq/pytorch/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,16 @@ def update_graph(self) -> Graph:
node = node[0]
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
fw_info=self.fw_info)
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
# To enable GPTQ for other attributes, this code needs to be modified.
weights, weight_quant_config, activation_quant_config = \
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
for weight_attr, weight in weights.items():
node.set_weights_by_keys(weight_attr, self.fw_impl.to_numpy(weight))
for config_attr, config_value in weight_quant_config.items():
node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
for config_parameter_name, config_parameter_value in weight_quant_config.items():
node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
config_parameter_value,
attr_name=kernel_attribute)
for config_attr, config_value in activation_quant_config.items():
node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
Expand Down

0 comments on commit 934ac82

Please sign in to comment.