diff --git a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py index 681805672..5feb09f2a 100644 --- a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py +++ b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py @@ -20,20 +20,15 @@ if tf.__version__ >= "2.13": from keras.src.engine.input_layer import InputLayer + from keras.src.layers.core import TFOpLambda else: from keras.engine.input_layer import InputLayer - -from packaging import version + from keras.layers.core import TFOpLambda from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \ AttachTpcToKeras -if version.parse(tf.__version__) >= version.parse("2.13"): - from keras.src.layers.core import TFOpLambda -else: - from keras.layers.core import TFOpLambda - from model_compression_toolkit.constants import AXIS from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ diff --git a/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py b/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py index 2e6bedd00..3c8ced191 100644 --- a/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py +++ b/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py @@ -14,6 +14,7 @@ # ============================================================================== import unittest import keras +import tensorflow as tf import numpy as np from keras import Input