From 64dacc0272156fb8e6e4f961d2e6b242eaf0c634 Mon Sep 17 00:00:00 2001 From: Elad Cohen <78862769+elad-c@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:36:21 +0300 Subject: [PATCH] Const quantization (#1045) Add const quantization to "add", "sub", "mul" & "div" operations. Enabled in TPC imx500.v2. --- .../core/common/graph/base_graph.py | 4 +- .../core/common/graph/base_node.py | 33 +++-- .../core/common/graph/functional_node.py | 19 ++- .../common/network_editors/node_filters.py | 7 +- .../quantization/node_quantization_config.py | 5 - .../lut_kmeans_params.py | 7 +- .../core/common/similarity_analyzer.py | 4 +- .../back2framework/keras_model_builder.py | 5 +- .../substitutions/batchnorm_folding.py | 14 +- .../substitutions/linear_collapsing.py | 2 +- .../substitutions/residual_collapsing.py | 2 +- .../core/keras/keras_implementation.py | 20 +-- .../core/keras/keras_node_prior_info.py | 8 +- .../pruning/pruning_keras_implementation.py | 9 +- .../core/keras/reader/common.py | 4 +- .../core/keras/reader/node_builder.py | 37 +++-- .../core/keras/tf_tensor_numpy.py | 7 +- .../back2framework/pytorch_model_builder.py | 55 ++++--- .../substitutions/batchnorm_folding.py | 16 +-- .../substitutions/const_holder_conv.py | 4 +- .../substitutions/linear_collapsing.py | 2 +- .../substitutions/relu_bound_to_power_of_2.py | 8 +- .../substitutions/residual_collapsing.py | 2 +- .../pruning/pruning_pytorch_implementation.py | 16 +-- .../core/pytorch/pytorch_implementation.py | 9 +- .../core/pytorch/pytorch_node_prior_info.py | 4 +- .../builder/fully_quantized_model_builder.py | 6 +- .../builder/fully_quantized_model_builder.py | 11 +- .../tpc_models/imx500_tpc/v1_lut/tp_model.py | 2 +- .../tpc_models/imx500_tpc/v2/tp_model.py | 26 +++- .../tpc_models/imx500_tpc/v2/tpc_keras.py | 2 +- .../tpc_models/imx500_tpc/v2_lut/tp_model.py | 30 ++-- .../tpc_models/imx500_tpc/v2_lut/tpc_keras.py | 2 +- .../helpers/generate_test_tp_model.py | 3 +- .../const_quantization_test.py | 120 ++++++++++++++++ .../const_representation_test.py | 8 +- .../feature_networks/mixed_precision_tests.py | 3 +- .../test_features_runner.py | 22 ++- .../test_hessian_info_calculator.py | 4 +- ..._sensitivity_eval_non_suppoerted_output.py | 4 +- ...t_weights_activation_split_substitution.py | 2 +- .../layer_tests/base_pytorch_layer_test.py | 2 + .../feature_models/const_quantization_test.py | 136 ++++++++++++++++++ .../model_tests/test_feature_models_runner.py | 12 ++ 44 files changed, 546 insertions(+), 152 deletions(-) create mode 100644 tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py create mode 100644 tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index d33fab23a..892d12454 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -98,8 +98,8 @@ def set_tpc(self, tpc_layers = tpc.op_sets_to_layers.get_layers() tpc_filtered_layers = [layer for layer in tpc_layers if isinstance(layer, LayerFilterParams)] for n in self.nodes: - is_node_in_tpc = n.type in tpc_layers or any([n.is_match_filter_params(filtered_layer) - for filtered_layer in tpc_filtered_layers]) + is_node_in_tpc = any([n.is_match_type(_type) for _type in tpc_layers]) or \ + any([n.is_match_filter_params(filtered_layer) for filtered_layer in tpc_filtered_layers]) if n.is_custom: if not is_node_in_tpc: Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. ' diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index 193e01d08..9cc721702 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -151,7 +151,21 @@ def is_reused(self) -> bool: """ return self.reuse or self.reuse_group is not None - def get_weights_by_keys(self, name: str) -> np.ndarray: + def _get_weight_name(self, name: Union[str, int]) -> List[Union[str, int]]: + """ + Get weight names that match argument name (either string weights or integer for + positional weights). + Args: + name: weight name + + Returns: + A list of weight names that match input "name" + + """ + return [k for k in self.weights.keys() + if (isinstance(k, int) and name == k) or (isinstance(k, str) and name in k)] + + def get_weights_by_keys(self, name: Union[str, int]) -> np.ndarray: """ Get a node's weight by its name. Args: @@ -163,7 +177,7 @@ def get_weights_by_keys(self, name: str) -> np.ndarray: if name is None: return None - res = [k for k in self.weights.keys() if name in k] + res = self._get_weight_name(name) if len(res) == 1: # Make sure there are no duplicates return self.weights[res[0]] else: @@ -179,7 +193,7 @@ def set_weights_by_keys(self, name: str, tensor: np.ndarray): """ - res = [k for k in self.weights.keys() if name in k] + res = self._get_weight_name(name) if len(res) == 1: self.weights[res[0]] = tensor else: # Add if not exist @@ -552,14 +566,17 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions: for fl, qco in tpc.filterlayer2qco.items(): if self.is_match_filter_params(fl): return qco - if self.type in tpc.layer2qco: - return tpc.layer2qco.get(self.type) + # Extract qco with is_match_type to overcome mismatch of function types in TF 2.15 + matching_qcos = [_qco for _type, _qco in tpc.layer2qco.items() if self.is_match_type(_type)] + if matching_qcos: + if len(matching_qcos) > 1: + Logger.error('Found duplicate qco types!') + return matching_qcos[0] return tpc.tp_model.default_qco def is_match_type(self, _type: Type) -> bool: """ - Check if input type matches the node type, either in instance type or in type name. Checking the - name string is required because of function types changes that occurred in TF 2.15. + Check if input type matches the node type, either in instance type or in type name. Args: _type: other node type @@ -567,7 +584,7 @@ def is_match_type(self, _type: Type) -> bool: Whether _type matches the self node type """ - return _type == self.type or _type.__name__ == self.type.__name__ + return _type == self.type def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool: """ diff --git a/model_compression_toolkit/core/common/graph/functional_node.py b/model_compression_toolkit/core/common/graph/functional_node.py index 85c0cc2f9..ccc6e4bc2 100644 --- a/model_compression_toolkit/core/common/graph/functional_node.py +++ b/model_compression_toolkit/core/common/graph/functional_node.py @@ -1,5 +1,6 @@ -from typing import Dict, Any, Tuple, List +from typing import Dict, Any, Tuple, Type +from model_compression_toolkit.constants import FOUND_TF from model_compression_toolkit.core.common.graph.base_node import BaseNode import numpy as np @@ -71,3 +72,19 @@ def type(self): :return: the node's functional_op """ return self.functional_op + + def is_match_type(self, _type: Type) -> bool: + """ + Check if input type matches the node type, either in instance type or in type name. Checking the + name string is required because of function types changes that occurred in TF 2.15, because it + changes the "function" attribute object (e.g. a different tf.add function that will fail the + equal operation). + + Args: + _type: other node type + Returns: + Whether _type matches the self node type + + """ + names_match = _type.__name__ == self.type.__name__ if FOUND_TF else False + return super().is_match_type(_type) or names_match diff --git a/model_compression_toolkit/core/common/network_editors/node_filters.py b/model_compression_toolkit/core/common/network_editors/node_filters.py index d5e1e1f1d..9d8bd28d6 100644 --- a/model_compression_toolkit/core/common/network_editors/node_filters.py +++ b/model_compression_toolkit/core/common/network_editors/node_filters.py @@ -15,6 +15,7 @@ from typing import Any from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher +from model_compression_toolkit.core.common.graph.base_node import BaseNode class NodeTypeFilter(BaseNodeMatcher): @@ -30,7 +31,7 @@ def __init__(self, node_type): """ self.node_type = node_type - def apply(self, input_object: Any) -> bool: + def apply(self, input_object: BaseNode) -> bool: """ Check if input_object is of the type that NodeTypeFilter contains. @@ -38,9 +39,9 @@ def apply(self, input_object: Any) -> bool: input_object: Node object to check for its type. Returns: - True if the node if of the type that was passed during the initialization of NodeTypeFilter. + True if the node is of the type that was passed during the initialization of NodeTypeFilter. """ - if input_object.type == self.node_type: + if input_object.is_match_type(self.node_type): return True diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index 0dd2e62df..7f612ac12 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -265,8 +265,6 @@ def __init__(self, self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization self.l_p_value = qc.l_p_value - - @property def weights_error_method(self) -> QuantizationErrorMethod: """ @@ -412,9 +410,6 @@ def __init__(self, qc: QuantizationConfig, for attr in node_attrs_list: if isinstance(attr, int): # this is a positional attribute, so it needs to be handled separately. - # we assume that a positional attribute is quantized with the default configuration provided in the TPC. - if op_cfg.default_weight_attr_config.enable_weights_quantization: - Logger.critical(f"Quantizing constant weights is not supported.") self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc, weights_attr_cfg=op_cfg.default_weight_attr_config, weights_channels_axis=weights_channels_axis) diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py index f6d5f52ae..d7c3be072 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +from typing import Dict import numpy as np from sklearn.cluster import KMeans @@ -38,10 +39,10 @@ def lut_kmeans_tensor(tensor_data: np.ndarray, n_iter: int = 10, min_threshold: float = MIN_THRESHOLD, quant_error_method: qc.QuantizationErrorMethod = None, - is_symmetric=False, + is_symmetric: bool = False, node=None, hessian_info_service: HessianInfoService = None, - num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> dict: + num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Dict: """ The quantizer first finds the closest max value per channel of tensor_data. Now, we divide tensor_data with the threshold vector per channel. In addition, we scale the result to the range @@ -101,7 +102,7 @@ def lut_kmeans_histogram(bins: np.ndarray, constrained: bool = True, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, - quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict: + quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> Dict: """ Finds quantization cluster points for non-uniform activation quantization. The quantizer first finds the closest power-of-two number to the max value of the given histogram, diff --git a/model_compression_toolkit/core/common/similarity_analyzer.py b/model_compression_toolkit/core/common/similarity_analyzer.py index 09988f354..59e218495 100644 --- a/model_compression_toolkit/core/common/similarity_analyzer.py +++ b/model_compression_toolkit/core/common/similarity_analyzer.py @@ -235,7 +235,7 @@ def compute_kl_divergence(float_tensor: np.ndarray, fxp_tensor: np.ndarray, batc axis: int = None) -> float: """ Compute the similarity between two tensor using KL-divergence. - The returned values is between 0 to 1: the smaller returned value, + The returned values is between 0 and 1: the smaller returned value, the greater similarity there is between the two tensors. Args: @@ -257,6 +257,6 @@ def compute_kl_divergence(float_tensor: np.ndarray, fxp_tensor: np.ndarray, batc non_zero_fxp_tensor[non_zero_fxp_tensor == 0] = EPS prob_distance = np.where(float_flat != 0, float_flat * np.log(float_flat / non_zero_fxp_tensor), 0) - # The sum is part of the KL-Divergance function. + # The sum is part of the KL-Divergence function. # The mean is to aggregate the distance between each output probability vectors. return np.mean(np.sum(prob_distance, axis=-1), axis=-1) diff --git a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py index c9e8b9a12..1f089dff6 100644 --- a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +++ b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py @@ -39,6 +39,7 @@ from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler from model_compression_toolkit.core.keras.reader.connectivity_handler import OutTensor +from mct_quantizers import KerasQuantizationWrapper # In tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda. FQ_NODE_OP_V2_3 = 'FakeQuantWithMinMaxVars' @@ -270,7 +271,9 @@ def _run_operation(self, out_tensors_of_n_float) else: input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists - input_tensors = n.insert_positional_weights_to_input_list(input_tensors) + if not isinstance(op_func, KerasQuantizationWrapper): + # The KerasQuantizationWrapper will insert the quantized positional weights internally. + input_tensors = n.insert_positional_weights_to_input_list(input_tensors) # Build a functional node using its args if isinstance(n, FunctionalNode): if n.inputs_as_list: # If the first argument should be a list of tensors: diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py index 637bfe210..c94020a8b 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py @@ -70,9 +70,9 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode, Returns: The modified convolution node's weight/kernel/ """ - if conv_node.type == DepthwiseConv2D: + if conv_node.is_match_type(DepthwiseConv2D): kernel = kernel * weights_scale.reshape((1, 1, kernel.shape[-2], kernel.shape[-1])) - elif conv_node.type == Conv2DTranspose: + elif conv_node.is_match_type(Conv2DTranspose): kernel = kernel * weights_scale.reshape((1, 1, -1, 1)) else: kernel = kernel * weights_scale.reshape((1, 1, 1, -1)) @@ -98,10 +98,10 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode, Returns: The modified convolution node's weight/kernel/ """ - if conv_node.type == DepthwiseConv2D: + if conv_node.is_match_type(DepthwiseConv2D): bias_update = kernel * bias_factor.reshape((1, 1, -1, 1)) kernel = kernel * weights_scale.reshape((1, 1, -1, 1)) - elif conv_node.type == Conv2DTranspose: + elif conv_node.is_match_type(Conv2DTranspose): bias_update = (kernel * bias_factor.reshape((1, 1, 1, -1))).sum(3) kernel = kernel * weights_scale.reshape((1, 1, 1, -1)) else: @@ -133,7 +133,7 @@ def is_group_conv_fn(node: BaseNode) -> bool: Returns: True if the node is a group convolution, else False """ - return (node.type == Conv2D) and node.framework_attr[GROUPS] > 1 + return (node.is_match_type(Conv2D)) and node.framework_attr[GROUPS] > 1 def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]: @@ -147,8 +147,8 @@ def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]: is_bn: True if the node is a batch norm, else False is_dw_valid: True if the node is a dw-convolution valid for folding or a batch-norm node, else False """ - is_bn = node.type is BatchNormalization - is_dw = node.type is DepthwiseConv2D + is_bn = node.is_match_type(BatchNormalization) + is_dw = node.is_match_type(DepthwiseConv2D) is_dw_valid = is_dw and np.all(np.array(node.get_weights_by_keys(DEPTHWISE_KERNEL).shape[:2]) == 1) return is_bn, is_dw_valid diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py index 18e5c9ebf..9e1de30ec 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py @@ -58,7 +58,7 @@ def conv2d_collapsing_fn(first_node: BaseNode, Returns: The modified layer node's weights: kernel, bias """ - if first_node.type == Conv2D and second_node.type == Conv2D: + if first_node.is_match_type(Conv2D) and second_node.is_match_type(Conv2D): # Get nodes attributes kernel1 = first_node.get_weights_by_keys(kernel_str) kernel2 = second_node.get_weights_by_keys(kernel_str) diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py index 2b00f247c..1c6501560 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py @@ -49,7 +49,7 @@ def residual_collapsing_fn(first_node: BaseNode, Returns: The modified layer node's weights: kernel """ - if first_node.type == Conv2D: + if first_node.is_match_type(Conv2D): # Get nodes attributes kernel = first_node.get_weights_by_keys(kernel_str) (kH, kW, Cin, Cout) = kernel.shape diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index a39062ecc..f8333da00 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== from functools import partial -from typing import List, Any, Tuple, Callable, Dict +from typing import List, Any, Tuple, Callable, Dict, Union import numpy as np import tensorflow as tf @@ -412,12 +412,13 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool Returns: True if the node should be considered an interest point, False otherwise. """ - if node.type == Activation: + if node.is_match_type(Activation): node_type_name = node.framework_attr[keras_constants.ACTIVATION] if node_type_name in [keras_constants.SOFTMAX, keras_constants.SIGMOID]: return True - elif node.type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate, - tf.concat, Add, tf.add]: + elif any([node.is_match_type(_type) for _type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D, + DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate, tf.concat, + Add, tf.add]]): return True return False @@ -529,18 +530,18 @@ def get_node_mac_operations(self, kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type) - if node.type is Conv2D or node.type is Conv2DTranspose: + if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose): # (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel) return np.prod([x for x in output_shape if x is not None]) * \ kernel_shape[input_channel_axis] * \ (kernel_shape[0] * kernel_shape[1]) - elif node.type is DepthwiseConv2D: + elif node.is_match_type(DepthwiseConv2D): # Depth * (W_out * H_out) * C_in * (W_kernel * H_kernel) return node.framework_attr.get(DEPTH_MULTIPLIER) * \ np.prod([x for x in output_shape if x is not None]) / output_shape[output_channel_axis] * \ kernel_shape[input_channel_axis] * \ (kernel_shape[0] * kernel_shape[1]) - elif node.type is Dense: + elif node.is_match_type(Dense): # IN * OUT return kernel_shape[0] * kernel_shape[1] else: @@ -593,10 +594,9 @@ def get_inferable_quantizers(self, node: BaseNode): Returns: weight_quantizers: A dictionary between a weight's name to its quantizer. activation_quantizers: A list of activations quantization, one for each layer output. - """ - def _weight_name(w: str) -> str: + def _weight_name(w: Union[str, int]) -> Union[str, int]: """ Extracts the weight name from the full TensorFlow variable name. @@ -609,7 +609,7 @@ def _weight_name(w: str) -> str: Extracted weight name. """ - return w.split(':')[0].split('/')[-1] + return w.split(':')[0].split('/')[-1] if isinstance(w, str) else w attribute_names = [_weight_name(wn) for wn in node.get_node_weights_attributes() if node.is_weights_quantization_enabled(wn)] diff --git a/model_compression_toolkit/core/keras/keras_node_prior_info.py b/model_compression_toolkit/core/keras/keras_node_prior_info.py index 4c7747fca..1cbecbd50 100644 --- a/model_compression_toolkit/core/keras/keras_node_prior_info.py +++ b/model_compression_toolkit/core/keras/keras_node_prior_info.py @@ -56,13 +56,13 @@ def _get_min_max_outputs(node: BaseNode, """ min_output, max_output = None, None - if node.type == ReLU: + if node.is_match_type(ReLU): min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None elif fw_info.layers_has_min_max(node.type): min_output, max_output = fw_info.layer_min_max_mapping[node.type] - elif node.type == Activation and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]): + elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]): min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]] return min_output, max_output @@ -82,7 +82,7 @@ def _get_mean_std_outputs(node: BaseNode, """ mean_output, std_output = None, None - if node.type == BatchNormalization: + if node.is_match_type(BatchNormalization): mean_output = node.get_weights_by_keys(BETA) if node.get_weights_by_keys(GAMMA) is None: std_output = 1.0 @@ -92,7 +92,7 @@ def _get_mean_std_outputs(node: BaseNode, mean_output = 0.0 else: next_node_list = graph.get_next_nodes(node) - bn_nodes = [bn_node for bn_node in next_node_list if bn_node.type == BatchNormalization] + bn_nodes = [bn_node for bn_node in next_node_list if bn_node.is_match_type(BatchNormalization)] if len(bn_nodes) != 0: bn_node = bn_nodes[0] moving_variance = bn_node.get_weights_by_keys(MOVING_VARIANCE) diff --git a/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py index a495ed160..91ccdb881 100644 --- a/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +++ b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py @@ -209,10 +209,9 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool: """ # Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1. - if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]: + if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose): return node.framework_attr[GROUPS] == 1 - return node.type == keras.layers.Dense - + return node.is_match_type(keras.layers.Dense) def _prune_keras_edge_node(node: BaseNode, @@ -250,9 +249,9 @@ def _prune_keras_edge_node(node: BaseNode, if not is_exit_node: # Update 'filters' or 'units' attributes for entry node Conv2D/Conv2DTranspose layers. - if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]: + if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose): node.framework_attr[FILTERS] = int(np.sum(mask)) - elif node.type == keras.layers.Dense: + elif node.is_match_type(keras.layers.Dense): node.framework_attr[UNITS] = int(np.sum(mask)) if is_exit_node: diff --git a/model_compression_toolkit/core/keras/reader/common.py b/model_compression_toolkit/core/keras/reader/common.py index 12427f21d..6f905fffd 100644 --- a/model_compression_toolkit/core/keras/reader/common.py +++ b/model_compression_toolkit/core/keras/reader/common.py @@ -43,7 +43,7 @@ def is_node_an_input_layer(node: BaseNode) -> bool: Whether the node represents an input layer or not. """ if isinstance(node, BaseNode): - return node.type == InputLayer + return node.is_match_type(InputLayer) elif isinstance(node, KerasNode): return isinstance(node.layer, InputLayer) else: @@ -60,7 +60,7 @@ def is_node_a_model(node: BaseNode) -> bool: Whether the node represents a Keras model or not. """ if isinstance(node, BaseNode): - return node.type in [Functional, Sequential] + return node.is_match_type(Functional) or node.is_match_type(Sequential) elif isinstance(node, KerasNode): return isinstance(node.layer, Functional) or isinstance(node.layer, Sequential) else: diff --git a/model_compression_toolkit/core/keras/reader/node_builder.py b/model_compression_toolkit/core/keras/reader/node_builder.py index 401a14f38..606c578cf 100644 --- a/model_compression_toolkit/core/keras/reader/node_builder.py +++ b/model_compression_toolkit/core/keras/reader/node_builder.py @@ -41,7 +41,7 @@ REUSED_IDENTIFIER = '_reused_' -is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray)) +is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, float)) is_tensor = lambda x: isinstance(x, KerasTensor) @@ -61,18 +61,36 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]: """ Positional weights are saved according to their index in the node's call arguments, so need to know the function arguments' names in case the weights are in the kwargs. + + Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so + it will only include the arguments that may contain constants. For example, we don't + want the transpose_a attribute of tf.matmul to be saved as a constant. + + Every operation we add support to, needs to be added here. + Args: tfoplambda_layer: TFOpLambda layer. Returns: A dictionary with argument number and index: {arg_name: arg_index}. """ - if tfoplambda_layer.function in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow, - tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression] or \ - tfoplambda_layer.symbol in ['__operators__.add', 'math.add', 'math.multiply', 'linalg.matmul', 'concat']: - return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tfoplambda_layer.function).args)} - else: - return {} + kwargs2index = {tf.add: {'x': 0, 'y': 1}, + tf.subtract: {'x': 0, 'y': 1}, + tf.divide: {'x': 0, 'y': 1}, + tf.truediv: {'x': 0, 'y': 1}, + tf.multiply: {'x': 0, 'y': 1}, + tf.pow: {'x': 0, 'y': 1}, + tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function) + if not kwargs2index: + # In TF 2.15 the function attribute is different and doesn't match the original + # operation object we use. Therefore, we extract kwargs2index with the symbol. + kwargs2index = {'__operators__.add': {'x': 0, 'y': 1}, + 'math.add': {'x': 0, 'y': 1}, + 'math.multiply': {'x': 0, 'y': 1}, + 'linalg.matmul': {'a': 0, 'b': 1}, + 'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {}) + + return kwargs2index def build_node(node: KerasNode, @@ -154,8 +172,9 @@ def build_node(node: KerasNode, if is_const(v) or (keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow, tf.matmul] and isinstance(v, (tuple, list))): - weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)}) - weight_keys.append(k) + if k in kwarg2index: + weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)}) + weight_keys.append(k) # remove weights and KerasTensors and weights from op_call_kwargs op_call_kwargs = {k: v for k, v in op_call_kwargs.items() if not (kwarg2index.get(k) in weights or is_tensor(v))} diff --git a/model_compression_toolkit/core/keras/tf_tensor_numpy.py b/model_compression_toolkit/core/keras/tf_tensor_numpy.py index 6156409c4..a16557d98 100644 --- a/model_compression_toolkit/core/keras/tf_tensor_numpy.py +++ b/model_compression_toolkit/core/keras/tf_tensor_numpy.py @@ -40,7 +40,7 @@ def to_tf_tensor(tensor): Logger.critical(f'Unsupported type for conversion to TF tensor: {type(tensor)}.') -def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor], +def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor, float], is_single_tensor=False) -> np.ndarray: """ Convert a TF tensor to a Numpy array. @@ -65,6 +65,9 @@ def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor], else: return (tf_tensor_to_numpy(t) for t in tensor) elif isinstance(tensor, tf.Tensor): - return tensor.numpy() + np_tensor = tensor.numpy() + return np.array([np_tensor]) if np.isscalar(np_tensor) else np_tensor + elif isinstance(tensor, float): + return np.array([tensor]) else: Logger.critical(f'Unsupported type for conversion to Numpy array: {type(tensor)}.') diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 1e9aeb855..2e5a6b01f 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -33,26 +33,31 @@ from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder from model_compression_toolkit.core.pytorch.utils import to_torch_tensor from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER +from mct_quantizers import PytorchQuantizationWrapper def _build_input_tensors_list(node: BaseNode, graph: Graph, inputs: Tuple[Any], - node_to_output_tensors_dict: Dict[BaseNode, List]) -> List[List]: + node_to_output_tensors_dict: Dict[BaseNode, List], + is_op_quantize_wrapper: bool) -> List[List]: """ - Given a node, build a list of input tensors the node gets. The list is built - based on the node's incoming edges and previous nodes' output tensors. + Given a node, build a list of input tensors the node gets. The list is built based on the + node's incoming edges, previous nodes' output tensors and the node's positional weights. + Positional weights aren't used if the node's op is PytorchQuantizationWrapper, since it's + positional weights are already in the wrapper. Args: node: Node to build its input tensors list. graph: Graph the node is in. - inputs: list of input tensors to model + inputs: list of input tensors to model. node_to_output_tensors_dict: A dictionary from a node to its output tensors. + is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not. Returns: A list of the node's input tensors. """ - if node.type == DummyPlaceHolder: + if node.is_match_type(DummyPlaceHolder): input_tensors = [inputs[graph.get_inputs().index(node)]] else: input_tensors = [] @@ -62,7 +67,8 @@ def _build_input_tensors_list(node: BaseNode, _input_tensors = node_to_output_tensors_dict[ie.source_node] input_tensors.append(_input_tensors) input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists - input_tensors = node.insert_positional_weights_to_input_list(input_tensors) + if not is_op_quantize_wrapper: + input_tensors = node.insert_positional_weights_to_input_list(input_tensors) # convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the # list separately, because in FX the tensors are FX objects and fail to_torch_tensor input_tensors = [to_torch_tensor(t) if isinstance(t, np.ndarray) else t @@ -70,22 +76,27 @@ def _build_input_tensors_list(node: BaseNode, return input_tensors -def _merge_inputs(_node, input_tensors: List, op_call_args: List) -> List: +def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, + is_op_quantize_wrapper: bool) -> List: """ - Merge input tensors list with op_call_args, according to correct order + Merge input tensors list with op_call_args, according to correct order. Args: - _node: The node the inputs are for + _node: The node the inputs are for. input_tensors: activation input tensors to node. - op_call_args: framework node call args + op_call_args: framework node call args. + is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not. Returns: - Combined list of input_tensors and op_call_args + Combined list of input_tensors and op_call_args. """ if isinstance(_node, FunctionalNode) and _node.tensor_input_indices: - assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices' _input_list = op_call_args.copy() - for i, t in zip(_node.tensor_input_indices, input_tensors): - _input_list.insert(i, t) + if is_op_quantize_wrapper: + _input_list = input_tensors + _input_list + else: + assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices' + for i, t in zip(_node.tensor_input_indices, input_tensors): + _input_list.insert(i, t) else: _input_list = input_tensors + op_call_args @@ -118,7 +129,8 @@ def _run_operation(n: BaseNode, if isinstance(n, FunctionalNode) and n.inputs_as_list: out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs) else: - out_tensors_of_n_float = op_func(*_merge_inputs(n, input_tensors, op_call_args), **functional_kwargs) + merged_inputs = _merge_inputs(n, input_tensors, op_call_args, isinstance(op_func, PytorchQuantizationWrapper)) + out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs) # Add a fake quant node if the node has an activation threshold. out_tensors_of_n = out_tensors_of_n_float @@ -279,12 +291,12 @@ def forward(self, node_to_output_tensors_dict_float = dict() configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO) for node in self.node_sort: + op_func = self._get_op_func(node, configurable_nodes) input_tensors = _build_input_tensors_list(node, self.graph, args, - node_to_output_tensors_dict) - - op_func = self._get_op_func(node, configurable_nodes) + node_to_output_tensors_dict, + isinstance(op_func, PytorchQuantizationWrapper)) use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node) # Run node operation and fetch outputs @@ -326,15 +338,16 @@ def _get_op_func(self, """ return getattr(self, node.name) - def _get_activation_quantization_fn(self, node) -> Tuple[bool, bool, Callable]: + def _get_activation_quantization_fn(self, node) -> Tuple[bool, Callable]: """ Get activation quantization parameters for this node. Args: node: Node from which to extract the activation quantization parameters. - Returns: Flag to indicate if we quantize activations, flag to indicate if we quantize activations - using a quantization holder and a quantization function to use for the node's activations. + Returns: + Flag to indicate if we quantize activations using a quantization holder and a quantization + function to use for the node's activations. """ activation_quantization_holder = self.node_to_activation_quantization_holder.get(node.name) use_activation_quantization = node.is_activation_quantization_enabled() diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py index b57cf2292..0d0a25a07 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py @@ -62,11 +62,11 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode, Returns: The modified convolution node's weight/kernel/ """ - if conv_node.type == ConvTranspose2d: + if conv_node.is_match_type(ConvTranspose2d): _scale = weights_scale[None, :, None, None] else: _scale = weights_scale[:, None, None, None] - if conv_node.type == ConvTranspose2d and conv_node.framework_attr[GROUPS] > 1: + if conv_node.is_match_type(ConvTranspose2d) and conv_node.framework_attr[GROUPS] > 1: # PyTorch ConvTranspose2d kernel with groups stacks groups on in_channels axis, so need to reshape the kernel # so the groups are stacked on the out_channels axis to match the scale vector (then reshape back to original # shape) @@ -93,10 +93,10 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode, Returns: The modified convolution node's weight/kernel/ """ - if conv_node.type == Conv2d and conv_node.framework_attr[GROUPS] > 1: + if conv_node.is_match_type(Conv2d) and conv_node.framework_attr[GROUPS] > 1: bias_update = (kernel * bias_factor[:, None, None, None]).flatten() _scale = weights_scale[:, None, None, None] - elif conv_node.type == ConvTranspose2d: + elif conv_node.is_match_type(ConvTranspose2d): bias_update = (kernel * bias_factor[:, None, None, None]).sum(axis=0).flatten() _scale = weights_scale[:, None, None, None] else: @@ -125,8 +125,8 @@ def is_group_conv_fn(node: BaseNode) -> bool: Returns: True if the node is a group convolution, else False """ - return node.type in [Conv2d, ConvTranspose2d] and \ - node.framework_attr[GROUPS] not in [node.framework_attr[IN_CHANNELS], 1] + return (node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d)) and \ + node.framework_attr[GROUPS] not in [node.framework_attr[IN_CHANNELS], 1] def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]: @@ -140,8 +140,8 @@ def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]: is_bn: True if the node is a batch norm, else False is_dw_valid: True if the node is a dw-convolution valid for folding or a batch-norm node, else False """ - is_bn = node.type is BatchNorm2d - is_dw = node.type is Conv2d and node.framework_attr[GROUPS] == node.framework_attr[IN_CHANNELS] + is_bn = node.is_match_type(BatchNorm2d) + is_dw = node.is_match_type(Conv2d) and node.framework_attr[GROUPS] == node.framework_attr[IN_CHANNELS] is_dw_valid = is_dw and np.all(np.array(node.get_weights_by_keys(KERNEL).shape[2:]) == 1) return is_bn, is_dw_valid diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py index ec9ea141d..978ca7f49 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py @@ -48,9 +48,9 @@ def substitute(self, Graph after applying the substitution. """ # Set new layer - if func_node.type == conv2d: + if func_node.is_match_type(conv2d): new_layer = Conv2d - elif func_node.type == conv_transpose2d: + elif func_node.is_match_type(conv_transpose2d): new_layer = ConvTranspose2d else: Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py index 2c1a41655..599113741 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py @@ -53,7 +53,7 @@ def conv2d_collapsing_fn(first_node: BaseNode, Returns: The modified layer node's weights: kernel, bias """ - if first_node.type == Conv2d and second_node.type == Conv2d: + if first_node.is_match_type(Conv2d) and second_node.is_match_type(Conv2d): # Get nodes attributes kernel1 = first_node.get_weights_by_keys(kernel_str) kernel2 = second_node.get_weights_by_keys(kernel_str) diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py index 634b4cf35..ee5fba9d1 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py @@ -76,17 +76,17 @@ def substitute(self, second_op2d_node = nodes_list[2] # only act on bound relu with not POT max value and 0 min value - if non_linear_node.type == ReLU6: + if non_linear_node.is_match_type(ReLU6): scale_factor = 6.0 / self.threshold non_linear_node.layer_class = Hardtanh non_linear_node.framework_attr[INPLACE] = False non_linear_node.framework_attr[HARDTANH_MIN_VAL] = 0.0 non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold - elif non_linear_node.type == relu6: + elif non_linear_node.is_match_type(relu6): scale_factor = 6.0 / self.threshold non_linear_node.functional_op = hardtanh non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, False) - elif non_linear_node.type == Hardtanh: + elif non_linear_node.is_match_type(Hardtanh): if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \ (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) - np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0): @@ -94,7 +94,7 @@ def substitute(self, non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold else: return graph - elif non_linear_node.type == hardtanh: + elif non_linear_node.is_match_type(hardtanh): if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \ (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) - np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0): diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py index a10c4cf68..8453517e7 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py @@ -46,7 +46,7 @@ def residual_collapsing_fn(first_node: BaseNode, Returns: The modified layer node's weights: kernel """ - if first_node.type == Conv2d: + if first_node.is_match_type(Conv2d): # Get nodes attributes kernel = first_node.get_weights_by_keys(kernel_str) (Cout, Cin, kH, kW) = kernel.shape diff --git a/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py index a1e32cbb5..5c680ae48 100644 --- a/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py @@ -76,9 +76,9 @@ def prune_intermediate_node(self, pruned_parameters = {} mask_bool = output_mask.astype(bool) node.weights = pruned_parameters - if node.type == torch.nn.BatchNorm2d: + if node.is_match_type(torch.nn.BatchNorm2d): node.framework_attr[NUM_FEATURES] = int(np.sum(input_mask)) - elif node.type == torch.nn.PReLU: + elif node.is_match_type(torch.nn.PReLU): if node.framework_attr[NUM_PARAMETERS] > 1: node.framework_attr[NUM_PARAMETERS] = int(np.sum(input_mask)) else: @@ -227,9 +227,9 @@ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool: """ # Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1. - if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]: + if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d): return node.framework_attr[GROUPS] == 1 - return node.type == torch.nn.Linear + return node.is_match_type(torch.nn.Linear) def _prune_pytorch_edge_node(node: BaseNode, @@ -268,18 +268,18 @@ def _prune_pytorch_edge_node(node: BaseNode, if not is_exit_node: # Update 'out_channels' or 'out_features' attributes for entry nodes # Conv2d,ConvTranspose2d / Linear layers. - if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]: + if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d): node.framework_attr[OUT_CHANNELS] = int(np.sum(mask)) - elif node.type == torch.nn.Linear: + elif node.is_match_type(torch.nn.Linear): node.framework_attr[OUT_FEATURES] = int(np.sum(mask)) else: Logger.critical(f"{node.type} is currently not supported" f"as an edge node in a pruning section") if is_exit_node: - if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]: + if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d): node.framework_attr[IN_CHANNELS] = int(np.sum(mask)) - elif node.type == torch.nn.Linear: + elif node.is_match_type(torch.nn.Linear): node.framework_attr[IN_FEATURES] = int(np.sum(mask)) else: Logger.critical(f"{node.type} is currently not supported" diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index bef158851..225b57d98 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -398,8 +398,8 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool Returns: True if the node should be considered an interest point, False otherwise. """ - if node.type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax, softmax, operator.add, add, cat, - operator.concat]: + if any([node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax, + softmax, operator.add, add, cat, operator.concat]]): return True return False @@ -464,12 +464,12 @@ def get_node_mac_operations(self, kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type) - if node.type is Conv2d or node.type is ConvTranspose2d: + if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d): # (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel) return np.prod([x for x in output_shape if x is not None]) * \ kernel_shape[input_channel_axis] * \ (kernel_shape[0] * kernel_shape[1]) - elif node.type is Linear: + elif node.is_match_type(Linear): # IN * OUT return kernel_shape[0] * kernel_shape[1] else: @@ -552,7 +552,6 @@ def get_inferable_quantizers(self, node: BaseNode): Returns: weight_quantizers: A dictionary between a weight's name to its quantizer. activation_quantizers: A list of activations quantization, one for each layer output. - """ return get_inferable_quantizers(node, diff --git a/model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py b/model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py index e39d7b77e..82e3d6dfb 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +++ b/model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py @@ -62,7 +62,7 @@ def _get_mean_std_outputs(node: BaseNode, """ mean_output, std_output = None, None - if node.type == BatchNorm2d: + if node.is_match_type(BatchNorm2d): mean_output = node.get_weights_by_keys(BETA) if node.get_weights_by_keys(GAMMA) is None: std_output = 1.0 @@ -72,7 +72,7 @@ def _get_mean_std_outputs(node: BaseNode, mean_output = 0.0 else: next_node_list = graph.get_next_nodes(node) - bn_nodes = [bn_node for bn_node in next_node_list if bn_node.type == BatchNorm2d] + bn_nodes = [bn_node for bn_node in next_node_list if bn_node.is_match_type(BatchNorm2d)] if len(bn_nodes) != 0: bn_node = bn_nodes[0] moving_variance = bn_node.get_weights_by_keys(MOVING_VARIANCE) diff --git a/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py index ea9c02f18..9a310d7f1 100644 --- a/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py @@ -42,8 +42,12 @@ def _get_wrapper(node: common.BaseNode, """ weights_quantizers, _ = fw_impl.get_inferable_quantizers(node) if len(weights_quantizers) > 0: + # for positional weights we need to extract the weight's value. + weights_values = {attr: node.get_weights_by_keys(attr) + for attr in weights_quantizers if isinstance(attr, int)} return KerasQuantizationWrapper(layer, - weights_quantizers) + weights_quantizers, + weights_values) return layer diff --git a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py index 72c143a6a..cbc3abeb9 100644 --- a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py @@ -29,7 +29,7 @@ def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module, - fw_impl) -> Union[torch.nn.Module,PytorchQuantizationWrapper]: + fw_impl) -> Union[torch.nn.Module, PytorchQuantizationWrapper]: """ A function which takes a computational graph node and a pytorch module and perform the quantization wrapping @@ -37,20 +37,26 @@ def fully_quantized_wrapper(node: common.BaseNode, Args: node: A node of mct graph. module: A Pytorch module + fw_impl: FrameworkImplementation object with a specific framework methods implementation. Returns: Wrapped layer """ weight_quantizers, _ = fw_impl.get_inferable_quantizers(node) if len(weight_quantizers) > 0: - return PytorchQuantizationWrapper(module, weight_quantizers) + # for positional weights we need to extract the weight's value. + weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr)) + for attr in weight_quantizers if isinstance(attr, int)} + return PytorchQuantizationWrapper(module, weight_quantizers, weights_values) return module + def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable: """ Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node. If the layer is not supposed to be wrapped with an activation quantizer - return None. Args: node: Node to attach a PytorchActivationQuantizationHolder to its output. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. Returns: A PytorchActivationQuantizationHolder module for the node's activation quantization. """ @@ -64,6 +70,7 @@ def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable: f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers ' f'were found for node {node}') + def get_exportable_pytorch_model(graph: Graph): """ Convert graph to fully quantized PyTorch model. diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py index d5fe7dc21..a66b124f4 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py @@ -56,7 +56,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza # We define a default quantization config for all non-specified weights attributes. default_weight_attr_config = AttributeQuantizationConfig( - weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, weights_n_bits=8, weights_per_channel_threshold=False, enable_weights_quantization=False, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py index e75b38f9b..d6e642119 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py @@ -32,7 +32,7 @@ def get_tp_model() -> TargetPlatformModel: NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations. - This version enables metadata by default + This version enables metadata by default. Returns: A TargetPlatformModel object. @@ -44,7 +44,8 @@ def get_tp_model() -> TargetPlatformModel: name='imx500_tp_model') -def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: +def get_op_quantization_configs() -> \ + Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: """ Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel. In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as @@ -151,6 +152,19 @@ def generate_tp_model(default_config: OpQuantizationConfig, # this configuration will be used for the operation quantization: default_configuration_options = tp.QuantizationConfigOptions([default_config]) + # Create a QuantizationConfigOptions for quantizing constants in functional ops. + # Constant configuration is similar to the default eight bit configuration except for PoT + # quantization method for the constant. + # Since the constants are not named attributes of the layer, we use the default_weight_attr_config to + # define the desired quantization properties for them. + const_config = default_config.clone_and_edit( + default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit( + enable_weights_quantization=True)) + if not (const_config.default_weight_attr_config.weights_quantization_method == tp.QuantizationMethod.POWER_OF_TWO and + const_config.default_weight_attr_config.weights_per_channel_threshold is False): + mct.logger.Logger.error('Constant quantization config should be per-tensor PoT.') + const_configuration_options = tp.QuantizationConfigOptions([const_config]) + # Create a TargetPlatformModel and set its default quantization config. # This default configuration will be used for all operations # unless specified otherwise (see OperatorsSet, for example): @@ -184,10 +198,10 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Define operations sets without quantization configuration # options (useful for creating fusing patterns, for example): any_relu = tp.OperatorsSet("AnyReLU") - add = tp.OperatorsSet("Add") - sub = tp.OperatorsSet("Sub") - mul = tp.OperatorsSet("Mul") - div = tp.OperatorsSet("Div") + add = tp.OperatorsSet("Add", const_configuration_options) + sub = tp.OperatorsSet("Sub", const_configuration_options) + mul = tp.OperatorsSet("Mul", const_configuration_options) + div = tp.OperatorsSet("Div", const_configuration_options) prelu = tp.OperatorsSet("PReLU") swish = tp.OperatorsSet("Swish") sigmoid = tp.OperatorsSet("Sigmoid") diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py index 02e7fb124..8f0ac8064 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py @@ -122,7 +122,7 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): tp.OperationsSetToLayers("Add", [tf.add, Add]) tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract]) tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply]) - tp.OperationsSetToLayers("Div", [tf.math.divide]) + tp.OperationsSetToLayers("Div", [tf.math.divide, tf.math.truediv]) tp.OperationsSetToLayers("PReLU", [PReLU]) tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py index 050d3c8a5..c9e556ead 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py @@ -33,7 +33,7 @@ def get_tp_model() -> TargetPlatformModel: NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations. - This version enables metadata by default + This version enables metadata by default. Returns: A TargetPlatformModel object. @@ -45,7 +45,8 @@ def get_tp_model() -> TargetPlatformModel: name='imx500_lut_tp_model') -def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: +def get_op_quantization_configs() -> \ + Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: """ Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel. In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as @@ -57,13 +58,13 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza # We define a default quantization config for all non-specified weights attributes. default_weight_attr_config = AttributeQuantizationConfig( - weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, weights_n_bits=8, weights_per_channel_threshold=False, enable_weights_quantization=False, lut_values_bitwidth=None) - # We define a quantization config to quantize the kernel (for layers where there is a kernel attribute). + # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). kernel_base_config = AttributeQuantizationConfig( weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, weights_n_bits=8, @@ -150,6 +151,19 @@ def generate_tp_model(default_config: OpQuantizationConfig, # this configuration will be used for the operation quantization: default_configuration_options = tp.QuantizationConfigOptions([default_config]) + # Create a QuantizationConfigOptions for quantizing constants in functional ops. + # Constant configuration is similar to the default eight bit configuration except for PoT + # quantization method for the constant. + # Since the constants are not named attributes of the layer, we use the default_weight_attr_config to + # define the desired quantization properties for them. + const_config = default_config.clone_and_edit( + default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit( + enable_weights_quantization=True)) + if not (const_config.default_weight_attr_config.weights_quantization_method == tp.QuantizationMethod.POWER_OF_TWO and + const_config.default_weight_attr_config.weights_per_channel_threshold is False): + mct.logger.Logger.error('Constant quantization config should be per-tensor PoT.') + const_configuration_options = tp.QuantizationConfigOptions([const_config]) + # Create a TargetPlatformModel and set its default quantization config. # This default configuration will be used for all operations # unless specified otherwise (see OperatorsSet, for example): @@ -181,10 +195,10 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Define operations sets without quantization configuration # options (useful for creating fusing patterns, for example): any_relu = tp.OperatorsSet("AnyReLU") - add = tp.OperatorsSet("Add") - sub = tp.OperatorsSet("Sub") - mul = tp.OperatorsSet("Mul") - div = tp.OperatorsSet("Div") + add = tp.OperatorsSet("Add", const_configuration_options) + sub = tp.OperatorsSet("Sub", const_configuration_options) + mul = tp.OperatorsSet("Mul", const_configuration_options) + div = tp.OperatorsSet("Div", const_configuration_options) prelu = tp.OperatorsSet("PReLU") swish = tp.OperatorsSet("Swish") sigmoid = tp.OperatorsSet("Sigmoid") diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py index cee39d1fd..3259ecd6f 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py @@ -122,7 +122,7 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): tp.OperationsSetToLayers("Add", [tf.add, Add]) tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract]) tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply]) - tp.OperationsSetToLayers("Div", [tf.math.divide]) + tp.OperationsSetToLayers("Div", [tf.math.divide, tf.math.truediv]) tp.OperationsSetToLayers("PReLU", [PReLU]) tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) diff --git a/tests/common_tests/helpers/generate_test_tp_model.py b/tests/common_tests/helpers/generate_test_tp_model.py index 6fa3c5aa1..dba3b05e3 100644 --- a/tests/common_tests/helpers/generate_test_tp_model.py +++ b/tests/common_tests/helpers/generate_test_tp_model.py @@ -74,7 +74,8 @@ def generate_mixed_precision_test_tp_model(base_cfg, default_config, mp_bitwidth name=name) -def generate_tp_model_with_activation_mp(base_cfg, default_config, mp_bitwidth_candidates_list, name="activation_mp_model"): +def generate_tp_model_with_activation_mp(base_cfg, default_config, mp_bitwidth_candidates_list, + name="activation_mp_model"): mp_op_cfg_list = [] for weights_n_bits, activation_n_bits in mp_bitwidth_candidates_list: diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py new file mode 100644 index 000000000..a76732e82 --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py @@ -0,0 +1,120 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import partial +import tensorflow as tf +import numpy as np + +import model_compression_toolkit as mct +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest +from tests.common_tests.helpers.tensors_compare import cosine_similarity +from mct_quantizers import KerasQuantizationWrapper + +from model_compression_toolkit.constants import TENSORFLOW +from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL + +keras = tf.keras +layers = keras.layers +tp = mct.target_platform + + +class ConstQuantizationTest(BaseKerasFeatureNetworkTest): + + def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwargs=False, + input_shape=(32, 32, 16)): + super(ConstQuantizationTest, self).__init__(unit_test=unit_test, input_shape=input_shape) + self.layer = layer + self.const = const + self.is_list_input = is_list_input + self.input_reverse_order = input_reverse_order + self.use_kwargs = use_kwargs + + def generate_inputs(self): + # need positive inputs so won't divide with zero or take root of negative number + return [1 + np.random.random(in_shape) for in_shape in self.get_input_shapes()] + + def get_tpc(self): + return mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v2") + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = inputs + if self.is_list_input: + if self.input_reverse_order: + x = self.layer([self.const, x]) + else: + x = self.layer([x, self.const]) + else: + if self.input_reverse_order: + if self.use_kwargs: + x = self.layer(x=self.const, y=x) + else: + x = self.layer(self.const, x) + else: + if self.use_kwargs: + x = self.layer(x=x, y=self.const) + else: + x = self.layer(x, self.const) + return tf.keras.models.Model(inputs=inputs, outputs=x) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + cs = cosine_similarity(y, y_hat) + self.unit_test.assertTrue(np.isclose(cs, 1, atol=0.001), msg=f'fail cosine similarity check:{cs}') + self.unit_test.assertTrue(isinstance(quantized_model.layers[2], KerasQuantizationWrapper), + msg='TFOpLambda should be quantized') + const_index = 0 if self.input_reverse_order else 1 + self.unit_test.assertTrue((quantized_model.layers[2].weight_values[const_index] == self.const).all(), + msg='Constant value should not change') + + +class AdvancedConstQuantizationTest(BaseKerasFeatureNetworkTest): + + def __init__(self, unit_test, input_shape=(32, 32, 3)): + super(AdvancedConstQuantizationTest, self).__init__(unit_test=unit_test, input_shape=input_shape) + self.const = np.random.random((130,)) + + def get_ptq_facade(self): + gptq_config = mct.gptq.get_keras_gptq_config(30) + return partial(mct.gptq.keras_gradient_post_training_quantization, + gptq_config=gptq_config) + + def get_resource_utilization(self): + return mct.core.ResourceUtilization(9e3) + + def generate_inputs(self): + # need positive inputs so won't divide with zero or take root of negative number + return [1 + np.random.random(in_shape) for in_shape in self.get_input_shapes()] + + def get_tpc(self): + return mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v2") + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Conv2D(130, 3)(inputs) + x = layers.ReLU()(x) + x = tf.add(x, self.const) + x = layers.Conv2D(16, 3)(x) + return tf.keras.models.Model(inputs=inputs, outputs=x) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + self.unit_test.assertTrue(isinstance(quantized_model.layers[5], KerasQuantizationWrapper), + msg='TFOpLambda should be quantized') + self.unit_test.assertTrue((quantized_model.layers[5].weight_values[1] == self.const).all(), + msg='Constant value should not change') diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py index 3cfe2d5fb..afd83fc57 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py @@ -28,14 +28,14 @@ class ConstRepresentationTest(BaseKerasFeatureNetworkTest): - def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwrags=False, + def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwargs=False, input_shape=(32, 32, 16)): super(ConstRepresentationTest, self).__init__(unit_test=unit_test, input_shape=input_shape) self.layer = layer self.const = const self.is_list_input = is_list_input self.input_reverse_order = input_reverse_order - self.use_kwrags = use_kwrags + self.use_kwargs = use_kwargs def generate_inputs(self): # need positive inputs so won't divide with zero or take root of negative number @@ -58,12 +58,12 @@ def create_networks(self): x = self.layer([x, self.const]) else: if self.input_reverse_order: - if self.use_kwrags: + if self.use_kwargs: x = self.layer(x=self.const, y=x) else: x = self.layer(self.const, x) else: - if self.use_kwrags: + if self.use_kwargs: x = self.layer(x=x, y=self.const) else: x = self.layer(x, self.const) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py index 50e7ea79d..2a85f0337 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py @@ -51,8 +51,9 @@ def get_tpc(self): # sets all combinations of 2, 4, 8 bits for weights and activations mixed_precision_candidates_list = get_base_mp_nbits_candidates() + default_config = eight_bits.clone_and_edit(attr_weights_configs_mapping={}) return get_tpc_with_activation_mp_keras(base_config=eight_bits, - default_config=eight_bits.clone_and_edit(attr_weights_configs_mapping={}), + default_config=default_config, mp_bitwidth_candidates_list=mixed_precision_candidates_list, name="mixed_precision_activation_test") diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 22308a900..6fdebdeb0 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -134,6 +134,8 @@ from tests.keras_tests.feature_networks_tests.feature_networks.const_representation_test import ConstRepresentationTest, \ ConstRepresentationMultiInputTest, ConstRepresentationMatMulTest from tests.keras_tests.feature_networks_tests.feature_networks.concatination_threshold_update import ConcatThresholdtest +from tests.keras_tests.feature_networks_tests.feature_networks.const_quantization_test import ConstQuantizationTest, \ + AdvancedConstQuantizationTest from model_compression_toolkit.qat.common.qat_config import TrainingMethod layers = tf.keras.layers @@ -537,13 +539,25 @@ def test_linear_collapsing(self): SixConv2DCollapsingTest(self).run_test() Op2DAddConstCollapsingTest(self).run_test() + def test_const_quantization(self): + c = (np.ones((16,)) + np.random.random((16,))).astype(np.float32) + for func in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv]: + ConstQuantizationTest(self, func, c).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True).run_test() + ConstQuantizationTest(self, func, c, use_kwargs=True).run_test() + ConstQuantizationTest(self, func, 2.45).run_test() + ConstQuantizationTest(self, func, 5.1, input_reverse_order=True).run_test() + + AdvancedConstQuantizationTest(self).run_test() + def test_const_representation(self): c = (np.ones((16,)) + np.random.random((16,))).astype(np.float32) for func in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow]: ConstRepresentationTest(self, func, c).run_test() ConstRepresentationTest(self, func, c, input_reverse_order=True).run_test() - ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwrags=True).run_test() - ConstRepresentationTest(self, func, c, use_kwrags=True).run_test() + ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True).run_test() + ConstRepresentationTest(self, func, c, use_kwargs=True).run_test() ConstRepresentationTest(self, func, 2.45).run_test() ConstRepresentationTest(self, func, 5.1, input_reverse_order=True).run_test() @@ -554,8 +568,8 @@ def test_const_representation(self): for func in [layers.Add(), layers.Multiply(), layers.Subtract()]: ConstRepresentationTest(self, func, c, is_list_input=True).run_test() ConstRepresentationTest(self, func, c, input_reverse_order=True, is_list_input=True).run_test() - ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwrags=True, is_list_input=True).run_test() - ConstRepresentationTest(self, func, c, use_kwrags=True, is_list_input=True).run_test() + ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True, is_list_input=True).run_test() + ConstRepresentationTest(self, func, c, use_kwargs=True, is_list_input=True).run_test() ConstRepresentationMultiInputTest(self).run_test() diff --git a/tests/keras_tests/function_tests/test_hessian_info_calculator.py b/tests/keras_tests/function_tests/test_hessian_info_calculator.py index 7bdecaaa8..2360a35e9 100644 --- a/tests/keras_tests/function_tests/test_hessian_info_calculator.py +++ b/tests/keras_tests/function_tests/test_hessian_info_calculator.py @@ -237,7 +237,7 @@ def test_reused_layer(self): sorted_graph_nodes = graph.get_topo_sorted_nodes() # Two nodes representing the same reused layer - interest_points = [n for n in sorted_graph_nodes if n.type == Conv2D] + interest_points = [n for n in sorted_graph_nodes if n.is_match_type(Conv2D)] self.assertTrue(len(interest_points)==2, f"Expected to find 2 Conv2D nodes but found {len(interest_points)}") hessian_service = hessian_common.HessianInfoService(graph=graph, @@ -419,7 +419,7 @@ def test_reused_layer(self): sorted_graph_nodes = graph.get_topo_sorted_nodes() # Two nodes representing the same reused layer - interest_points = [n for n in sorted_graph_nodes if n.type == Conv2D] + interest_points = [n for n in sorted_graph_nodes if n.is_match_type(Conv2D)] self.assertTrue(len(interest_points)==2, f"Expected to find 2 Conv2D nodes but found {len(interest_points)}") hessian_service = hessian_common.HessianInfoService(graph=graph, diff --git a/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py b/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py index f8884348d..209b2ac3b 100644 --- a/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py +++ b/tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py @@ -38,6 +38,7 @@ def argmax_output_model(input_shape): model = keras.Model(inputs=inputs, outputs=outputs) return model + def nms_output_model(input_shape): inputs = layers.Input(shape=input_shape) x = layers.Conv2D(1, 3, padding='same')(inputs) @@ -49,7 +50,7 @@ def nms_output_model(input_shape): y = tf.concat([x, x], -1) y = tf.concat([y, y], -1) scores = tf.concat([x, y], -1) # shape = (batch, detections, classes) - boxes, _ = tf.split(x, (4,12), -1) + boxes, _ = tf.split(x, (4, 12), -1) boxes = tf.expand_dims(boxes, 2) # shape = (batch, detections, 1, box coordinates) # NMS layer @@ -104,5 +105,6 @@ def test_not_supported_output_nms(self): self.verify_test_for_model(model) self.assertTrue("All graph outputs must support Hessian score computation" in str(e.exception)) + if __name__ == '__main__': unittest.main() diff --git a/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py b/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py index 6d181664c..93fa0b881 100644 --- a/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py +++ b/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py @@ -147,7 +147,7 @@ def test_all_weights_layers_split(self): graph, split_graph = test_setup(in_model, keras_impl, mixed_precision_candidates_list=_get_base_mp_nbits_candidates()) weights_node_types = [Conv2D, Conv2DTranspose, DepthwiseConv2D, Dense] - original_weights_nodes = [n for n in graph.get_topo_sorted_nodes() if n.type in weights_node_types] + original_weights_nodes = [n for n in graph.get_topo_sorted_nodes() if any([n.is_match_type(_type) for _type in weights_node_types])] self.assertTrue(len(split_graph.nodes) == len(graph.nodes) + len(original_weights_nodes)) diff --git a/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py b/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py index e077fc88d..78ff7c63f 100644 --- a/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py +++ b/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py @@ -95,6 +95,8 @@ def forward(self, x, y): def is_node_fake_quant(node): return node.target == torch.fake_quantize_per_tensor_affine + + def has_nested_attr(obj, attr): """ Check if an object `obj` has a nested attribute `attr`. diff --git a/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py b/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py new file mode 100644 index 000000000..83b84d034 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py @@ -0,0 +1,136 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import partial +import torch +import torch.nn as nn +import numpy as np +import model_compression_toolkit as mct +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model +from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest +from tests.common_tests.helpers.tensors_compare import cosine_similarity +from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL +from model_compression_toolkit.constants import PYTORCH +from mct_quantizers import PytorchQuantizationWrapper + +tp = mct.target_platform + + +class ConstQuantizationNet(nn.Module): + def __init__(self, layer, const): + super().__init__() + self.layer = layer + self.const = to_torch_tensor(const) if isinstance(const, np.ndarray) else const + + def forward(self, x): + return self.layer(x, self.const) + + +class ConstQuantizationReverseOrderNet(nn.Module): + def __init__(self, layer, const): + super().__init__() + self.layer = layer + self.const = to_torch_tensor(const) if isinstance(const, np.ndarray) else const + + def forward(self, x): + return self.layer(self.const, x) + + +class ConstQuantizationTest(BasePytorchFeatureNetworkTest): + + def __init__(self, unit_test, func, const, input_reverse_order=False): + super().__init__(unit_test=unit_test, input_shape=(16, 32, 32)) + self.func = func + self.const = const + self.input_reverse_order = input_reverse_order + + def generate_inputs(self): + return [np.random.random(in_shape)+1 for in_shape in self.get_input_shapes()] + + def get_tpc(self): + return mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, "v2") + + def create_networks(self): + if self.input_reverse_order: + return ConstQuantizationReverseOrderNet(self.func, self.const) + else: + return ConstQuantizationNet(self.func, self.const) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + in_torch_tensor = to_torch_tensor(input_x[0]) + set_model(float_model) + y = float_model(in_torch_tensor) + y_hat = quantized_model(in_torch_tensor) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat)) + self.unit_test.assertTrue(np.isclose(cs, 1, atol=0.001), msg=f'fail cosine similarity check: {cs}') + for n, m in quantized_model.named_modules(): + if n == self.func.__name__: + self.unit_test.assertTrue(isinstance(m, PytorchQuantizationWrapper), + msg=f'Expected layer type to be "PytorchQuantizationWrapper" but got {type(m)}.') + self.unit_test.assertTrue((list(m.weight_values.values())[0].detach().cpu().numpy() == + self.const).all(), + msg=f'Expected PytorchQuantizationWrapper const value to match float const.') + + +class AdvancedConstQuantizationNet(nn.Module): + def __init__(self, const): + super().__init__() + self.conv1 = nn.Conv2d(3, 130, 3) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(130, 16, 3) + self.const = to_torch_tensor(const) if isinstance(const, np.ndarray) else const + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = torch.add(x, self.const) + x = self.conv2(x) + return x + + +class AdvancedConstQuantizationTest(BasePytorchFeatureNetworkTest): + def __init__(self, unit_test): + super().__init__(unit_test=unit_test, input_shape=(3, 32, 32)) + self.const = (np.random.random((130, 1, 1))).astype(np.float32) + + def get_ptq_facade(self): + gptq_config = mct.gptq.get_pytorch_gptq_config(30) + return partial(mct.gptq.pytorch_gradient_post_training_quantization, + gptq_config=gptq_config) + + def get_resource_utilization(self): + return mct.core.ResourceUtilization(9e3) + + def generate_inputs(self): + return [np.random.random(in_shape)+1 for in_shape in self.get_input_shapes()] + + def get_tpc(self): + return mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, "v2") + + def create_networks(self): + return AdvancedConstQuantizationNet(self.const) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + in_torch_tensor = to_torch_tensor(input_x[0]) + set_model(float_model) + y = float_model(in_torch_tensor) + y_hat = quantized_model(in_torch_tensor) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + for n, m in quantized_model.named_modules(): + if n == torch.add.__name__: + self.unit_test.assertTrue(isinstance(m, PytorchQuantizationWrapper), + msg=f'Expected layer type to be "PytorchQuantizationWrapper" but got {type(m)}.') + self.unit_test.assertTrue((list(m.weight_values.values())[0].detach().cpu().numpy() == + self.const).all(), + msg=f'Expected PytorchQuantizationWrapper const value to match float const.') diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index fd74b1125..19504b3af 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -89,6 +89,8 @@ from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \ ConstRepresentationMultiInputTest from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod +from tests.pytorch_tests.model_tests.feature_models.const_quantization_test import ConstQuantizationTest, \ + AdvancedConstQuantizationTest from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest @@ -230,6 +232,16 @@ def test_residual_collapsing(self): ResidualCollapsingTest1(self).run_test() ResidualCollapsingTest2(self).run_test() + def test_const_quantization(self): + c = (np.ones((32,)) + np.random.random((32,))).astype(np.float32) + for func in [torch.add, torch.sub, torch.mul, torch.div]: + ConstQuantizationTest(self, func, c).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True).run_test() + ConstQuantizationTest(self, func, 2.45).run_test() + ConstQuantizationTest(self, func, 5, input_reverse_order=True).run_test() + + AdvancedConstQuantizationTest(self).run_test() + def test_const_representation(self): c = (np.ones((32,)) + np.random.random((32,))).astype(np.float32) for func in [torch.add, torch.sub, torch.mul, torch.div]: