From ad83b072ac7c3eed033b66d2bc6badbb014cc52f Mon Sep 17 00:00:00 2001 From: Ofir Gordon Date: Thu, 25 Jan 2024 14:59:59 +0200 Subject: [PATCH] Modify DefaultDict constructor to initialize an empty dict if None is passed (#929) Co-authored-by: Ofir Gordon --- model_compression_toolkit/core/common/defaultdict.py | 6 +++--- .../qparams_weights_computation.py | 2 +- .../gptq/keras/quantizer/ste_rounding/symmetric_ste.py | 2 +- .../gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py | 2 +- .../feature_networks/gptq/gptq_test.py | 2 +- tests/keras_tests/function_tests/test_get_gptq_config.py | 2 +- tests/pytorch_tests/function_tests/get_gptq_config_test.py | 2 +- tests/pytorch_tests/model_tests/feature_models/gptq_test.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/model_compression_toolkit/core/common/defaultdict.py b/model_compression_toolkit/core/common/defaultdict.py index 3d6200798..6cf0d6e94 100644 --- a/model_compression_toolkit/core/common/defaultdict.py +++ b/model_compression_toolkit/core/common/defaultdict.py @@ -26,17 +26,17 @@ class DefaultDict: """ def __init__(self, - known_dict: Dict[Any, Any], + known_dict: Dict[Any, Any] = None, default_value: Any = None): """ Args: - known_dict: Dictionary to wrap. + known_dict: Dictionary to wrap. If None is provided, initializes an empty dictionary. default_value: default value when requested key is not in known_dict. """ self.default_value = default_value - self.known_dict = known_dict + self.known_dict = known_dict if known_dict is not None else {} def get(self, key: Any) -> Any: """ diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py index eaf329421..99c4b798a 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py @@ -24,7 +24,7 @@ # If the quantization config does not contain kernel channel mapping or the weights # quantization is not per-channel, we use a dummy channel mapping. -dummy_channel_mapping = DefaultDict({}, (None, None)) +dummy_channel_mapping = DefaultDict(default_value=(None, None)) def get_weights_qparams(kernel: np.ndarray, diff --git a/model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py b/model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py index e704a8a01..b54f11d49 100644 --- a/model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +++ b/model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py @@ -77,7 +77,7 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer): def __init__(self, quantization_config: TrainableQuantizerWeightsConfig, - max_lsbs_change_map: dict = DefaultDict({}, 1)): + max_lsbs_change_map: dict = DefaultDict(default_value=1)): """ Initialize a STEWeightGPTQQuantizer object with parameters to use for the quantization. diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py b/model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py index bacf6074d..a4a34a230 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py @@ -84,7 +84,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer): def __init__(self, quantization_config: TrainableQuantizerWeightsConfig, - max_lsbs_change_map: dict = DefaultDict({}, 1)): + max_lsbs_change_map: dict = DefaultDict(default_value=1)): """ Construct a Pytorch model that utilize a fake weight quantizer of STE (Straight Through Estimator) for symmetric quantizer. diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py index b98807b88..b0b5ec02e 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py @@ -72,7 +72,7 @@ def __init__(self, unit_test, quant_method=QuantizationMethod.SYMMETRIC, roundin if rounding_type == RoundingType.SoftQuantizer: self.override_params = {QUANT_PARAM_LEARNING_STR: quantization_parameter_learning} elif rounding_type == RoundingType.STE: - self.override_params = {MAX_LSB_STR: DefaultDict({}, 1)} + self.override_params = {MAX_LSB_STR: DefaultDict(default_value=1)} else: self.override_params = None diff --git a/tests/keras_tests/function_tests/test_get_gptq_config.py b/tests/keras_tests/function_tests/test_get_gptq_config.py index cc70fbda5..e8ed3764d 100644 --- a/tests/keras_tests/function_tests/test_get_gptq_config.py +++ b/tests/keras_tests/function_tests/test_get_gptq_config.py @@ -106,7 +106,7 @@ def setUp(self): train_bias=True, loss=multiple_tensors_mse_loss, rounding_type=RoundingType.STE, - gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict({}, 1)}), + gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict(default_value=1)}), get_keras_gptq_config(n_epochs=1, optimizer=tf.keras.optimizers.Adam()), get_keras_gptq_config(n_epochs=1, diff --git a/tests/pytorch_tests/function_tests/get_gptq_config_test.py b/tests/pytorch_tests/function_tests/get_gptq_config_test.py index 71fb8bc84..ed20a13f3 100644 --- a/tests/pytorch_tests/function_tests/get_gptq_config_test.py +++ b/tests/pytorch_tests/function_tests/get_gptq_config_test.py @@ -80,7 +80,7 @@ def run_test(self): {QUANT_PARAM_LEARNING_STR: self.quantization_parameters_learning} elif self.rounding_type == RoundingType.STE: gptqv2_config.gptq_quantizer_params_override = \ - {MAX_LSB_STR: DefaultDict({}, 1)} + {MAX_LSB_STR: DefaultDict(default_value=1)} else: gptqv2_config.gptq_quantizer_params_override = None diff --git a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py index ae116fa93..029c4934e 100644 --- a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py @@ -63,7 +63,7 @@ def __init__(self, unit_test, experimental_exporter=True, weights_bits=8, weight self.log_norm_weights = log_norm_weights self.scaled_log_norm = scaled_log_norm self.override_params = {QUANT_PARAM_LEARNING_STR: params_learning} if \ - rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict({}, 1)} \ + rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict(default_value=1)} \ if rounding_type == RoundingType.STE else None def get_quantization_config(self):