Skip to content

Commit

Permalink
fix missing target_kpi moves in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Mar 11, 2024
1 parent 9e21f7e commit ca10d58
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
5 changes: 2 additions & 3 deletions tests/common_tests/helpers/prep_graph_for_func_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def prepare_graph_with_quantization_parameters(in_model,
def prepare_graph_set_bit_widths(in_model,
fw_impl,
representative_data_gen,
target_kpi,
n_iter,
quant_config,
fw_info,
Expand Down Expand Up @@ -133,8 +132,8 @@ def _representative_data_gen():
######################################
# Finalize bit widths
######################################
if target_kpi is not None:
assert core_config.mixed_precision_enable
if core_config.mixed_precision_enable:
assert core_config.mixed_precision_config.target_kpi is not None
if core_config.mixed_precision_config.configuration_overwrite is None:

bit_widths_config = search_bit_width(tg,
Expand Down
4 changes: 2 additions & 2 deletions tests/keras_tests/non_parallel_tests/test_keras_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def rep_data():
rep_data,
target_platform_capabilities=tpc)
core_config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=2,
use_hessian_based_scores=False))
use_hessian_based_scores=False,
target_kpi=mct.core.KPI(np.inf)))
quantized_model, _ = mct.ptq.keras_post_training_quantization(model,
rep_data,
core_config=core_config,
target_kpi=mct.core.KPI(np.inf),
target_platform_capabilities=tpc)

def test_get_keras_supported_version(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def plot_tensor_sizes(self):
cfg = mct.core.DEFAULTCONFIG
mp_cfg = mct.core.MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse,
distance_weighting_method=MpDistanceWeighting.AVG,
use_hessian_based_scores=False)
use_hessian_based_scores=False,
target_kpi=mct.core.KPI(np.inf))

# compare max tensor size with plotted max tensor size
tg = prepare_graph_set_bit_widths(in_model=model,
Expand All @@ -120,7 +121,6 @@ def plot_tensor_sizes(self):
tpc=tpc,
network_editor=[],
quant_config=cfg,
target_kpi=mct.core.KPI(),
n_iter=1,
analyze_similarity=True,
mp_cfg=mp_cfg)
Expand Down Expand Up @@ -162,7 +162,6 @@ def rep_data():
self.model = MultipleOutputsNet()
quantized_model, _ = mct.ptq.keras_post_training_quantization(self.model,
rep_data,
target_kpi=mct.core.KPI(np.inf),
core_config=core_config,
target_platform_capabilities=tpc)

Expand Down

0 comments on commit ca10d58

Please sign in to comment.