Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preparing TPC for weights per-attribute quantization #925

Merged
merged 36 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
49916d9
WIP - Implementing attribute config in TPC and adjusting all relevant…
Dec 14, 2023
58c02b8
* modify keras pot TPC
Dec 21, 2023
67ce1a6
adjust all existing TPCs
Dec 26, 2023
f996254
Modifications to all TPCs and fix tests and issues
Jan 3, 2024
b3110d1
Fix almost all keras feature network tests
Jan 4, 2024
06a4a30
Fix keras function tests
Jan 15, 2024
05d23b1
fix some remaining keras tests
Jan 15, 2024
129ab2a
fix pytorch tests
Jan 16, 2024
a2f6c61
documentation, typehints and functions and variables cleaning (except…
Jan 18, 2024
637e73f
Another test fixes
Jan 21, 2024
973654f
merge main
Jan 21, 2024
40f8f8c
minor fix
Jan 21, 2024
a7bd349
fix common test
Jan 21, 2024
e5426f9
fix exporter tests
Jan 21, 2024
b21bf80
fix pytorch old api tests
Jan 21, 2024
d6c4d3a
fix pytorch tp model tests
Jan 22, 2024
61d57f9
fix constant imports
Jan 22, 2024
4e9d7b2
Merge branch 'main' into tpc-attr-cfg
Jan 22, 2024
f0a5ef5
fix some merging issues
Jan 22, 2024
01af709
Fix simple PR comments
Jan 24, 2024
d03848e
change attribute mapping to use default dict (changes in TPC, need to…
Jan 25, 2024
377d3ba
Merge branch 'main' into tpc-attr-cfg
Jan 25, 2024
a2955cf
Adjust TPCs to work with default dict in attrs mapping (WIP)
Jan 25, 2024
8b827e1
Minor PR comments fixes
Jan 25, 2024
5b460a7
Fix tests after modifying attributes mapping
Jan 28, 2024
891b8ff
add comment in TPCs regarding attributes mapping
Jan 28, 2024
e3a0b08
minor PR fixes
Jan 28, 2024
5ce8870
try to solve circular
Jan 28, 2024
a3b114d
another attempt to solve circular
Jan 29, 2024
e5276f1
revert unnecessary changes of imports
Jan 29, 2024
ca861a5
revert
Jan 29, 2024
14dde1e
Merge branch 'main' into tpc-attr-cfg
Jan 30, 2024
3dc9e09
fix default dict import after extraction
Jan 30, 2024
0ff82cd
fix key mapping loop
Jan 31, 2024
090e609
merge main and resolved conflicts
Jan 31, 2024
67add83
fix import
Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
ACTIVATION_QUANT_PARAMS_FN = 'activation_quantization_params_fn'
WEIGHTS_QUANT_PARAMS_FN = 'weights_quantization_params_fn'
WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
WEIGHTS_CFG = 'weights_cfg'

# Memory graph constants
DUMMY_NODE = 'dummy_node'
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def set_tpc(self,
if not is_node_in_tpc:
Logger.error(f'MCT does not support optimizing Keras custom layers, but found layer of type {n.type}. '
f'Please add the custom layer to TPC or file a feature request or an issue if you believe this is an issue.')
if any([qc.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
Logger.error(f'MCT does not support optimizing Keras custom layers with weights quantization. Layer: {n.type}')

self.tpc = tpc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# ==============================================================================
from model_compression_toolkit.constants import ACTIVATION_QUANTIZATION_CFG, WEIGHTS_QUANTIZATION_CFG, QC, \
OP_CFG, ACTIVATION_QUANTIZATION_FN, WEIGHTS_QUANTIZATION_FN, ACTIVATION_QUANT_PARAMS_FN, WEIGHTS_QUANT_PARAMS_FN, \
WEIGHTS_CHANNELS_AXIS
WEIGHTS_CHANNELS_AXIS, WEIGHTS_CFG
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR


##########################################
Expand Down Expand Up @@ -47,4 +48,5 @@ def __init__(self, **kwargs):
kwargs.get(OP_CFG),
kwargs.get(WEIGHTS_QUANTIZATION_FN),
kwargs.get(WEIGHTS_QUANT_PARAMS_FN),
kwargs.get(WEIGHTS_CHANNELS_AXIS))
kwargs.get(WEIGHTS_CHANNELS_AXIS),
kwargs.get(WEIGHTS_CFG))
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def __init__(self,
op_cfg: OpQuantizationConfig,
weights_quantization_fn: Callable,
weights_quantization_params_fn: Callable,
weights_channels_axis: int):
weights_channels_axis: int,
weights_cfg: Any):
"""

Args:
Expand All @@ -245,19 +246,22 @@ def __init__(self,
weights_quantization_fn: Function to use when quantizing the node's weights.
weights_quantization_params_fn: Function to use when computing the threshold for quantizing a node's weights.
weights_channels_axis: Axis to quantize a node's kernel when quantizing per-channel.
weights_cfg: Weights attribute quantization config.
"""

# TODO: after refactoring to enable attributes quantization, all weights quantization arguments
# should be taken per attribute, and not from the weights config
self.weights_quantization_fn = weights_quantization_fn
self.weights_quantization_params_fn = weights_quantization_params_fn
self.weights_channels_axis = weights_channels_axis
self.weights_quantization_params = {}
self.weights_quantization_method = op_cfg.weights_quantization_method
self.weights_quantization_method = weights_cfg.weights_quantization_method
self.weights_error_method = qc.weights_error_method
self.weights_n_bits = op_cfg.weights_n_bits
self.weights_n_bits = weights_cfg.weights_n_bits
self.weights_bias_correction = qc.weights_bias_correction
self.weights_second_moment_correction = qc.weights_second_moment_correction
self.weights_per_channel_threshold = op_cfg.weights_per_channel_threshold
self.enable_weights_quantization = op_cfg.enable_weights_quantization
self.weights_per_channel_threshold = weights_cfg.weights_per_channel_threshold
self.enable_weights_quantization = weights_cfg.enable_weights_quantization
self.min_threshold = qc.min_threshold
self.l_p_value = qc.l_p_value
self.simd_size = op_cfg.simd_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def set_quantization_configs_to_node(node: BaseNode,
fw_info,
weight_channel_axis,
node_qc_options,
node.type,
mixed_precision_enable=mixed_precision_enable)

for candidate_qc in node.candidates_quantization_cfg:
Expand Down Expand Up @@ -118,10 +119,11 @@ def create_node_activation_qc(qc: QuantizationConfig,
activation_quantization_params_fn)


def create_node_qc_candidate(qc: QuantizationConfig,
fw_info: FrameworkInfo,
weight_channel_axis: int,
op_cfg: OpQuantizationConfig) -> CandidateNodeQuantizationConfig:
def _create_node_single_candidate_qc(qc: QuantizationConfig,
fw_info: FrameworkInfo,
weight_channel_axis: int,
op_cfg: OpQuantizationConfig,
kernel_attr: str) -> CandidateNodeQuantizationConfig:
"""
Create quantization configuration candidate from a QuantizationConfig object.
Creates both weights and activation quantization configurations
Expand All @@ -133,18 +135,28 @@ def create_node_qc_candidate(qc: QuantizationConfig,
weights/activations should be quantized)
weight_channel_axis: Output channel index of the node's kernel.
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
kernel_attr: The name of the kernel attribute of the node,
TODO: kernel_attr should be removed once enabling attributes quantization (because this function would create
candidate for all attributes not specifically for the kernel

Returns: a CandidateNodeQuantizationConfig object with both weights and activation quantization config objects.

"""

# get attributes for weights quantization
weights_quantization_fn = get_weights_quantization_fn(op_cfg.weights_quantization_method)
weights_cfg = op_cfg.attr_weights_configs_mapping.get(kernel_attr)
if weights_cfg is None:
# the node doesn't have a specified kernel config. Using the default attribute config for quantization.
# TODO: This should be the behavior for all attributes that are not specified in the attribute config mapping,
# which currently disables the quantization of the weights attribute.
weights_cfg = op_cfg.default_weight_attr_config

weights_quantization_fn = get_weights_quantization_fn(weights_cfg.weights_quantization_method)

if weights_quantization_fn is None:
Logger.critical('Unknown quantization method for weights') # pragma: no cover

weights_quantization_params_fn = get_weights_quantization_params_fn(op_cfg.weights_quantization_method)
weights_quantization_params_fn = get_weights_quantization_params_fn(weights_cfg.weights_quantization_method)

# get attributes for activation quantization
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
Expand All @@ -159,13 +171,15 @@ def create_node_qc_candidate(qc: QuantizationConfig,
activation_quantization_params_fn=activation_quantization_params_fn,
weights_quantization_fn=weights_quantization_fn,
weights_quantization_params_fn=weights_quantization_params_fn,
weight_channel_axis=weight_channel_axis)
weight_channel_axis=weight_channel_axis,
weights_cfg=weights_cfg)


def _create_node_candidates_qc(qc: QuantizationConfig,
fw_info: FrameworkInfo,
weight_channel_axis: int,
node_qc_options: QuantizationConfigOptions,
node_type: type,
mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]:
"""
Create a list of candidates of weights and activation quantization configurations for a node.
Expand All @@ -175,28 +189,39 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
fw_info: Framework information (e.g., which layers should have their kernels' quantized).
weight_channel_axis: Output channel index of the node's kernel.
node_qc_options: QuantizationConfigOptions for the node with quantization candidates information.
node_type: The type of the layer that the node represents.
mixed_precision_enable: is mixed precision enabled

Returns:
List of candidates of weights quantization configurations to set for a node.
"""

candidates = []

# TODO: Currently, we are using fw_info to get the kernel attribute, but this would changed once we enabe multi
# attribute quantization via AttributeQuantizationConfig class (needs to be implemented)

kernel_attr = fw_info.get_kernel_op_attributes(node_type)
assert len(kernel_attr) == 1
kernel_attr = kernel_attr[0]

if mixed_precision_enable:
for op_cfg in node_qc_options.quantization_config_list:
candidate_nbits_qc = copy.deepcopy(qc)
candidates.append(create_node_qc_candidate(candidate_nbits_qc,
fw_info,
weight_channel_axis,
op_cfg))
candidates.append(_create_node_single_candidate_qc(candidate_nbits_qc,
fw_info,
weight_channel_axis,
op_cfg,
kernel_attr))
# sorting the candidates by weights number of bits first and then by activation number of bits
# (in reversed order)
candidates.sort(key=lambda c: (c.weights_quantization_cfg.weights_n_bits,
c.activation_quantization_cfg.activation_n_bits), reverse=True)
else:
candidates.append(create_node_qc_candidate(qc,
fw_info,
weight_channel_axis,
node_qc_options.base_config))
candidates.append(_create_node_single_candidate_qc(qc,
fw_info,
weight_channel_axis,
node_qc_options.base_config,
kernel_attr))

return candidates
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,21 @@
DEFAULT_TP_MODEL = 'default'
IMX500_TP_MODEL = 'imx500'
TFLITE_TP_MODEL = 'tflite'
QNNPACK_TP_MODEL = 'qnnpack'
QNNPACK_TP_MODEL = 'qnnpack'

# TP Attributes
KERNEL_ATTR = "kernel_attr"
BIAS_ATTR = "bias_attr"

# TODO: this is duplicated from the core frameworks constants files, because the original consts can't be used here
# duo to circular dependency. It might be best to extract the constants from the core file and put them here (in a
# separate changeset, because it affects the entire code)
KERAS_KERNEL = "kernel"
KERAS_DEPTHWISE_KERNEL = "depthwise_kernel"
BIAS = "bias"
PYTORCH_KERNEL = "weight"

# Configuration attributes names

WEIGHTS_N_BITS = 'weights_n_bits'
WEIGHTS_QUANTIZATION_METHOD = 'weights_quantization_method'
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_default_quantization_config_options, TargetPlatformModel

from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
QuantizationConfigOptions
QuantizationConfigOptions, AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat

from mct_quantizers import QuantizationMethod
Expand Down
Loading
Loading