Skip to content

Commit

Permalink
Refactor QuantizationConfigOptions (#1088)
Browse files Browse the repository at this point in the history
Refactor QuantizationConfigOptions to enforce 'base_config' to be a reference of an instance in 'quantization_config_list'
  • Loading branch information
elad-c authored May 29, 2024
1 parent 7763d4c commit 53b8e42
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __eq__(self, other):
self.simd_size == other.simd_size


class QuantizationConfigOptions(object):
class QuantizationConfigOptions:
"""
Wrap a set of quantization configurations to consider during the quantization
Expand All @@ -215,19 +215,24 @@ def __init__(self,
"""

assert isinstance(quantization_config_list,
list), f'\'QuantizationConfigOptions\' options list must be a list, but received: {type(quantization_config_list)}.'
assert len(quantization_config_list) > 0, f'Options list can not be empty.'
list), f"'QuantizationConfigOptions' options list must be a list, but received: {type(quantization_config_list)}."
for cfg in quantization_config_list:
assert isinstance(cfg, OpQuantizationConfig), f'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: {type(cfg)}.'
assert isinstance(cfg, OpQuantizationConfig),\
f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}."
self.quantization_config_list = quantization_config_list
if len(quantization_config_list) > 1:
assert base_config is not None, f'For multiple configurations, a \'base_config\' is required for non-mixed-precision optimization.'
assert base_config in quantization_config_list, f"\'base_config\' must be included in the quantization config options list."
assert base_config is not None, \
f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization."
assert any([base_config is cfg for cfg in quantization_config_list]), \
f"'base_config' must be included in the quantization config options list."
# Enforce base_config to be a reference to an instance in quantization_config_list.
self.base_config = base_config
elif len(quantization_config_list) == 1:
assert base_config is None or base_config == quantization_config_list[0], "'base_config' should be included in 'quantization_config_list'"
# Set base_config to be a reference to the first instance in quantization_config_list.
self.base_config = quantization_config_list[0]
else:
Logger.critical("\'QuantizationConfigOptions\' requires at least one \'OpQuantizationConfig\'; the provided list is empty.")
raise AssertionError("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.")

def __eq__(self, other):
"""
Expand Down
13 changes: 10 additions & 3 deletions tests/common_tests/helpers/generate_test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def generate_mixed_precision_test_tp_model(base_cfg, default_config, mp_bitwidth
candidate_cfg = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: weights_n_bits}},
activation_n_bits=activation_n_bits)

mp_op_cfg_list.append(candidate_cfg)
if candidate_cfg == base_cfg:
# the base config must be a reference of an instance in the cfg_list, so we put it and not the clone in the list.
mp_op_cfg_list.append(base_cfg)
else:
mp_op_cfg_list.append(candidate_cfg)

return generate_tp_model(default_config=default_config,
base_config=base_cfg,
Expand All @@ -85,8 +89,11 @@ def generate_tp_model_with_activation_mp(base_cfg, default_config, mp_bitwidth_c
**{k: v for k, v in base_cfg.attr_weights_configs_mapping.items() if
k != KERNEL_ATTR}},
activation_n_bits=activation_n_bits)

mp_op_cfg_list.append(candidate_cfg)
if candidate_cfg == base_cfg:
# the base config must be a reference of an instance in the cfg_list, so we put it and not the clone in the list.
mp_op_cfg_list.append(base_cfg)
else:
mp_op_cfg_list.append(candidate_cfg)

base_tp_model = generate_tp_model(default_config=default_config,
base_config=base_cfg,
Expand Down
3 changes: 2 additions & 1 deletion tests/common_tests/test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class QCOptionsTest(unittest.TestCase):
def test_empty_qc_options(self):
with self.assertRaises(AssertionError) as e:
tp.QuantizationConfigOptions([])
self.assertEqual('Options list can not be empty.', str(e.exception))
self.assertEqual("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.",
str(e.exception))

def test_list_of_no_qc(self):
with self.assertRaises(AssertionError) as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,4 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
error = np.sum(error, axis=(0,1)).flatten()
bias = dw_layer.weights[2]
# Input mean is 1 so correction_term = quant_error * 1
# TODO:
# Increase atol due to a minor difference in Symmetric quantizer
self.unit_test.assertTrue(np.isclose(error, bias, atol=1e-7).all())
self.unit_test.assertTrue(np.isclose(error, bias.numpy(), atol=3e-7).all())

0 comments on commit 53b8e42

Please sign in to comment.