Skip to content

Commit

Permalink
Preparing TPC for weights per-attribute quantization (#925)
Browse files Browse the repository at this point in the history
* Weights configuration in OpQuantizationConfig are extracted to a new class named AttributeQuantizationConfig which holds the weights quantization configuration per-attribute.
* Each OpQuantizationConfig now includes a default_attribute_config and an attributes_config_mapping which maps an attribute to the attribute's specific quantization configuration. The default config is then used to quantize all non-specified weight attributes.
* By default, we add Kernel and Bias attributes to all our TP models base op config. The kernel is quantized similarly to the way we have quantized weights so far. The bias quantization is disabled.
* To enable attribute quantization with specific config per attribute, we created a mapping mechanism between a general attribute name (e.g., "KERNEL_ATTR") to this attribute name in the framework

---------

Co-authored-by: Ofir Gordon <[email protected]>
  • Loading branch information
2 people authored and liord committed Feb 25, 2024
1 parent 8af86f8 commit 3f1dc0a
Show file tree
Hide file tree
Showing 58 changed files with 1,208 additions and 469 deletions.
1 change: 1 addition & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
ACTIVATION_QUANT_PARAMS_FN = 'activation_quantization_params_fn'
WEIGHTS_QUANT_PARAMS_FN = 'weights_quantization_params_fn'
WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
WEIGHTS_CFG = 'weights_cfg'

# Memory graph constants
DUMMY_NODE = 'dummy_node'
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def set_tpc(self,
if not is_node_in_tpc:
Logger.error(f'MCT does not support optimizing Keras custom layers, but found layer of type {n.type}. '
f'Please add the custom layer to TPC or file a feature request or an issue if you believe this is an issue.')
if any([qc.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
Logger.error(f'MCT does not support optimizing Keras custom layers with weights quantization. Layer: {n.type}')

self.tpc = tpc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# ==============================================================================
from model_compression_toolkit.constants import ACTIVATION_QUANTIZATION_CFG, WEIGHTS_QUANTIZATION_CFG, QC, \
OP_CFG, ACTIVATION_QUANTIZATION_FN, WEIGHTS_QUANTIZATION_FN, ACTIVATION_QUANT_PARAMS_FN, WEIGHTS_QUANT_PARAMS_FN, \
WEIGHTS_CHANNELS_AXIS
WEIGHTS_CHANNELS_AXIS, WEIGHTS_CFG
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR


##########################################
Expand Down Expand Up @@ -47,4 +48,5 @@ def __init__(self, **kwargs):
kwargs.get(OP_CFG),
kwargs.get(WEIGHTS_QUANTIZATION_FN),
kwargs.get(WEIGHTS_QUANT_PARAMS_FN),
kwargs.get(WEIGHTS_CHANNELS_AXIS))
kwargs.get(WEIGHTS_CHANNELS_AXIS),
kwargs.get(WEIGHTS_CFG))
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
QuantizationErrorMethod
from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
AttributeQuantizationConfig


##########################################
Expand Down Expand Up @@ -236,7 +237,8 @@ def __init__(self,
op_cfg: OpQuantizationConfig,
weights_quantization_fn: Callable,
weights_quantization_params_fn: Callable,
weights_channels_axis: int):
weights_channels_axis: int,
weights_cfg: AttributeQuantizationConfig):
"""
Args:
Expand All @@ -245,19 +247,22 @@ def __init__(self,
weights_quantization_fn: Function to use when quantizing the node's weights.
weights_quantization_params_fn: Function to use when computing the threshold for quantizing a node's weights.
weights_channels_axis: Axis to quantize a node's kernel when quantizing per-channel.
weights_cfg: Weights attribute quantization config.
"""

# TODO: after refactoring to enable attributes quantization, all weights quantization arguments
# should be taken per attribute, and not from the weights config
self.weights_quantization_fn = weights_quantization_fn
self.weights_quantization_params_fn = weights_quantization_params_fn
self.weights_channels_axis = weights_channels_axis
self.weights_quantization_params = {}
self.weights_quantization_method = op_cfg.weights_quantization_method
self.weights_quantization_method = weights_cfg.weights_quantization_method
self.weights_error_method = qc.weights_error_method
self.weights_n_bits = op_cfg.weights_n_bits
self.weights_n_bits = weights_cfg.weights_n_bits
self.weights_bias_correction = qc.weights_bias_correction
self.weights_second_moment_correction = qc.weights_second_moment_correction
self.weights_per_channel_threshold = op_cfg.weights_per_channel_threshold
self.enable_weights_quantization = op_cfg.enable_weights_quantization
self.weights_per_channel_threshold = weights_cfg.weights_per_channel_threshold
self.enable_weights_quantization = weights_cfg.enable_weights_quantization
self.min_threshold = qc.min_threshold
self.l_p_value = qc.l_p_value
self.simd_size = op_cfg.simd_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def set_quantization_configs_to_node(node: BaseNode,
fw_info,
weight_channel_axis,
node_qc_options,
node.type,
mixed_precision_enable=mixed_precision_enable)

for candidate_qc in node.candidates_quantization_cfg:
Expand Down Expand Up @@ -118,10 +119,11 @@ def create_node_activation_qc(qc: QuantizationConfig,
activation_quantization_params_fn)


def create_node_qc_candidate(qc: QuantizationConfig,
fw_info: FrameworkInfo,
weight_channel_axis: int,
op_cfg: OpQuantizationConfig) -> CandidateNodeQuantizationConfig:
def _create_node_single_candidate_qc(qc: QuantizationConfig,
fw_info: FrameworkInfo,
weight_channel_axis: int,
op_cfg: OpQuantizationConfig,
kernel_attr: str) -> CandidateNodeQuantizationConfig:
"""
Create quantization configuration candidate from a QuantizationConfig object.
Creates both weights and activation quantization configurations
Expand All @@ -133,18 +135,26 @@ def create_node_qc_candidate(qc: QuantizationConfig,
weights/activations should be quantized)
weight_channel_axis: Output channel index of the node's kernel.
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
kernel_attr: The name of the kernel attribute of the node,
TODO: kernel_attr should be removed once enabling attributes quantization (because this function would create
candidate for all attributes not specifically for the kernel
Returns: a CandidateNodeQuantizationConfig object with both weights and activation quantization config objects.
"""

# get attributes for weights quantization
weights_quantization_fn = get_weights_quantization_fn(op_cfg.weights_quantization_method)
# get attributes for weights quantization.
# if the node doesn't have a specified kernel config we use the default attribute config for quantization.
# TODO: This should be the behavior for all attributes that are not specified in the attribute config mapping,
# which currently disables the quantization of the weights attribute.
weights_cfg = op_cfg.attr_weights_configs_mapping.get(kernel_attr, op_cfg.default_weight_attr_config)

weights_quantization_fn = get_weights_quantization_fn(weights_cfg.weights_quantization_method)

if weights_quantization_fn is None:
Logger.critical('Unknown quantization method for weights') # pragma: no cover
Logger.critical(f'Unknown quantization method for weights for quantizing attribute: {kernel_attr}') # pragma: no cover

weights_quantization_params_fn = get_weights_quantization_params_fn(op_cfg.weights_quantization_method)
weights_quantization_params_fn = get_weights_quantization_params_fn(weights_cfg.weights_quantization_method)

# get attributes for activation quantization
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
Expand All @@ -159,13 +169,15 @@ def create_node_qc_candidate(qc: QuantizationConfig,
activation_quantization_params_fn=activation_quantization_params_fn,
weights_quantization_fn=weights_quantization_fn,
weights_quantization_params_fn=weights_quantization_params_fn,
weight_channel_axis=weight_channel_axis)
weight_channel_axis=weight_channel_axis,
weights_cfg=weights_cfg)


def _create_node_candidates_qc(qc: QuantizationConfig,
fw_info: FrameworkInfo,
weight_channel_axis: int,
node_qc_options: QuantizationConfigOptions,
node_type: type,
mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]:
"""
Create a list of candidates of weights and activation quantization configurations for a node.
Expand All @@ -175,28 +187,39 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
fw_info: Framework information (e.g., which layers should have their kernels' quantized).
weight_channel_axis: Output channel index of the node's kernel.
node_qc_options: QuantizationConfigOptions for the node with quantization candidates information.
node_type: The type of the layer that the node represents.
mixed_precision_enable: is mixed precision enabled
Returns:
List of candidates of weights quantization configurations to set for a node.
"""

candidates = []

# TODO: Currently, we are using fw_info to get the kernel attribute, but this would changed once we enable multi
# attribute quantization via AttributeQuantizationConfig class (needs to be implemented)

kernel_attr = fw_info.get_kernel_op_attributes(node_type)
assert len(kernel_attr) == 1
kernel_attr = kernel_attr[0]

if mixed_precision_enable:
for op_cfg in node_qc_options.quantization_config_list:
candidate_nbits_qc = copy.deepcopy(qc)
candidates.append(create_node_qc_candidate(candidate_nbits_qc,
fw_info,
weight_channel_axis,
op_cfg))
candidates.append(_create_node_single_candidate_qc(candidate_nbits_qc,
fw_info,
weight_channel_axis,
op_cfg,
kernel_attr))
# sorting the candidates by weights number of bits first and then by activation number of bits
# (in reversed order)
candidates.sort(key=lambda c: (c.weights_quantization_cfg.weights_n_bits,
c.activation_quantization_cfg.activation_n_bits), reverse=True)
else:
candidates.append(create_node_qc_candidate(qc,
fw_info,
weight_channel_axis,
node_qc_options.base_config))
candidates.append(_create_node_single_candidate_qc(qc,
fw_info,
weight_channel_axis,
node_qc_options.base_config,
kernel_attr))

return candidates
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,21 @@
DEFAULT_TP_MODEL = 'default'
IMX500_TP_MODEL = 'imx500'
TFLITE_TP_MODEL = 'tflite'
QNNPACK_TP_MODEL = 'qnnpack'
QNNPACK_TP_MODEL = 'qnnpack'

# TP Attributes
KERNEL_ATTR = "kernel_attr"
BIAS_ATTR = "bias_attr"

# TODO: this is duplicated from the core frameworks constants files, because the original consts can't be used here
# duo to circular dependency. It might be best to extract the constants from the core file and put them here (in a
# separate changeset, because it affects the entire code)
KERAS_KERNEL = "kernel"
KERAS_DEPTHWISE_KERNEL = "depthwise_kernel"
BIAS = "bias"
PYTORCH_KERNEL = "weight"

# Configuration attributes names

WEIGHTS_N_BITS = 'weights_n_bits'
WEIGHTS_QUANTIZATION_METHOD = 'weights_quantization_method'
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_default_quantization_config_options, TargetPlatformModel

from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
QuantizationConfigOptions
QuantizationConfigOptions, AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat

from mct_quantizers import QuantizationMethod
Expand Down
Loading

0 comments on commit 3f1dc0a

Please sign in to comment.