Skip to content

Commit

Permalink
fix pytorch tp model tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Jan 22, 2024
1 parent b21bf80 commit d6c4d3a
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions tests/pytorch_tests/function_tests/test_pytorch_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import Greater, Smaller, Eq
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import DEFAULT_MIXEDPRECISION_CONFIG
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
TFLITE_TP_MODEL, QNNPACK_TP_MODEL, KERNEL_ATTR, WEIGHTS_N_BITS, PYTORCH_KERNEL, BIAS_ATTR, BIAS
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
from tests.common_tests.helpers.generate_test_tp_model import generate_test_op_qc, generate_test_attr_configs
from tests.pytorch_tests.layer_tests.base_pytorch_layer_test import LayerTestModel
Expand Down Expand Up @@ -78,45 +79,53 @@ def test_pytorch_layers_with_params(self):

def test_qco_by_pytorch_layer(self):
default_qco = tp.QuantizationConfigOptions([TEST_QC])
hm = tp.TargetPlatformModel(default_qco, name='test')
with hm:
mixed_precision_configuration_options = tp.QuantizationConfigOptions([TEST_QC,
TEST_QC.clone_and_edit(
weights_n_bits=4),
TEST_QC.clone_and_edit(
weights_n_bits=2)],
base_config=TEST_QC)
default_qco = default_qco.clone_and_edit(attr_weights_configs_mapping={})
tpm = tp.TargetPlatformModel(default_qco, name='test')
with tpm:
mixed_precision_configuration_options = tp.QuantizationConfigOptions(
[TEST_QC,
TEST_QC.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}}),
TEST_QC.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 2}})],
base_config=TEST_QC)

tp.OperatorsSet("conv", mixed_precision_configuration_options)

sevenbit_qco = TEST_QCO.clone_and_edit(activation_n_bits=7)
sevenbit_qco = TEST_QCO.clone_and_edit(activation_n_bits=7,
attr_weights_configs_mapping={})
tp.OperatorsSet("tanh", sevenbit_qco)

sixbit_qco = TEST_QCO.clone_and_edit(activation_n_bits=6)
sixbit_qco = TEST_QCO.clone_and_edit(activation_n_bits=6,
attr_weights_configs_mapping={})
tp.OperatorsSet("avg_pool2d_kernel_2", sixbit_qco)

tp.OperatorsSet("avg_pool2d")

hm_pytorch = tp.TargetPlatformCapabilities(hm, name='fw_test')
with hm_pytorch:
tp.OperationsSetToLayers("conv", [torch.nn.Conv2d])
tpc_pytorch = tp.TargetPlatformCapabilities(tpm, name='fw_test')
with tpc_pytorch:
tp.OperationsSetToLayers("conv", [torch.nn.Conv2d], attr_mapping={KERNEL_ATTR: {
tuple([torch.nn.Conv2d]): PYTORCH_KERNEL}, BIAS_ATTR: {tuple(): BIAS}})
tp.OperationsSetToLayers("tanh", [torch.tanh])
tp.OperationsSetToLayers("avg_pool2d_kernel_2",
[LayerFilterParams(torch.nn.functional.avg_pool2d, kernel_size=2)])
[LayerFilterParams(torch.nn.functional.avg_pool2d, kernel_size=2)])
tp.OperationsSetToLayers("avg_pool2d",
[torch.nn.functional.avg_pool2d])
[torch.nn.functional.avg_pool2d])

conv_node = get_node(torch.nn.Conv2d(3, 3, (1, 1)))
tanh_node = get_node(torch.tanh)
avg_pool2d_k2 = get_node(partial(torch.nn.functional.avg_pool2d, kernel_size=2))
avg_pool2d = get_node(partial(torch.nn.functional.avg_pool2d, kernel_size=1))

conv_qco = conv_node.get_qco(hm_pytorch)
tanh_qco = tanh_node.get_qco(hm_pytorch)
avg_pool2d_k2_qco = avg_pool2d_k2.get_qco(hm_pytorch)
avg_pool2d_qco = avg_pool2d.get_qco(hm_pytorch)

self.assertEqual(conv_qco, mixed_precision_configuration_options)
conv_qco = conv_node.get_qco(tpc_pytorch)
tanh_qco = tanh_node.get_qco(tpc_pytorch)
avg_pool2d_k2_qco = avg_pool2d_k2.get_qco(tpc_pytorch)
avg_pool2d_qco = avg_pool2d.get_qco(tpc_pytorch)

self.assertEqual(len(conv_qco.quantization_config_list),
len(mixed_precision_configuration_options.quantization_config_list))
for i in range(len(conv_qco.quantization_config_list)):
self.assertEqual(conv_qco.quantization_config_list[i].attr_weights_configs_mapping[PYTORCH_KERNEL],
mixed_precision_configuration_options.quantization_config_list[
i].attr_weights_configs_mapping[KERNEL_ATTR])
self.assertEqual(tanh_qco, sevenbit_qco)
self.assertEqual(avg_pool2d_k2_qco, sixbit_qco)
self.assertEqual(avg_pool2d_qco, default_qco)
Expand Down

0 comments on commit d6c4d3a

Please sign in to comment.