Skip to content

Commit

Permalink
fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Apr 18, 2024
1 parent 168ff53 commit 0e1b806
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ def get_inferable_quantizers(self, node: BaseNode):
Returns:
weight_quantizers: A dictionary between a weight's name to its quantizer.
activation_quantizers: A list of activations quantization, one for each layer output.
weight_values: A dictionary between a weight's name to its value. Relevant for positional weights only.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def type(self):
def is_match_type(self, _type: Type) -> bool:
"""
Check if input type matches the node type, either in instance type or in type name. Checking the
name string is required because of function types changes that occurred in TF 2.15.
name string is required because of function types changes that occurred in TF 2.15, because it
changes the "function" attribute object (e.g. a different tf.add function that will fail the
equal operation).
Args:
_type: other node type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ def __init__(self, qc: QuantizationConfig,
for attr in node_attrs_list:
if isinstance(attr, int):
# this is a positional attribute, so it needs to be handled separately.
# we assume that a positional attribute is quantized with the default configuration provided in the TPC.
self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc,
weights_attr_cfg=op_cfg.default_weight_attr_config,
weights_channels_axis=weights_channels_axis)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,6 @@ def get_inferable_quantizers(self, node: BaseNode):
Returns:
weight_quantizers: A dictionary between a weight's name to its quantizer.
activation_quantizers: A list of activations quantization, one for each layer output.
weight_values: A dictionary between a weight's name to its value. Relevant for positional weights only.
"""

def _weight_name(w: Union[str, int]) -> Union[str, int]:
Expand Down
13 changes: 10 additions & 3 deletions model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
"""
Positional weights are saved according to their index in the node's call arguments, so
need to know the function arguments' names in case the weights are in the kwargs.
Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so
it will only include the arguments that may contain constants. For example, we don't
want the transpose_a attribute of tf.matmul to be saved as a constant.
Every operation we add support to, needs to be added here.
Args:
tfoplambda_layer: TFOpLambda layer.
Expand All @@ -75,13 +82,13 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
tf.pow: {'x': 0, 'y': 1},
tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function)
if not kwargs2index:
# In TF 2.15 the function attribute is different and doesn't match the original
# operation object we use. Therefore, we extract kwargs2index with the symbol.
kwargs2index = {'__operators__.add': {'x': 0, 'y': 1},
'math.add': {'x': 0, 'y': 1},
'math.multiply': {'x': 0, 'y': 1},
'linalg.matmul': {'a': 0, 'b': 1},
'concat': {'values': 0}}.get(tfoplambda_layer.symbol)
if not kwargs2index:
kwargs2index = {}
'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {})

return kwargs2index

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,6 @@ def get_inferable_quantizers(self, node: BaseNode):
Returns:
weight_quantizers: A dictionary between a weight's name to its quantizer.
activation_quantizers: A list of activations quantization, one for each layer output.
weight_values: A dictionary between a weight's name to its value. Relevant for positional weights only.
"""

return get_inferable_quantizers(node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def get_inferable_quantizers(node: BaseNode,
get_weights_quantizer_for_node: Callable,
get_activations_quantizer_for_node: Callable,
attributes_names: List[str] = []) -> Tuple[Dict, List, Dict]:
attributes_names: List[str] = []) -> Tuple[Dict, List]:
"""
Create quantizers to wrap a layer for its corresponding node.
Expand All @@ -32,25 +32,18 @@ def get_inferable_quantizers(node: BaseNode,
Returns:
weight_quantizers: A dictionary between a weight's name to its quantizer.
activation_quantizers: A list of activations quantization, one for each layer output.
weight_values: A dictionary between a weight's name to its value. Relevant for positional weights only.
"""

weight_quantizers = {}
activation_quantizers = []
weight_values = None

for attr in attributes_names:
if node.is_weights_quantization_enabled(attr):
weight_quantizer = get_weights_quantizer_for_node(node, attr)
weight_quantizers[attr] = weight_quantizer
if isinstance(attr, int): # for positional weights we need to extract the weight's value.
if weight_values is None:
weight_values = {attr: node.get_weights_by_keys(attr)}
else:
weight_values[attr] = node.get_weights_by_keys(attr)

if node.is_activation_quantization_enabled():
num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1
activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs

return weight_quantizers, activation_quantizers, weight_values
return weight_quantizers, activation_quantizers
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ def _get_wrapper(node: common.BaseNode,
Returns: Wrapped layer with weights quantizers and activation quantizers
"""
weights_quantizers, _, weights_values = fw_impl.get_inferable_quantizers(node)
weights_quantizers, _ = fw_impl.get_inferable_quantizers(node)
if len(weights_quantizers) > 0:
# for positional weights we need to extract the weight's value.
weights_values = {attr: node.get_weights_by_keys(attr)
for attr in weights_quantizers if isinstance(attr, int)}
return KerasQuantizationWrapper(layer,
weights_quantizers,
weights_values)
Expand All @@ -58,7 +61,7 @@ def get_activation_quantizer_holder(node: common.BaseNode, fw_impl) -> Callable:
Returns:
A ActivationQuantizationHolder layer for the node activation quantization.
"""
_, activation_quantizers, _ = fw_impl.get_inferable_quantizers(node)
_, activation_quantizers = fw_impl.get_inferable_quantizers(node)

# Holder by definition uses a single quantizer for the activation quantization
# thus we make sure this is the only possible case (unless it's a node with no activation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ def fully_quantized_wrapper(node: common.BaseNode,
Returns: Wrapped layer
"""
weight_quantizers, _, weights_values = fw_impl.get_inferable_quantizers(node)
weight_quantizers, _ = fw_impl.get_inferable_quantizers(node)
if len(weight_quantizers) > 0:
if weights_values is not None:
weights_values = {k: fw_impl.to_tensor(v) for k, v in weights_values.items()}
# for positional weights we need to extract the weight's value.
weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
for attr in weight_quantizers if isinstance(attr, int)}
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values)
return module


def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable:
"""
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
Expand All @@ -58,7 +60,7 @@ def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable:
Returns:
A PytorchActivationQuantizationHolder module for the node's activation quantization.
"""
_, activation_quantizers, _ = fw_impl.get_inferable_quantizers(node)
_, activation_quantizers = fw_impl.get_inferable_quantizers(node)
# Holder by definition uses a single quantizer for the activation quantization
# thus we make sure this is the only possible case (unless it's a node we no activation
# quantization, which in this case has an empty list).
Expand All @@ -68,6 +70,7 @@ def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable:
f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
f'were found for node {node}')


def get_exportable_pytorch_model(graph: Graph):
"""
Convert graph to fully quantized PyTorch model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza

# We define a default quantization config for all non-specified weights attributes.
default_weight_attr_config = AttributeQuantizationConfig(
weights_quantization_method=tp.QuantizationMethod.SYMMETRIC,
weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
weights_n_bits=8,
weights_per_channel_threshold=False,
enable_weights_quantization=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,17 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# this configuration will be used for the operation quantization:
default_configuration_options = tp.QuantizationConfigOptions([default_config])

# const configuration is similar to the default eight bit configuration except for PoT
# Create a QuantizationConfigOptions for quantizing constants in functional ops.
# Constant configuration is similar to the default eight bit configuration except for PoT
# quantization method for the constant.
# Since the constants are not named attributes of the layer, we use the default_weight_attr_config to
# define the desired quantization properties for them.
const_config = default_config.clone_and_edit(
default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit(
weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
enable_weights_quantization=True))
if not (const_config.default_weight_attr_config.weights_quantization_method == tp.QuantizationMethod.POWER_OF_TWO and
const_config.default_weight_attr_config.weights_per_channel_threshold is False):
mct.logger.Logger.error('Constant quantization config should be per-tensor PoT.')
const_configuration_options = tp.QuantizationConfigOptions([const_config])

# Create a TargetPlatformModel and set its default quantization config.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_op_quantization_configs() -> \

# We define a default quantization config for all non-specified weights attributes.
default_weight_attr_config = AttributeQuantizationConfig(
weights_quantization_method=tp.QuantizationMethod.SYMMETRIC,
weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
weights_n_bits=8,
weights_per_channel_threshold=False,
enable_weights_quantization=False,
Expand Down Expand Up @@ -151,12 +151,17 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# this configuration will be used for the operation quantization:
default_configuration_options = tp.QuantizationConfigOptions([default_config])

# const configuration is similar to the default eight bit configuration except for PoT
# Create a QuantizationConfigOptions for quantizing constants in functional ops.
# Constant configuration is similar to the default eight bit configuration except for PoT
# quantization method for the constant.
# Since the constants are not named attributes of the layer, we use the default_weight_attr_config to
# define the desired quantization properties for them.
const_config = default_config.clone_and_edit(
default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit(
weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
enable_weights_quantization=True))
if not (const_config.default_weight_attr_config.weights_quantization_method == tp.QuantizationMethod.POWER_OF_TWO and
const_config.default_weight_attr_config.weights_per_channel_threshold is False):
mct.logger.Logger.error('Constant quantization config should be per-tensor PoT.')
const_configuration_options = tp.QuantizationConfigOptions([const_config])

# Create a TargetPlatformModel and set its default quantization config.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from functools import partial
import tensorflow as tf
import numpy as np

import model_compression_toolkit as mct
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from mct_quantizers import KerasQuantizationWrapper
Expand All @@ -32,14 +31,14 @@

class ConstQuantizationTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwrags=False,
def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwargs=False,
input_shape=(32, 32, 16)):
super(ConstQuantizationTest, self).__init__(unit_test=unit_test, input_shape=input_shape)
self.layer = layer
self.const = const
self.is_list_input = is_list_input
self.input_reverse_order = input_reverse_order
self.use_kwrags = use_kwrags
self.use_kwargs = use_kwargs

def generate_inputs(self):
# need positive inputs so won't divide with zero or take root of negative number
Expand All @@ -58,12 +57,12 @@ def create_networks(self):
x = self.layer([x, self.const])
else:
if self.input_reverse_order:
if self.use_kwrags:
if self.use_kwargs:
x = self.layer(x=self.const, y=x)
else:
x = self.layer(self.const, x)
else:
if self.use_kwrags:
if self.use_kwargs:
x = self.layer(x=x, y=self.const)
else:
x = self.layer(x, self.const)
Expand All @@ -80,3 +79,42 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
const_index = 0 if self.input_reverse_order else 1
self.unit_test.assertTrue((quantized_model.layers[2].weight_values[const_index] == self.const).all(),
msg='Constant value should not change')


class AdvancedConstQuantizationTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test, input_shape=(32, 32, 3)):
super(AdvancedConstQuantizationTest, self).__init__(unit_test=unit_test, input_shape=input_shape)
self.const = np.random.random((130,))

def get_ptq_facade(self):
gptq_config = mct.gptq.get_keras_gptq_config(30)
return partial(mct.gptq.keras_gradient_post_training_quantization,
gptq_config=gptq_config)

def get_resource_utilization(self):
return mct.core.ResourceUtilization(9e3)

def generate_inputs(self):
# need positive inputs so won't divide with zero or take root of negative number
return [1 + np.random.random(in_shape) for in_shape in self.get_input_shapes()]

def get_tpc(self):
return mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v2")

def create_networks(self):
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
x = layers.Conv2D(130, 3)(inputs)
x = layers.ReLU()(x)
x = tf.add(x, self.const)
x = layers.Conv2D(16, 3)(x)
return tf.keras.models.Model(inputs=inputs, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
self.unit_test.assertTrue(isinstance(quantized_model.layers[5], KerasQuantizationWrapper),
msg='TFOpLambda should be quantized')
self.unit_test.assertTrue((quantized_model.layers[5].weight_values[1] == self.const).all(),
msg='Constant value should not change')
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@

class ConstRepresentationTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwrags=False,
def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwargs=False,
input_shape=(32, 32, 16)):
super(ConstRepresentationTest, self).__init__(unit_test=unit_test, input_shape=input_shape)
self.layer = layer
self.const = const
self.is_list_input = is_list_input
self.input_reverse_order = input_reverse_order
self.use_kwrags = use_kwrags
self.use_kwargs = use_kwargs

def generate_inputs(self):
# need positive inputs so won't divide with zero or take root of negative number
Expand All @@ -58,12 +58,12 @@ def create_networks(self):
x = self.layer([x, self.const])
else:
if self.input_reverse_order:
if self.use_kwrags:
if self.use_kwargs:
x = self.layer(x=self.const, y=x)
else:
x = self.layer(self.const, x)
else:
if self.use_kwrags:
if self.use_kwargs:
x = self.layer(x=x, y=self.const)
else:
x = self.layer(x, self.const)
Expand Down
Loading

0 comments on commit 0e1b806

Please sign in to comment.