Skip to content

Commit

Permalink
Add metadata quantized models in PTQ & GPTQ facades. (#1034)
Browse files Browse the repository at this point in the history
Add metadata quantized models in PTQ & GPTQ facades. Enabled from the TPC.
Added imx500.v2 TPCs that enable metadata.
  • Loading branch information
elad-c authored Apr 11, 2024
1 parent 5bc899f commit e295990
Show file tree
Hide file tree
Showing 29 changed files with 1,154 additions and 28 deletions.
4 changes: 4 additions & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None

# Metadata fields
MCT_VERSION = 'mct_version'
TPC_VERSION = 'tpc_version'

WEIGHTS_SIGNED = True
# Minimal threshold to use for quantization ranges:
MIN_THRESHOLD = (2 ** -16)
Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def set_tpc(self,
if n.is_custom:
if not is_node_in_tpc:
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
f' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature request or an issue if you believe this should be supported.')
' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature '
'request or an issue if you believe this should be supported.') # pragma: no cover
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.')
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover

self.tpc = tpc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _unwrap_quantize_wrapper(layer: Layer):
weights_list.append(layer.get_quantized_weights()['kernel'])
else:
Logger.critical(f'KerasQuantizationWrapper should wrap only DepthwiseConv2D, Conv2D, Dense'
f' and Conv2DTranspose layers but wrapped layer is {layer.layer}')
f' and Conv2DTranspose layers but wrapped layer is {layer.layer}')

if layer.layer.bias is not None:
weights_list.append(layer.layer.bias)
Expand All @@ -121,6 +121,11 @@ def _unwrap_quantize_wrapper(layer: Layer):

return layer

# Delete metadata layer if exists
if hasattr(self.model, 'metadata_layer'):
Logger.info('Metadata is not exported to FakelyQuant models.')
delattr(self.model, 'metadata_layer')

# clone each layer in the model and apply _unwrap_quantize_wrapper to layers wrapped with a QuantizeWrapper.
self.exported_model = tf.keras.models.clone_model(self.model,
input_tensors=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def export(self):
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
"""
# Delete metadata layer if exists
if hasattr(self.model, 'metadata_layer'):
Logger.info('Metadata is not exported to TFLite models.')
delattr(self.model, 'metadata_layer')

# Use Keras exporter to quantize model's weights before converting it to TFLite.
# Since exporter saves the model, we use a tmp path for saving, and then we delete it.
handle, tmp_file = tempfile.mkstemp(DEFAULT_KERAS_EXPORT_EXTENTION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ def _substitute_model(layer_to_substitue: keras.layers.Layer) -> keras.layers.La

return layer_to_substitue

# Delete metadata layer if exists
if hasattr(self.model, 'metadata_layer'):
Logger.info('Metadata is not exported to TFLite models.')
delattr(self.model, 'metadata_layer')

# Transform the model to a new model that can be converted to int8 models.
# For example: replace dense layers with point-wise layers (to support per-channel quantization)
self.transformed_model = clone_model(self.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@
# limitations under the License.
# ==============================================================================
from typing import Callable
from io import BytesIO

import torch.nn
import onnx

from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
from mct_quantizers import pytorch_quantizers
from mct_quantizers.pytorch.metadata import add_onnx_metadata

DEFAULT_ONNX_OPSET_VERSION=15


class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
"""
Exporter for fakely-quant PyTorch models.
Expand Down Expand Up @@ -58,7 +62,6 @@ def __init__(self,
self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
self._onnx_opset_version = onnx_opset_version


def export(self) -> None:
"""
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
Expand All @@ -74,7 +77,7 @@ def export(self) -> None:
# If _use_onnx_custom_quantizer_ops is set to True, the quantizer forward function will use
# the custom implementation when exporting the operator into onnx model. If not, it removes the
# wraps and quantizes the ops in place (for weights, for activation torch quantization function is
# exported since it's used during forward.
# exported since it's used during forward).
if self._use_onnx_custom_quantizer_ops:
self._enable_onnx_custom_ops_export()
else:
Expand All @@ -87,15 +90,30 @@ def export(self) -> None:

model_input = to_torch_tensor(next(self.repr_dataset())[0])

torch.onnx.export(self.model,
model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})
if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
onnx_model = add_onnx_metadata(onnx_model, self.model.metadata)
onnx.save_model(onnx_model, self.save_model_path)
else:
torch.onnx.export(self.model,
model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})

def _enable_onnx_custom_ops_export(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def pytorch_export_model(model: torch.nn.Module,
repr_dataset: Callable,
is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX,
quantization_format : QuantizationFormat = QuantizationFormat.MCTQ,
quantization_format: QuantizationFormat = QuantizationFormat.MCTQ,
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) -> None:
"""
Export a PyTorch quantized model to a torchscript or onnx model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@ def get_exportable_pytorch_model(graph: Graph):
else:
def get_exportable_pytorch_model(*args, **kwargs):
Logger.critical("PyTorch must be installed to use 'get_exportable_pytorch_model'. "
"The 'torch' package is missing.") # pragma: no cover
"The 'torch' package is missing.") # pragma: no cover
7 changes: 6 additions & 1 deletion model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.metadata import get_versions_dict

LR_DEFAULT = 0.15
LR_REST_DEFAULT = 1e-4
Expand All @@ -48,6 +49,7 @@
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
from model_compression_toolkit import get_target_platform_capabilities
from mct_quantizers.keras.metadata import add_metadata

# As from TF2.9 optimizers package is changed
if version.parse(tf.__version__) < version.parse("2.9"):
Expand Down Expand Up @@ -234,7 +236,10 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
fw_impl,
DEFAULT_KERAS_INFO)

return get_exportable_keras_model(tg_gptq)
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
if target_platform_capabilities.tp_model.add_metadata:
exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
return exportable_model, user_info

else:
# If tensorflow is not installed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
from model_compression_toolkit.metadata import get_versions_dict

LR_DEFAULT = 1e-4
LR_REST_DEFAULT = 1e-4
Expand All @@ -47,6 +48,7 @@
from torch.nn import Module
from torch.optim import Adam, Optimizer
from model_compression_toolkit import get_target_platform_capabilities
from mct_quantizers.pytorch.metadata import add_metadata
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)

def get_pytorch_gptq_config(n_epochs: int,
Expand Down Expand Up @@ -202,7 +204,10 @@ def pytorch_gradient_post_training_quantization(model: Module,
fw_impl,
DEFAULT_PYTORCH_INFO)

return get_exportable_pytorch_model(graph_gptq)
exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
if target_platform_capabilities.tp_model.add_metadata:
exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
return exportable_model, user_info


else:
Expand Down
29 changes: 29 additions & 0 deletions model_compression_toolkit/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict
from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION


def get_versions_dict(tpc) -> Dict:
"""
Returns: A dictionary with TPC and MCT versions.
"""
# imported inside to avoid circular import error
from model_compression_toolkit import __version__ as mct_version
tpc_version = f'{tpc.name}.{tpc.version}'
return {MCT_VERSION: mct_version, TPC_VERSION: tpc_version}
8 changes: 6 additions & 2 deletions model_compression_toolkit/ptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.ptq.runner import ptq_runner
from model_compression_toolkit.metadata import get_versions_dict

if FOUND_TF:
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
Expand All @@ -38,6 +39,7 @@
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model

from model_compression_toolkit import get_target_platform_capabilities
from mct_quantizers.keras.metadata import add_metadata
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)


Expand Down Expand Up @@ -164,8 +166,10 @@ def keras_post_training_quantization(in_model: Model,
fw_impl,
fw_info)

return get_exportable_keras_model(graph_with_stats_correction)

exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
if target_platform_capabilities.tp_model.add_metadata:
exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
return exportable_model, user_info


else:
Expand Down
7 changes: 6 additions & 1 deletion model_compression_toolkit/ptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from model_compression_toolkit.ptq.runner import ptq_runner
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
from model_compression_toolkit.metadata import get_versions_dict


if FOUND_TORCH:
Expand All @@ -38,6 +39,7 @@
from torch.nn import Module
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
from model_compression_toolkit import get_target_platform_capabilities
from mct_quantizers.pytorch.metadata import add_metadata

DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)

Expand Down Expand Up @@ -139,7 +141,10 @@ def pytorch_post_training_quantization(in_module: Module,
fw_impl,
fw_info)

return get_exportable_pytorch_model(graph_with_stats_correction)
exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
if target_platform_capabilities.tp_model.add_metadata:
exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
return exportable_model, user_info


else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,18 @@ class TargetPlatformModel(ImmutableClass):

def __init__(self,
default_qco: QuantizationConfigOptions,
add_metadata: bool = False,
name="default_tp_model"):
"""
Args:
default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model.
add_metadata (bool): Whether to add metadata to the model or not.
name (str): Name of the model.
"""

super().__init__()
self.add_metadata = add_metadata
self.name = name
self.operator_set = []
assert isinstance(default_qco, QuantizationConfigOptions)
Expand Down Expand Up @@ -191,7 +194,7 @@ def __validate_model(self):
"""
opsets_names = [op.name for op in self.operator_set]
if (len(set(opsets_names)) != len(opsets_names)):
if len(set(opsets_names)) != len(opsets_names):
Logger.critical(f'Operator Sets must have unique names.')

def get_default_config(self) -> OpQuantizationConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import TargetPlatformModel
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION


class TargetPlatformCapabilities(ImmutableClass):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v1_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import get_keras_tpc as get_keras_tpc_v1_pot
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v2_lut

# Keras: TPC versioning
keras_tpc_models_dict = {'v1': get_keras_tpc_v1(),
'v1_lut': get_keras_tpc_v1_lut(),
'v1_pot': get_keras_tpc_v1_pot(),
'v2': get_keras_tpc_v2(),
'v2_lut': get_keras_tpc_v2_lut(),
LATEST: get_keras_tpc_latest()}

###############################
Expand All @@ -42,13 +46,19 @@
get_pytorch_tpc as get_pytorch_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1_pot
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_v1_lut

from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v2
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v2_lut

# Pytorch: TPC versioning
pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1(),
'v1_lut': get_pytorch_tpc_v1_lut(),
'v1_pot': get_pytorch_tpc_v1_pot(),
'v2': get_pytorch_tpc_v2(),
'v2_lut': get_pytorch_tpc_v2_lut(),
LATEST: get_pytorch_tpc_latest()}

tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

__version__ = 'v2'
Loading

0 comments on commit e295990

Please sign in to comment.