diff --git a/model_compression_toolkit/constants.py b/model_compression_toolkit/constants.py index 9bcf6b68f..632d1dc56 100644 --- a/model_compression_toolkit/constants.py +++ b/model_compression_toolkit/constants.py @@ -24,6 +24,10 @@ FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None +# Metadata fields +MCT_VERSION = 'mct_version' +TPC_VERSION = 'tpc_version' + WEIGHTS_SIGNED = True # Minimal threshold to use for quantization ranges: MIN_THRESHOLD = (2 ** -16) diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index bd20c8277..d33fab23a 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -103,9 +103,10 @@ def set_tpc(self, if n.is_custom: if not is_node_in_tpc: Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. ' - f' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature request or an issue if you believe this should be supported.') + ' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature ' + 'request or an issue if you believe this should be supported.') # pragma: no cover if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]): - Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') + Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover self.tpc = tpc diff --git a/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py b/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py index 70ca1b853..2078d7e56 100644 --- a/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +++ b/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py @@ -100,7 +100,7 @@ def _unwrap_quantize_wrapper(layer: Layer): weights_list.append(layer.get_quantized_weights()['kernel']) else: Logger.critical(f'KerasQuantizationWrapper should wrap only DepthwiseConv2D, Conv2D, Dense' - f' and Conv2DTranspose layers but wrapped layer is {layer.layer}') + f' and Conv2DTranspose layers but wrapped layer is {layer.layer}') if layer.layer.bias is not None: weights_list.append(layer.layer.bias) @@ -121,6 +121,11 @@ def _unwrap_quantize_wrapper(layer: Layer): return layer + # Delete metadata layer if exists + if hasattr(self.model, 'metadata_layer'): + Logger.info('Metadata is not exported to FakelyQuant models.') + delattr(self.model, 'metadata_layer') + # clone each layer in the model and apply _unwrap_quantize_wrapper to layers wrapped with a QuantizeWrapper. self.exported_model = tf.keras.models.clone_model(self.model, input_tensors=None, diff --git a/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py b/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py index f3a65fdcb..7174751a4 100644 --- a/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +++ b/model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py @@ -56,6 +56,11 @@ def export(self): (namely, weights that are in fake-quant format) and fake-quant layers for the activations. """ + # Delete metadata layer if exists + if hasattr(self.model, 'metadata_layer'): + Logger.info('Metadata is not exported to TFLite models.') + delattr(self.model, 'metadata_layer') + # Use Keras exporter to quantize model's weights before converting it to TFLite. # Since exporter saves the model, we use a tmp path for saving, and then we delete it. handle, tmp_file = tempfile.mkstemp(DEFAULT_KERAS_EXPORT_EXTENTION) diff --git a/model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py b/model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py index 48f77a188..281cf0abb 100644 --- a/model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +++ b/model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py @@ -166,6 +166,11 @@ def _substitute_model(layer_to_substitue: keras.layers.Layer) -> keras.layers.La return layer_to_substitue + # Delete metadata layer if exists + if hasattr(self.model, 'metadata_layer'): + Logger.info('Metadata is not exported to TFLite models.') + delattr(self.model, 'metadata_layer') + # Transform the model to a new model that can be converted to int8 models. # For example: replace dense layers with point-wise layers (to support per-channel quantization) self.transformed_model = clone_model(self.model, diff --git a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py index 4ea17a576..436159fc0 100644 --- a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +++ b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py @@ -13,17 +13,21 @@ # limitations under the License. # ============================================================================== from typing import Callable +from io import BytesIO import torch.nn +import onnx from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.pytorch.utils import to_torch_tensor from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter from mct_quantizers import pytorch_quantizers +from mct_quantizers.pytorch.metadata import add_onnx_metadata DEFAULT_ONNX_OPSET_VERSION=15 + class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter): """ Exporter for fakely-quant PyTorch models. @@ -58,7 +62,6 @@ def __init__(self, self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops self._onnx_opset_version = onnx_opset_version - def export(self) -> None: """ Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model @@ -74,7 +77,7 @@ def export(self) -> None: # If _use_onnx_custom_quantizer_ops is set to True, the quantizer forward function will use # the custom implementation when exporting the operator into onnx model. If not, it removes the # wraps and quantizes the ops in place (for weights, for activation torch quantization function is - # exported since it's used during forward. + # exported since it's used during forward). if self._use_onnx_custom_quantizer_ops: self._enable_onnx_custom_ops_export() else: @@ -87,15 +90,30 @@ def export(self) -> None: model_input = to_torch_tensor(next(self.repr_dataset())[0]) - torch.onnx.export(self.model, - model_input, - self.save_model_path, - opset_version=self._onnx_opset_version, - verbose=False, - input_names=['input'], - output_names=['output'], - dynamic_axes={'input': {0: 'batch_size'}, - 'output': {0: 'batch_size'}}) + if hasattr(self.model, 'metadata'): + onnx_bytes = BytesIO() + torch.onnx.export(self.model, + model_input, + onnx_bytes, + opset_version=self._onnx_opset_version, + verbose=False, + input_names=['input'], + output_names=['output'], + dynamic_axes={'input': {0: 'batch_size'}, + 'output': {0: 'batch_size'}}) + onnx_model = onnx.load_from_string(onnx_bytes.getvalue()) + onnx_model = add_onnx_metadata(onnx_model, self.model.metadata) + onnx.save_model(onnx_model, self.save_model_path) + else: + torch.onnx.export(self.model, + model_input, + self.save_model_path, + opset_version=self._onnx_opset_version, + verbose=False, + input_names=['input'], + output_names=['output'], + dynamic_axes={'input': {0: 'batch_size'}, + 'output': {0: 'batch_size'}}) def _enable_onnx_custom_ops_export(self): """ diff --git a/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py b/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py index 36676e610..e4cf5fa69 100644 --- a/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +++ b/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py @@ -40,7 +40,7 @@ def pytorch_export_model(model: torch.nn.Module, repr_dataset: Callable, is_layer_exportable_fn: Callable = is_pytorch_layer_exportable, serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX, - quantization_format : QuantizationFormat = QuantizationFormat.MCTQ, + quantization_format: QuantizationFormat = QuantizationFormat.MCTQ, onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) -> None: """ Export a PyTorch quantized model to a torchscript or onnx model. diff --git a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py index c84f758c3..72c143a6a 100644 --- a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py @@ -93,4 +93,4 @@ def get_exportable_pytorch_model(graph: Graph): else: def get_exportable_pytorch_model(*args, **kwargs): Logger.critical("PyTorch must be installed to use 'get_exportable_pytorch_model'. " - "The 'torch' package is missing.") # pragma: no cover \ No newline at end of file + "The 'torch' package is missing.") # pragma: no cover diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index a6bf52d31..47e81bb87 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -31,6 +31,7 @@ from model_compression_toolkit.gptq.runner import gptq_runner from model_compression_toolkit.core.analyzer import analyzer_model_quantization from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities +from model_compression_toolkit.metadata import get_versions_dict LR_DEFAULT = 0.15 LR_REST_DEFAULT = 1e-4 @@ -48,6 +49,7 @@ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model from model_compression_toolkit import get_target_platform_capabilities + from mct_quantizers.keras.metadata import add_metadata # As from TF2.9 optimizers package is changed if version.parse(tf.__version__) < version.parse("2.9"): @@ -234,7 +236,10 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da fw_impl, DEFAULT_KERAS_INFO) - return get_exportable_keras_model(tg_gptq) + exportable_model, user_info = get_exportable_keras_model(tg_gptq) + if target_platform_capabilities.tp_model.add_metadata: + exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities)) + return exportable_model, user_info else: # If tensorflow is not installed, diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 1baea72f6..678483514 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -31,6 +31,7 @@ from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ MixedPrecisionQuantizationConfig +from model_compression_toolkit.metadata import get_versions_dict LR_DEFAULT = 1e-4 LR_REST_DEFAULT = 1e-4 @@ -47,6 +48,7 @@ from torch.nn import Module from torch.optim import Adam, Optimizer from model_compression_toolkit import get_target_platform_capabilities + from mct_quantizers.pytorch.metadata import add_metadata DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL) def get_pytorch_gptq_config(n_epochs: int, @@ -202,7 +204,10 @@ def pytorch_gradient_post_training_quantization(model: Module, fw_impl, DEFAULT_PYTORCH_INFO) - return get_exportable_pytorch_model(graph_gptq) + exportable_model, user_info = get_exportable_pytorch_model(graph_gptq) + if target_platform_capabilities.tp_model.add_metadata: + exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities)) + return exportable_model, user_info else: diff --git a/model_compression_toolkit/metadata.py b/model_compression_toolkit/metadata.py new file mode 100644 index 000000000..6db6f9634 --- /dev/null +++ b/model_compression_toolkit/metadata.py @@ -0,0 +1,29 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Dict +from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION + + +def get_versions_dict(tpc) -> Dict: + """ + + Returns: A dictionary with TPC and MCT versions. + + """ + # imported inside to avoid circular import error + from model_compression_toolkit import __version__ as mct_version + tpc_version = f'{tpc.name}.{tpc.version}' + return {MCT_VERSION: mct_version, TPC_VERSION: tpc_version} diff --git a/model_compression_toolkit/ptq/keras/quantization_facade.py b/model_compression_toolkit/ptq/keras/quantization_facade.py index f84d6e41b..7a2d43a81 100644 --- a/model_compression_toolkit/ptq/keras/quantization_facade.py +++ b/model_compression_toolkit/ptq/keras/quantization_facade.py @@ -28,6 +28,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.ptq.runner import ptq_runner +from model_compression_toolkit.metadata import get_versions_dict if FOUND_TF: from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO @@ -38,6 +39,7 @@ from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model from model_compression_toolkit import get_target_platform_capabilities + from mct_quantizers.keras.metadata import add_metadata DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL) @@ -164,8 +166,10 @@ def keras_post_training_quantization(in_model: Model, fw_impl, fw_info) - return get_exportable_keras_model(graph_with_stats_correction) - + exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction) + if target_platform_capabilities.tp_model.add_metadata: + exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities)) + return exportable_model, user_info else: diff --git a/model_compression_toolkit/ptq/pytorch/quantization_facade.py b/model_compression_toolkit/ptq/pytorch/quantization_facade.py index 5dfc87ff4..fac72bb80 100644 --- a/model_compression_toolkit/ptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/ptq/pytorch/quantization_facade.py @@ -29,6 +29,7 @@ from model_compression_toolkit.ptq.runner import ptq_runner from model_compression_toolkit.core.analyzer import analyzer_model_quantization from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights +from model_compression_toolkit.metadata import get_versions_dict if FOUND_TORCH: @@ -38,6 +39,7 @@ from torch.nn import Module from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model from model_compression_toolkit import get_target_platform_capabilities + from mct_quantizers.pytorch.metadata import add_metadata DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL) @@ -139,7 +141,10 @@ def pytorch_post_training_quantization(in_module: Module, fw_impl, fw_info) - return get_exportable_pytorch_model(graph_with_stats_correction) + exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction) + if target_platform_capabilities.tp_model.add_metadata: + exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities)) + return exportable_model, user_info else: diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py b/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py index c0bba5a2e..17aca54f2 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py @@ -60,15 +60,18 @@ class TargetPlatformModel(ImmutableClass): def __init__(self, default_qco: QuantizationConfigOptions, + add_metadata: bool = False, name="default_tp_model"): """ Args: default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model. + add_metadata (bool): Whether to add metadata to the model or not. name (str): Name of the model. """ super().__init__() + self.add_metadata = add_metadata self.name = name self.operator_set = [] assert isinstance(default_qco, QuantizationConfigOptions) @@ -191,7 +194,7 @@ def __validate_model(self): """ opsets_names = [op.name for op in self.operator_set] - if (len(set(opsets_names)) != len(opsets_names)): + if len(set(opsets_names)) != len(opsets_names): Logger.critical(f'Operator Sets must have unique names.') def get_default_config(self) -> OpQuantizationConfig: diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py index f832ca910..80385553b 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py @@ -29,6 +29,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import TargetPlatformModel from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc +from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION class TargetPlatformCapabilities(ImmutableClass): diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py index 86baa7536..27fc67b4d 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py @@ -25,11 +25,15 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1 from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v1_lut from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import get_keras_tpc as get_keras_tpc_v1_pot + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v2_lut # Keras: TPC versioning keras_tpc_models_dict = {'v1': get_keras_tpc_v1(), 'v1_lut': get_keras_tpc_v1_lut(), 'v1_pot': get_keras_tpc_v1_pot(), + 'v2': get_keras_tpc_v2(), + 'v2_lut': get_keras_tpc_v2_lut(), LATEST: get_keras_tpc_latest()} ############################### @@ -42,13 +46,19 @@ get_pytorch_tpc as get_pytorch_tpc_v1 from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \ get_pytorch_tpc as get_pytorch_tpc_v1_pot - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_v1_lut - + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v1_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v2 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v2_lut # Pytorch: TPC versioning pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1(), 'v1_lut': get_pytorch_tpc_v1_lut(), 'v1_pot': get_pytorch_tpc_v1_pot(), + 'v2': get_pytorch_tpc_v2(), + 'v2_lut': get_pytorch_tpc_v2_lut(), LATEST: get_pytorch_tpc_latest()} tpc_dict = {TENSORFLOW: keras_tpc_models_dict, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py new file mode 100644 index 000000000..087b9fbf1 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +__version__ = 'v2' diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py new file mode 100644 index 000000000..e75b38f9b --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py @@ -0,0 +1,210 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import List, Tuple + +import model_compression_toolkit as mct +from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS +from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \ + TargetPlatformModel +from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \ + AttributeQuantizationConfig + +tp = mct.target_platform + + +def get_tp_model() -> TargetPlatformModel: + """ + A method that generates a default target platform model, with base 8-bit quantization configuration and 8, 4, 2 + bits configuration list for mixed-precision quantization. + NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets + (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the + 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations. + This version enables metadata by default + + Returns: A TargetPlatformModel object. + + """ + base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() + return generate_tp_model(default_config=default_config, + base_config=base_config, + mixed_precision_cfg_list=mixed_precision_cfg_list, + name='imx500_tp_model') + + +def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: + """ + Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel. + In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as + default configuration for mixed-precision quantization. + + Returns: An OpQuantizationConfig config object and a list of OpQuantizationConfig objects. + + """ + + # TODO: currently, we don't want to quantize any attribute but the kernel by default, + # to preserve the current behavior of MCT, so quantization is disabled for all other attributes. + # Other quantization parameters are set to what we eventually want to quantize by default + # when we enable multi-attributes quantization - THIS NEED TO BE MODIFIED IN ALL TP MODELS! + + # define a default quantization config for all non-specified weights attributes. + default_weight_attr_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=False, # TODO: this will changed to True once implementing multi-attributes quantization + lut_values_bitwidth=None) + + # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). + kernel_base_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, + weights_n_bits=8, + weights_per_channel_threshold=True, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # define a quantization config to quantize the bias (for layers where there is a bias attribute). + bias_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + weights_n_bits=FLOAT_BITWIDTH, + weights_per_channel_threshold=False, + enable_weights_quantization=False, + lut_values_bitwidth=None) + + # Create a quantization config. + # A quantization configuration defines how an operator + # should be quantized on the modeled hardware: + + # We define a default config for operation without kernel attribute. + # This is the default config that should be used for non-linear operations. + eight_bits_default = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32) + + # We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes. + linear_eight_bits = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32) + + # To quantize a model using mixed-precision, create + # a list with more than one OpQuantizationConfig. + # In this example, we quantize some operations' weights + # using 2, 4 or 8 bits, and when using 2 or 4 bits, it's possible + # to quantize the operations' activations using LUT. + four_bits = linear_eight_bits.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}}, + simd_size=linear_eight_bits.simd_size * 2) + two_bits = linear_eight_bits.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 2}}, + simd_size=linear_eight_bits.simd_size * 4) + + mixed_precision_cfg_list = [linear_eight_bits, four_bits, two_bits] + + return linear_eight_bits, mixed_precision_cfg_list, eight_bits_default + + +def generate_tp_model(default_config: OpQuantizationConfig, + base_config: OpQuantizationConfig, + mixed_precision_cfg_list: List[OpQuantizationConfig], + name: str) -> TargetPlatformModel: + """ + Generates TargetPlatformModel with default defined Operators Sets, based on the given base configuration and + mixed-precision configurations options list. + + Args + default_config: A default OpQuantizationConfig to set as the TP model default configuration. + base_config: An OpQuantizationConfig to set as the TargetPlatformModel base configuration for mixed-precision purposes only. + mixed_precision_cfg_list: A list of OpQuantizationConfig to be used as the TP model mixed-precision + quantization configuration options. + name: The name of the TargetPlatformModel. + + Returns: A TargetPlatformModel object. + + """ + # Create a QuantizationConfigOptions, which defines a set + # of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example). + # If the QuantizationConfigOptions contains only one configuration, + # this configuration will be used for the operation quantization: + default_configuration_options = tp.QuantizationConfigOptions([default_config]) + + # Create a TargetPlatformModel and set its default quantization config. + # This default configuration will be used for all operations + # unless specified otherwise (see OperatorsSet, for example): + generated_tpm = tp.TargetPlatformModel(default_configuration_options, add_metadata=True, name=name) + + # To start defining the model's components (such as operator sets, and fusing patterns), + # use 'with' the TargetPlatformModel instance, and create them as below: + with generated_tpm: + # Create an OperatorsSet to represent a set of operations. + # Each OperatorsSet has a unique label. + # If a quantization configuration options is passed, these options will + # be used for operations that will be attached to this set's label. + # Otherwise, it will be a configure-less set (used in fusing): + + generated_tpm.set_simd_padding(is_simd_padding=True) + + # May suit for operations like: Dropout, Reshape, etc. + default_qco = tp.get_default_quantization_config_options() + tp.OperatorsSet("NoQuantization", + default_qco.clone_and_edit(enable_activation_quantization=False) + .clone_and_edit_weight_attribute(enable_weights_quantization=False)) + + # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects + mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list, + base_config=base_config) + + # Define operator sets that use mixed_precision_configuration_options: + conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options) + fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options) + + # Define operations sets without quantization configuration + # options (useful for creating fusing patterns, for example): + any_relu = tp.OperatorsSet("AnyReLU") + add = tp.OperatorsSet("Add") + sub = tp.OperatorsSet("Sub") + mul = tp.OperatorsSet("Mul") + div = tp.OperatorsSet("Div") + prelu = tp.OperatorsSet("PReLU") + swish = tp.OperatorsSet("Swish") + sigmoid = tp.OperatorsSet("Sigmoid") + tanh = tp.OperatorsSet("Tanh") + + # Combine multiple operators into a single operator to avoid quantization between + # them. To do this we define fusing patterns using the OperatorsSets that were created. + # To group multiple sets with regard to fusing, an OperatorSetConcat can be created + activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) + activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid) + any_binary = tp.OperatorSetConcat(add, sub, mul, div) + + # ------------------- # + # Fusions + # ------------------- # + tp.Fusing([conv, activations_after_conv_to_fuse]) + tp.Fusing([fc, activations_after_fc_to_fuse]) + tp.Fusing([any_binary, any_relu]) + + return generated_tpm diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py new file mode 100644 index 000000000..96505f638 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py @@ -0,0 +1,129 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.constants import FOUND_SONY_CUSTOM_LAYERS +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_DEPTHWISE_KERNEL, \ + KERAS_KERNEL, BIAS_ATTR, BIAS + +if FOUND_SONY_CUSTOM_LAYERS: + from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose + +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2 import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_keras_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Keras TargetPlatformCapabilities object with default operation sets to layers mapping. + + Returns: a Keras TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_keras_tpc(name='imx500_tpc_keras_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + + Args: + name: Name of the TargetPlatformCapabilities. + tp_model: TargetPlatformModel object. + + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + keras_tpc = tp.TargetPlatformCapabilities(tp_model, name=name, version=TPC_VERSION) + + no_quant_list = [Reshape, + tf.reshape, + Permute, + tf.transpose, + Flatten, + Cropping2D, + ZeroPadding2D, + Dropout, + MaxPooling2D, + tf.split, + tf.quantization.fake_quant_with_min_max_vars, + tf.math.argmax, + tf.shape, + tf.math.equal, + tf.gather, + tf.cast, + tf.unstack, + tf.compat.v1.gather, + tf.nn.top_k, + tf.__operators__.getitem, + tf.image.combined_non_max_suppression, + tf.compat.v1.shape] + + if FOUND_SONY_CUSTOM_LAYERS: + no_quant_list.append(SSDPostProcess) + + with keras_tpc: + tp.OperationsSetToLayers("NoQuantization", no_quant_list) + tp.OperationsSetToLayers("Conv", + [Conv2D, + DepthwiseConv2D, + Conv2DTranspose, + tf.nn.conv2d, + tf.nn.depthwise_conv2d, + tf.nn.conv2d_transpose], + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + attr_mapping={ + KERNEL_ATTR: DefaultDict({ + DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL, + tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("FullyConnected", [Dense], + attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("AnyReLU", [tf.nn.relu, + tf.nn.relu6, + tf.nn.leaky_relu, + ReLU, + LeakyReLU, + tp.LayerFilterParams(Activation, activation="relu"), + tp.LayerFilterParams(Activation, activation="leaky_relu")]) + tp.OperationsSetToLayers("Add", [tf.add, Add]) + tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract]) + tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply]) + tp.OperationsSetToLayers("Div", [tf.math.divide]) + tp.OperationsSetToLayers("PReLU", [PReLU]) + tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) + tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) + tp.OperationsSetToLayers("Tanh", [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")]) + + return keras_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py new file mode 100644 index 000000000..af057b710 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py @@ -0,0 +1,111 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import operator + +import torch +from torch import add, sub, mul, div, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, chunk, unbind, topk, \ + gather, equal, transpose, permute, argmax, squeeze +from torch.nn import Conv2d, Linear, BatchNorm2d, ConvTranspose2d +from torch.nn import Dropout, Flatten, Hardtanh +from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU +from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, PYTORCH_KERNEL, \ + BIAS +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2 import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_pytorch_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Pytorch TargetPlatformCapabilities object with default operation sets to layers mapping. + + Returns: a Pytorch TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_pytorch_tpc(name='imx500_tpc_pytorch_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + Args: + name: Name of the TargetPlatformModel. + tp_model: TargetPlatformModel object. + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + pytorch_tpc = tp.TargetPlatformCapabilities(tp_model, + name=name, + version=TPC_VERSION) + + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)} + + with pytorch_tpc: + tp.OperationsSetToLayers("NoQuantization", [Dropout, + Flatten, + dropout, + flatten, + split, + operator.getitem, + reshape, + unsqueeze, + BatchNorm2d, + chunk, + unbind, + torch.Tensor.size, + permute, + transpose, + equal, + argmax, + gather, + topk, + squeeze]) + + tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("FullyConnected", [Linear], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("AnyReLU", [torch.relu, + ReLU, + ReLU6, + LeakyReLU, + relu, + relu6, + leaky_relu, + tp.LayerFilterParams(Hardtanh, min_val=0), + tp.LayerFilterParams(hardtanh, min_val=0)]) + + tp.OperationsSetToLayers("Add", [operator.add, add]) + tp.OperationsSetToLayers("Sub", [operator.sub, sub]) + tp.OperationsSetToLayers("Mul", [operator.mul, mul]) + tp.OperationsSetToLayers("Div", [operator.truediv, div]) + tp.OperationsSetToLayers("PReLU", [PReLU, prelu]) + tp.OperationsSetToLayers("Swish", [SiLU, silu, Hardswish, hardswish]) + tp.OperationsSetToLayers("Sigmoid", [Sigmoid, sigmoid]) + tp.OperationsSetToLayers("Tanh", [Tanh, tanh]) + + return pytorch_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py new file mode 100644 index 000000000..8fb8b3a4b --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +__version__ = 'v2_lut' diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py new file mode 100644 index 000000000..050d3c8a5 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py @@ -0,0 +1,207 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import List, Tuple + +import model_compression_toolkit as mct +from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \ + WEIGHTS_QUANTIZATION_METHOD +from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \ + TargetPlatformModel +from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \ + AttributeQuantizationConfig + +tp = mct.target_platform + + +def get_tp_model() -> TargetPlatformModel: + """ + A method that generates a default target platform model, with base 8-bit quantization configuration and 8, 4, 2 + bits configuration list for mixed-precision quantization. + NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets + (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the + 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations. + This version enables metadata by default + + Returns: A TargetPlatformModel object. + + """ + base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() + return generate_tp_model(default_config=default_config, + base_config=base_config, + mixed_precision_cfg_list=mixed_precision_cfg_list, + name='imx500_lut_tp_model') + + +def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: + """ + Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel. + In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as + default configuration for mixed-precision quantization with non-uniform quantizer for 2 and 4 bit candidates. + + Returns: An OpQuantizationConfig config object and a list of OpQuantizationConfig objects. + + """ + + # We define a default quantization config for all non-specified weights attributes. + default_weight_attr_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=False, + lut_values_bitwidth=None) + + # We define a quantization config to quantize the kernel (for layers where there is a kernel attribute). + kernel_base_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, + weights_n_bits=8, + weights_per_channel_threshold=True, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # We define a quantization config to quantize the bias (for layers where there is a bias attribute). + bias_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + weights_n_bits=FLOAT_BITWIDTH, + weights_per_channel_threshold=False, + enable_weights_quantization=False, + lut_values_bitwidth=None) + + # Create a quantization config. + # A quantization configuration defines how an operator + # should be quantized on the modeled hardware: + + # We define a default config for operation without kernel attribute. + # This is the default config that should be used for non-linear operations. + eight_bits_default = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32) + + # We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes. + linear_eight_bits = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32) + + # To quantize a model using mixed-precision, create + # a list with more than one OpQuantizationConfig. + # In this example, we quantize some operations' weights + # using 2, 4 or 8 bits, and when using 2 or 4 bits, it's possible + # to quantize the operations' activations using LUT. + four_bits_lut = linear_eight_bits.clone_and_edit( + attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4, + WEIGHTS_QUANTIZATION_METHOD: tp.QuantizationMethod.LUT_SYM_QUANTIZER}}, + simd_size=linear_eight_bits.simd_size * 2) + two_bits_lut = linear_eight_bits.clone_and_edit( + attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 2, + WEIGHTS_QUANTIZATION_METHOD: tp.QuantizationMethod.LUT_SYM_QUANTIZER}}, + simd_size=linear_eight_bits.simd_size * 4) + mixed_precision_cfg_list = [linear_eight_bits, four_bits_lut, two_bits_lut] + + return linear_eight_bits, mixed_precision_cfg_list, eight_bits_default + + +def generate_tp_model(default_config: OpQuantizationConfig, + base_config: OpQuantizationConfig, + mixed_precision_cfg_list: List[OpQuantizationConfig], + name: str) -> TargetPlatformModel: + """ + Generates TargetPlatformModel with default defined Operators Sets, based on the given base configuration and + mixed-precision configurations options list. + + Args + default_config: A default OpQuantizationConfig to set as the TP model default configuration. + base_config: An OpQuantizationConfig to set as the TargetPlatformModel base configuration for mixed-precision purposes only. + mixed_precision_cfg_list: A list of OpQuantizationConfig to be used as the TP model mixed-precision + quantization configuration options. + name: The name of the TargetPlatformModel. + + Returns: A TargetPlatformModel object. + + """ + # Create a QuantizationConfigOptions, which defines a set + # of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example). + # If the QuantizationConfigOptions contains only one configuration, + # this configuration will be used for the operation quantization: + default_configuration_options = tp.QuantizationConfigOptions([default_config]) + + # Create a TargetPlatformModel and set its default quantization config. + # This default configuration will be used for all operations + # unless specified otherwise (see OperatorsSet, for example): + generated_tpm = tp.TargetPlatformModel(default_configuration_options, add_metadata=True, name=name) + + # To start defining the model's components (such as operator sets, and fusing patterns), + # use 'with' the TargetPlatformModel instance, and create them as below: + with generated_tpm: + # Create an OperatorsSet to represent a set of operations. + # Each OperatorsSet has a unique label. + # If a quantization configuration options is passed, these options will + # be used for operations that will be attached to this set's label. + # Otherwise, it will be a configure-less set (used in fusing): + + # May suit for operations like: Dropout, Reshape, etc. + default_qco = tp.get_default_quantization_config_options() + tp.OperatorsSet("NoQuantization", + default_qco.clone_and_edit(enable_activation_quantization=False) + .clone_and_edit_weight_attribute(enable_weights_quantization=False)) + + # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects + mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list, + base_config=base_config) + + # Define operator sets that use mixed_precision_configuration_options: + conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options) + fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options) + + # Define operations sets without quantization configuration + # options (useful for creating fusing patterns, for example): + any_relu = tp.OperatorsSet("AnyReLU") + add = tp.OperatorsSet("Add") + sub = tp.OperatorsSet("Sub") + mul = tp.OperatorsSet("Mul") + div = tp.OperatorsSet("Div") + prelu = tp.OperatorsSet("PReLU") + swish = tp.OperatorsSet("Swish") + sigmoid = tp.OperatorsSet("Sigmoid") + tanh = tp.OperatorsSet("Tanh") + + # Combine multiple operators into a single operator to avoid quantization between + # them. To do this we define fusing patterns using the OperatorsSets that were created. + # To group multiple sets with regard to fusing, an OperatorSetConcat can be created + activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) + activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid) + any_binary = tp.OperatorSetConcat(add, sub, mul, div) + + # ------------------- # + # Fusions + # ------------------- # + tp.Fusing([conv, activations_after_conv_to_fuse]) + tp.Fusing([fc, activations_after_fc_to_fuse]) + tp.Fusing([any_binary, any_relu]) + + return generated_tpm diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py new file mode 100644 index 000000000..8ef6b73a5 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py @@ -0,0 +1,129 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.constants import FOUND_SONY_CUSTOM_LAYERS +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_KERNEL, BIAS_ATTR, \ + KERAS_DEPTHWISE_KERNEL, BIAS + +if FOUND_SONY_CUSTOM_LAYERS: + from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose + +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_keras_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Keras TargetPlatformCapabilities object with default operation sets to layers mapping. + Returns: a Keras TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_keras_tpc(name='imx500_tpc_keras_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + + Args: + name: Name of the TargetPlatformCapabilities. + tp_model: TargetPlatformModel object. + + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + keras_tpc = tp.TargetPlatformCapabilities(tp_model, name=name, version=TPC_VERSION) + + no_quant_list = [Reshape, + tf.reshape, + Permute, + tf.transpose, + Flatten, + Cropping2D, + ZeroPadding2D, + Dropout, + MaxPooling2D, + tf.split, + tf.quantization.fake_quant_with_min_max_vars, + tf.math.argmax, + tf.shape, + tf.math.equal, + tf.gather, + tf.cast, + tf.unstack, + tf.compat.v1.gather, + tf.nn.top_k, + tf.__operators__.getitem, + tf.image.combined_non_max_suppression, + tf.compat.v1.shape] + + if FOUND_SONY_CUSTOM_LAYERS: + no_quant_list.append(SSDPostProcess) + + with keras_tpc: + tp.OperationsSetToLayers("NoQuantization", no_quant_list) + + tp.OperationsSetToLayers("Conv", + [Conv2D, + DepthwiseConv2D, + Conv2DTranspose, + tf.nn.conv2d, + tf.nn.depthwise_conv2d, + tf.nn.conv2d_transpose], + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + attr_mapping={ + KERNEL_ATTR: DefaultDict({ + DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL, + tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("FullyConnected", [Dense], + attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("AnyReLU", [tf.nn.relu, + tf.nn.relu6, + tf.nn.leaky_relu, + ReLU, + LeakyReLU, + tp.LayerFilterParams(Activation, activation="relu"), + tp.LayerFilterParams(Activation, activation="leaky_relu")]) + tp.OperationsSetToLayers("Add", [tf.add, Add]) + tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract]) + tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply]) + tp.OperationsSetToLayers("Div", [tf.math.divide]) + tp.OperationsSetToLayers("PReLU", [PReLU]) + tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) + tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) + tp.OperationsSetToLayers("Tanh", [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")]) + + return keras_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py new file mode 100644 index 000000000..364f1cfad --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py @@ -0,0 +1,110 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import operator + +import torch +from torch import add, sub, mul, div, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, chunk, unbind, topk, \ + gather, equal, transpose, permute, argmax, squeeze +from torch.nn import Conv2d, Linear, BatchNorm2d, ConvTranspose2d +from torch.nn import Dropout, Flatten, Hardtanh +from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU +from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS_ATTR, \ + BIAS +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_pytorch_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Pytorch TargetPlatformCapabilities object with default operation sets to layers mapping. + Returns: a Pytorch TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_pytorch_tpc(name='imx500_tpc_pytorch_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + Args: + name: Name of the TargetPlatformModel. + tp_model: TargetPlatformModel object. + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + pytorch_tpc = tp.TargetPlatformCapabilities(tp_model, + name=name, + version=TPC_VERSION) + + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)} + + with pytorch_tpc: + tp.OperationsSetToLayers("NoQuantization", [Dropout, + Flatten, + dropout, + flatten, + split, + operator.getitem, + reshape, + unsqueeze, + BatchNorm2d, + chunk, + unbind, + torch.Tensor.size, + permute, + transpose, + equal, + argmax, + gather, + topk, + squeeze]) + + tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("FullyConnected", [Linear], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("AnyReLU", [torch.relu, + ReLU, + ReLU6, + LeakyReLU, + relu, + relu6, + leaky_relu, + tp.LayerFilterParams(Hardtanh, min_val=0), + tp.LayerFilterParams(hardtanh, min_val=0)]) + + tp.OperationsSetToLayers("Add", [operator.add, add]) + tp.OperationsSetToLayers("Sub", [operator.sub, sub]) + tp.OperationsSetToLayers("Mul", [operator.mul, mul]) + tp.OperationsSetToLayers("Div", [operator.truediv, div]) + tp.OperationsSetToLayers("PReLU", [PReLU, prelu]) + tp.OperationsSetToLayers("Swish", [SiLU, silu, Hardswish, hardswish]) + tp.OperationsSetToLayers("Sigmoid", [Sigmoid, sigmoid]) + tp.OperationsSetToLayers("Tanh", [Tanh, tanh]) + + return pytorch_tpc diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/metadata_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/metadata_test.py new file mode 100644 index 000000000..4283e3002 --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/metadata_test.py @@ -0,0 +1,43 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +import numpy as np + +import model_compression_toolkit as mct +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest +from mct_quantizers.keras.metadata import add_metadata, get_metadata + +from model_compression_toolkit.constants import TENSORFLOW +from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL + +keras = tf.keras +layers = keras.layers +tp = mct.target_platform + + +class MetadataTest(BaseKerasFeatureNetworkTest): + + def get_tpc(self): + return mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v2") + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + output = tf.add(inputs, inputs) + return tf.keras.models.Model(inputs=inputs, outputs=output) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + self.unit_test.assertTrue(len(get_metadata(quantized_model)) > 0, + msg='A model quantized with TPC IMX500.v2 should have a metadata.') diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 88edee49e..8c32621a8 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -129,6 +129,7 @@ MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationNonConfNodesTest, \ MixedPercisionSearchTotalMemoryNonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest, MixedPercisionCombinedNMSTest from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import MatmulToDenseSubstitutionTest +from tests.keras_tests.feature_networks_tests.feature_networks.metadata_test import MetadataTest from tests.keras_tests.feature_networks_tests.feature_networks.const_representation_test import ConstRepresentationTest, \ ConstRepresentationMultiInputTest, ConstRepresentationMatMulTest from tests.keras_tests.feature_networks_tests.feature_networks.concatination_threshold_update import ConcatThresholdtest @@ -744,6 +745,9 @@ def test_bn_attributes_quantization(self): def concat_threshold_test(self): ConcatThresholdtest(self).run_test() + def test_metadata(self): + MetadataTest(self).run_test() + if __name__ == '__main__': unittest.main() diff --git a/tests/keras_tests/function_tests/test_unsupported_custom_layer.py b/tests/keras_tests/function_tests/test_unsupported_custom_layer.py index 56e71f595..05972fef3 100644 --- a/tests/keras_tests/function_tests/test_unsupported_custom_layer.py +++ b/tests/keras_tests/function_tests/test_unsupported_custom_layer.py @@ -39,9 +39,8 @@ def test_raised_error_with_custom_layer(self): x = CustomIdentity()(inputs) model = keras.Model(inputs=inputs, outputs=x) - expected_error = f'MCT does not support optimizing Keras custom layers, but found layer of type . Please add the custom layer to TPC ' \ - f'or file a feature request or an issue if you believe this is an issue.' + expected_error = f"MCT does not support optimizing Keras custom layers. Found a layer of type . " \ + f" Please add the custom layer to Target Platform Capabilities (TPC), or file a feature request or an issue if you believe this should be supported." def rep_dataset(): yield [np.random.randn(1, 3, 3, 3)] diff --git a/tests/pytorch_tests/model_tests/feature_models/metadata_test.py b/tests/pytorch_tests/model_tests/feature_models/metadata_test.py new file mode 100644 index 000000000..4f57d1e3b --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/metadata_test.py @@ -0,0 +1,49 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch +import torch.nn as nn +import numpy as np +import model_compression_toolkit as mct +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model +from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest +from tests.common_tests.helpers.tensors_compare import cosine_similarity +from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL +from model_compression_toolkit.constants import PYTORCH +from mct_quantizers import PytorchQuantizationWrapper +from mct_quantizers.pytorch.metadata import add_metadata, get_metadata, add_onnx_metadata, get_onnx_metadata + +tp = mct.target_platform + + +class DummyNet(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 13, 1) + + def forward(self, x): + return self.conv(x) + + +class MetadataTest(BasePytorchFeatureNetworkTest): + + def get_tpc(self): + return mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, "v2") + + def create_networks(self): + return DummyNet() + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + self.unit_test.assertTrue(len(get_metadata(quantized_model)) > 0, + msg='A model quantized with TPC IMX500.v2 should have a metadata.') diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 75cb6bb56..07546c0ae 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -85,6 +85,7 @@ GPTQLearnRateZeroTest from tests.pytorch_tests.model_tests.feature_models.uniform_activation_test import \ UniformActivationTest +from tests.pytorch_tests.model_tests.feature_models.metadata_test import MetadataTest from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \ ConstRepresentationMultiInputTest from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod @@ -572,7 +573,6 @@ def test_qat(self): QuantizationAwareTrainingMixedPrecisionCfgTest(self).run_test() QuantizationAwareTrainingMixedPrecisionRUCfgTest(self).run_test() - def test_bn_attributes_quantization(self): """ This test checks the quantization of BatchNorm layer attributes. @@ -583,6 +583,9 @@ def test_bn_attributes_quantization(self): def test_concat_threshold_update(self): ConcatUpdateTest(self).run_test() + def test_metadata(self): + MetadataTest(self).run_test() + if __name__ == '__main__': unittest.main()