Skip to content

Commit 5b1fe84

Browse files
chenmoneygithubcopybara-github
authored andcommitted
Code changes to get ready for an incoming Keras optimizer migration.
PiperOrigin-RevId: 473314021 Change-Id: I9d820aeb76fd57dbea1b66678b59f04292d14d49
1 parent b91d881 commit 5b1fe84

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/bn_folding_test.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
from qkeras import utils as qkeras_utils
4040
from qkeras import bn_folding_utils
4141

42+
def get_sgd_optimizer(learning_rate):
43+
if hasattr(tf.keras.optimizers, "legacy"):
44+
return tf.keras.optimizers.legacy.SGD(learning_rate)
45+
else:
46+
return tf.keras.optimizers.SGD(learning_rate)
47+
4248

4349
def get_qconv2d_model(input_shape, kernel_size, kernel_quantizer=None):
4450
num_class = 2
@@ -108,7 +114,7 @@ def get_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay):
108114

109115
x_shape = (2, 2, 1)
110116
loss_fn = tf.keras.losses.MeanSquaredError()
111-
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
117+
optimizer = get_sgd_optimizer(learning_rate=1e-3)
112118

113119
# define a model with seperate conv2d and bn layers
114120
x = x_in = layers.Input(x_shape, name="input")
@@ -349,7 +355,7 @@ def test_loading():
349355

350356
loss_fn = tf.keras.losses.MeanSquaredError()
351357
loss_metric = metrics.Mean()
352-
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
358+
optimizer = get_sgd_optimizer(learning_rate=1e-3)
353359
x_shape = (2, 2, 1)
354360

355361
custom_objects = {}
@@ -397,7 +403,7 @@ def test_same_training_and_prediction():
397403
epochs = 5
398404
loss_fn = tf.keras.losses.MeanSquaredError()
399405
loss_metric = metrics.Mean()
400-
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
406+
optimizer = get_sgd_optimizer(learning_rate=1e-3)
401407

402408
x_shape = (2, 2, 1)
403409
kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
@@ -568,7 +574,7 @@ def _get_models(x_shape, num_class, depthwise_quantizer, folding_mode,
568574
epochs = 5
569575
loss_fn = tf.keras.losses.MeanSquaredError()
570576
loss_metric = metrics.Mean()
571-
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
577+
optimizer = get_sgd_optimizer(learning_rate=1e-3)
572578

573579
pred1 = run_training(
574580
model, epochs, loss_fn, loss_metric, optimizer, train_ds, do_print=False)

0 commit comments

Comments
 (0)