Skip to content

Commit

Permalink
fix feature runners
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Mar 11, 2024
1 parent ca10d58 commit b4086fc
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 351 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,20 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=


class MixedPercisionSearchTest(MixedPercisionBaseTest):
def __init__(self, unit_test):
def __init__(self, unit_test, distance_metric=MpDistanceWeighting.AVG):
super().__init__(unit_test, val_batch_size=2)

self.distance_metric = distance_metric

def get_kpi(self):
# kpi is infinity -> should give best model - 8bits
return KPI(np.inf)

def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1,
distance_weighting_method=self.distance_metric,
target_kpi=self.get_kpi())

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D)
assert (quantization_info.mixed_precision_cfg == [0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
from tests.keras_tests.feature_networks_tests.feature_networks.uniform_range_selection_activation_test import \
UniformRangeSelectionActivationTest, UniformRangeSelectionBoundedActivationTest
from tests.keras_tests.feature_networks_tests.feature_networks.weights_mixed_precision_tests import \
MixedPrecisionSearchTest, MixedPercisionDepthwiseTest, \
MixedPercisionSearchTest, MixedPercisionDepthwiseTest, \
MixedPercisionSearchKPI4BitsAvgTest, MixedPercisionSearchKPI2BitsAvgTest, MixedPrecisionActivationDisabled, \
MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationKPINonConfNodesTest, \
MixedPercisionSearchTotalKPINonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest, MixedPercisionCombinedNMSTest
Expand Down Expand Up @@ -204,8 +204,8 @@ def test_mixed_precision_search_kpi_4bits_avg_nms(self):
MixedPercisionCombinedNMSTest(self).run_test()

def test_mixed_precision_search(self):
MixedPrecisionSearchTest(self, distance_metric=MpDistanceWeighting.AVG).run_test()
MixedPrecisionSearchTest(self, distance_metric=MpDistanceWeighting.LAST_LAYER).run_test()
MixedPercisionSearchTest(self, distance_metric=MpDistanceWeighting.AVG).run_test()
MixedPercisionSearchTest(self, distance_metric=MpDistanceWeighting.LAST_LAYER).run_test()

def test_mixed_precision_for_part_weights_layers(self):
MixedPercisionSearchPartWeightsLayersTest(self).run_test()
Expand Down
Loading

0 comments on commit b4086fc

Please sign in to comment.