Skip to content

Commit

Permalink
Add metadata tests
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Apr 11, 2024
1 parent a05a0c3 commit e4feaeb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
5 changes: 3 additions & 2 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def set_tpc(self,
if n.is_custom:
if not is_node_in_tpc:
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
f' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature request or an issue if you believe this should be supported.')
' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature '
'request or an issue if you believe this should be supported.') # pragma: no cover
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.')
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover

self.tpc = tpc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationNonConfNodesTest, \
MixedPercisionSearchTotalMemoryNonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest, MixedPercisionCombinedNMSTest
from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import MatmulToDenseSubstitutionTest
from tests.keras_tests.feature_networks_tests.feature_networks.metadata_test import MetadataTest
from tests.keras_tests.feature_networks_tests.feature_networks.const_representation_test import ConstRepresentationTest, \
ConstRepresentationMultiInputTest, ConstRepresentationMatMulTest
from tests.keras_tests.feature_networks_tests.feature_networks.concatination_threshold_update import ConcatThresholdtest
Expand Down Expand Up @@ -744,6 +745,9 @@ def test_bn_attributes_quantization(self):
def concat_threshold_test(self):
ConcatThresholdtest(self).run_test()

def test_metadata(self):
MetadataTest(self).run_test()


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
GPTQLearnRateZeroTest
from tests.pytorch_tests.model_tests.feature_models.uniform_activation_test import \
UniformActivationTest
from tests.pytorch_tests.model_tests.feature_models.metadata_test import MetadataTest
from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \
ConstRepresentationMultiInputTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
Expand Down Expand Up @@ -572,7 +573,6 @@ def test_qat(self):
QuantizationAwareTrainingMixedPrecisionCfgTest(self).run_test()
QuantizationAwareTrainingMixedPrecisionRUCfgTest(self).run_test()


def test_bn_attributes_quantization(self):
"""
This test checks the quantization of BatchNorm layer attributes.
Expand All @@ -583,6 +583,9 @@ def test_bn_attributes_quantization(self):
def test_concat_threshold_update(self):
ConcatUpdateTest(self).run_test()

def test_metadata(self):
MetadataTest(self).run_test()


if __name__ == '__main__':
unittest.main()

0 comments on commit e4feaeb

Please sign in to comment.