Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move MP traget_kpi from facade to MixedPrecisionQuantizationConfig #990

Merged
merged 8 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from typing import List, Callable

from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI


class MixedPrecisionQuantizationConfig:

def __init__(self,
target_kpi: KPI = None,
compute_distance_fn: Callable = None,
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG,
num_of_images: int = 32,
Expand All @@ -34,6 +36,7 @@ def __init__(self,
Class with mixed precision parameters to quantize the input model.

Args:
target_kpi (KPI): KPI to constraint the search of the mixed-precision configuration for the model.
compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
Expand All @@ -46,6 +49,7 @@ def __init__(self,

"""

self.target_kpi = target_kpi
self.compute_distance_fn = compute_distance_fn
self.distance_weighting_method = distance_weighting_method
self.num_of_images = num_of_images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class BitWidthSearchMethod(Enum):
def search_bit_width(graph_to_search_cfg: Graph,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
target_kpi: KPI,
mp_config: MixedPrecisionQuantizationConfig,
representative_data_gen: Callable,
search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING,
Expand All @@ -64,7 +63,6 @@ def search_bit_width(graph_to_search_cfg: Graph,
graph_to_search_cfg: Graph to search a MP configuration for.
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
fw_impl: FrameworkImplementation object with specific framework methods implementation.
target_kpi: Target KPI to bound our feasible solution space s.t the configuration does not violate it.
mp_config: Mixed-precision quantization configuration.
representative_data_gen: Dataset to use for retrieving images for the models inputs.
search_method: BitWidthSearchMethod to define which searching method to use.
Expand All @@ -76,6 +74,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
bit-width index on the node).

"""
target_kpi = mp_config.target_kpi

# target_kpi have to be passed. If it was not passed, the facade is not supposed to get here by now.
if target_kpi is None:
Expand Down
10 changes: 4 additions & 6 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def core_runner(in_model: Any,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
tpc: TargetPlatformCapabilities,
target_kpi: KPI = None,
tb_w: TensorboardWriter = None):
"""
Quantize a trained model using post-training quantization.
Expand All @@ -67,7 +66,6 @@ def core_runner(in_model: Any,
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
tpc: TargetPlatformCapabilities object that models the inference target platform and
the attached framework operator's information.
target_kpi: KPI to constraint the search of the mixed-precision configuration for the model.
tb_w: TensorboardWriter object for logging

Returns:
Expand Down Expand Up @@ -106,14 +104,14 @@ def core_runner(in_model: Any,
######################################
# Finalize bit widths
######################################
if target_kpi is not None:
assert core_config.mixed_precision_enable
if core_config.mixed_precision_enable:
if core_config.mixed_precision_config.target_kpi is None:
Logger.critical(f"Trying to run Mixed Precision quantization without providing a valid target KPI.")
if core_config.mixed_precision_config.configuration_overwrite is None:

bit_widths_config = search_bit_width(tg,
fw_info,
fw_impl,
target_kpi,
core_config.mixed_precision_config,
representative_data_gen,
hessian_info_service=hessian_info_service)
Expand All @@ -139,7 +137,7 @@ def core_runner(in_model: Any,
fw_info=fw_info,
fw_impl=fw_impl)

if target_kpi is not None:
if core_config.mixed_precision_enable:
# Retrieve lists of tuples (node, node's final weights/activation bitwidth)
weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info)
activation_conf_nodes_bitwidth = tg.get_final_activation_config()
Expand Down
17 changes: 7 additions & 10 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def keras_gradient_post_training_quantization(in_model: Model,
representative_data_gen: Callable,
gptq_config: GradientPTQConfig,
gptq_representative_data_gen: Callable = None,
target_kpi: KPI = None,
core_config: CoreConfig = CoreConfig(),
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
Expand All @@ -142,7 +141,6 @@ def keras_gradient_post_training_quantization(in_model: Model,
representative_data_gen (Callable): Dataset used for calibration.
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
Expand Down Expand Up @@ -171,26 +169,26 @@ def keras_gradient_post_training_quantization(in_model: Model,

>>> config = mct.core.CoreConfig()

If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
with different bitwidths for different layers.
The candidates bitwidth for quantization should be defined in the target platform model:

>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))

For mixed-precision set a target KPI object:
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
while the bias will not):

>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.

If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
with different bitwidths for different layers.
The candidates bitwidth for quantization should be defined in the target platform model:

>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, target_kpi=kpi))

Create GPTQ config:

>>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1)

Pass the model with the representative dataset generator to get a quantized model:

>>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
>>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, core_config=config)

"""
KerasModelValidation(model=in_model,
Expand All @@ -212,7 +210,6 @@ def keras_gradient_post_training_quantization(in_model: Model,
fw_info=fw_info,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_kpi=target_kpi,
tb_w=tb_w)

tg_gptq = gptq_runner(tg,
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def get_pytorch_gptq_config(n_epochs: int,

def pytorch_gradient_post_training_quantization(model: Module,
representative_data_gen: Callable,
target_kpi: KPI = None,
core_config: CoreConfig = CoreConfig(),
gptq_config: GradientPTQConfig = None,
gptq_representative_data_gen: Callable = None,
Expand All @@ -118,7 +117,6 @@ def pytorch_gradient_post_training_quantization(model: Module,
Args:
model (Module): Pytorch model to quantize.
representative_data_gen (Callable): Dataset used for calibration.
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
Expand Down Expand Up @@ -176,7 +174,6 @@ def pytorch_gradient_post_training_quantization(model: Module,
fw_info=DEFAULT_PYTORCH_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_kpi=target_kpi,
tb_w=tb_w)

# ---------------------- #
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/ptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

def keras_post_training_quantization(in_model: Model,
representative_data_gen: Callable,
target_kpi: KPI = None,
core_config: CoreConfig = CoreConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
"""
Expand All @@ -61,7 +60,6 @@ def keras_post_training_quantization(in_model: Model,
Args:
in_model (Model): Keras model to quantize.
representative_data_gen (Callable): Dataset used for calibration.
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.

Expand Down Expand Up @@ -137,7 +135,6 @@ def keras_post_training_quantization(in_model: Model,
fw_info=fw_info,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_kpi=target_kpi,
tb_w=tb_w)

tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/ptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

def pytorch_post_training_quantization(in_module: Module,
representative_data_gen: Callable,
target_kpi: KPI = None,
core_config: CoreConfig = CoreConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
"""
Expand All @@ -60,7 +59,6 @@ def pytorch_post_training_quantization(in_module: Module,
Args:
in_module (Module): Pytorch module to quantize.
representative_data_gen (Callable): Dataset used for calibration.
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.

Expand Down Expand Up @@ -109,7 +107,6 @@ def pytorch_post_training_quantization(in_module: Module,
fw_info=DEFAULT_PYTORCH_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_kpi=target_kpi,
tb_w=tb_w)

tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/qat/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def qat_wrapper(n: common.BaseNode,

def keras_quantization_aware_training_init_experimental(in_model: Model,
representative_data_gen: Callable,
target_kpi: KPI = None,
core_config: CoreConfig = CoreConfig(),
qat_config: QATConfig = QATConfig(),
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
Expand All @@ -110,7 +109,6 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
Args:
in_model (Model): Keras model to quantize.
representative_data_gen (Callable): Dataset used for initial calibration.
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
qat_config (QATConfig): QAT configuration
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
Expand Down Expand Up @@ -195,7 +193,6 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
fw_info=fw_info,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_kpi=target_kpi,
tb_w=tb_w)

tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/qat/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def qat_wrapper(n: common.BaseNode,

def pytorch_quantization_aware_training_init_experimental(in_model: Module,
representative_data_gen: Callable,
target_kpi: KPI = None,
core_config: CoreConfig = CoreConfig(),
qat_config: QATConfig = QATConfig(),
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
Expand All @@ -98,7 +97,6 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
Args:
in_model (Model): Pytorch model to quantize.
representative_data_gen (Callable): Dataset used for initial calibration.
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
qat_config (QATConfig): QAT configuration
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Pytorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
Expand Down Expand Up @@ -162,7 +160,6 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
fw_info=DEFAULT_PYTORCH_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_kpi=target_kpi,
tb_w=tb_w)

tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
Expand Down
1 change: 0 additions & 1 deletion tests/common_tests/base_feature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run_test(self):
core_config = self.get_core_config()
ptq_model, quantization_info = self.get_ptq_facade()(model_float,
self.representative_data_gen_experimental,
target_kpi=self.get_kpi(),
core_config=core_config,
target_platform_capabilities=self.get_tpc()
)
Expand Down
Loading
Loading