Skip to content

Commit

Permalink
Modify DefaultDict constructor to initialize an empty dict if None is…
Browse files Browse the repository at this point in the history
… passed (#929)

Co-authored-by: Ofir Gordon <[email protected]>
  • Loading branch information
ofirgo and Ofir Gordon authored Jan 25, 2024
1 parent f46add1 commit ad83b07
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions model_compression_toolkit/core/common/defaultdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ class DefaultDict:
"""

def __init__(self,
known_dict: Dict[Any, Any],
known_dict: Dict[Any, Any] = None,
default_value: Any = None):
"""
Args:
known_dict: Dictionary to wrap.
known_dict: Dictionary to wrap. If None is provided, initializes an empty dictionary.
default_value: default value when requested key is not in known_dict.
"""

self.default_value = default_value
self.known_dict = known_dict
self.known_dict = known_dict if known_dict is not None else {}

def get(self, key: Any) -> Any:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# If the quantization config does not contain kernel channel mapping or the weights
# quantization is not per-channel, we use a dummy channel mapping.
dummy_channel_mapping = DefaultDict({}, (None, None))
dummy_channel_mapping = DefaultDict(default_value=(None, None))


def get_weights_qparams(kernel: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):

def __init__(self,
quantization_config: TrainableQuantizerWeightsConfig,
max_lsbs_change_map: dict = DefaultDict({}, 1)):
max_lsbs_change_map: dict = DefaultDict(default_value=1)):
"""
Initialize a STEWeightGPTQQuantizer object with parameters to use for the quantization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):

def __init__(self,
quantization_config: TrainableQuantizerWeightsConfig,
max_lsbs_change_map: dict = DefaultDict({}, 1)):
max_lsbs_change_map: dict = DefaultDict(default_value=1)):
"""
Construct a Pytorch model that utilize a fake weight quantizer of STE (Straight Through Estimator) for symmetric quantizer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, unit_test, quant_method=QuantizationMethod.SYMMETRIC, roundin
if rounding_type == RoundingType.SoftQuantizer:
self.override_params = {QUANT_PARAM_LEARNING_STR: quantization_parameter_learning}
elif rounding_type == RoundingType.STE:
self.override_params = {MAX_LSB_STR: DefaultDict({}, 1)}
self.override_params = {MAX_LSB_STR: DefaultDict(default_value=1)}
else:
self.override_params = None

Expand Down
2 changes: 1 addition & 1 deletion tests/keras_tests/function_tests/test_get_gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def setUp(self):
train_bias=True,
loss=multiple_tensors_mse_loss,
rounding_type=RoundingType.STE,
gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict({}, 1)}),
gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict(default_value=1)}),
get_keras_gptq_config(n_epochs=1,
optimizer=tf.keras.optimizers.Adam()),
get_keras_gptq_config(n_epochs=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch_tests/function_tests/get_gptq_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_test(self):
{QUANT_PARAM_LEARNING_STR: self.quantization_parameters_learning}
elif self.rounding_type == RoundingType.STE:
gptqv2_config.gptq_quantizer_params_override = \
{MAX_LSB_STR: DefaultDict({}, 1)}
{MAX_LSB_STR: DefaultDict(default_value=1)}
else:
gptqv2_config.gptq_quantizer_params_override = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, unit_test, experimental_exporter=True, weights_bits=8, weight
self.log_norm_weights = log_norm_weights
self.scaled_log_norm = scaled_log_norm
self.override_params = {QUANT_PARAM_LEARNING_STR: params_learning} if \
rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict({}, 1)} \
rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict(default_value=1)} \
if rounding_type == RoundingType.STE else None

def get_quantization_config(self):
Expand Down

0 comments on commit ad83b07

Please sign in to comment.