Skip to content

Commit

Permalink
fix QuantizationMethod import in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 13, 2025
1 parent 36014bf commit 3e2a701
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
from torch import nn

import model_compression_toolkit as mct
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter
Expand Down Expand Up @@ -695,15 +694,15 @@ def test_gptq(self):
GPTQWeightsUpdateTest(self, rounding_type=RoundingType.SoftQuantizer).run_test()
GPTQLearnRateZeroTest(self, rounding_type=RoundingType.SoftQuantizer).run_test()
GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer,
weights_quant_method=mct.QuantizationMethod.UNIFORM).run_test()
weights_quant_method=QuantizationMethod.UNIFORM).run_test()
GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer,
weights_quant_method=mct.QuantizationMethod.UNIFORM, per_channel=False,
weights_quant_method=QuantizationMethod.UNIFORM, per_channel=False,
params_learning=False).run_test()
GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer,
weights_quant_method=mct.QuantizationMethod.UNIFORM,
weights_quant_method=QuantizationMethod.UNIFORM,
per_channel=True, hessian_weights=True, log_norm_weights=True, scaled_log_norm=True).run_test()
GPTQWeightsUpdateTest(self, rounding_type=RoundingType.SoftQuantizer,
weights_quant_method=mct.QuantizationMethod.UNIFORM,
weights_quant_method=QuantizationMethod.UNIFORM,
params_learning=False).run_test() # TODO: When params learning is True, the uniform quantizer gets a min value > max value

def test_gptq_with_gradual_activation(self):
Expand Down

0 comments on commit 3e2a701

Please sign in to comment.