|
39 | 39 | from qkeras import utils as qkeras_utils
|
40 | 40 | from qkeras import bn_folding_utils
|
41 | 41 |
|
| 42 | +def get_sgd_optimizer(learning_rate): |
| 43 | + if hasattr(tf.keras.optimizers, "legacy"): |
| 44 | + return tf.keras.optimizers.legacy.SGD(learning_rate) |
| 45 | + else: |
| 46 | + return tf.keras.optimizers.SGD(learning_rate) |
| 47 | + |
42 | 48 |
|
43 | 49 | def get_qconv2d_model(input_shape, kernel_size, kernel_quantizer=None):
|
44 | 50 | num_class = 2
|
@@ -108,7 +114,7 @@ def get_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay):
|
108 | 114 |
|
109 | 115 | x_shape = (2, 2, 1)
|
110 | 116 | loss_fn = tf.keras.losses.MeanSquaredError()
|
111 |
| - optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) |
| 117 | + optimizer = get_sgd_optimizer(learning_rate=1e-3) |
112 | 118 |
|
113 | 119 | # define a model with seperate conv2d and bn layers
|
114 | 120 | x = x_in = layers.Input(x_shape, name="input")
|
@@ -349,7 +355,7 @@ def test_loading():
|
349 | 355 |
|
350 | 356 | loss_fn = tf.keras.losses.MeanSquaredError()
|
351 | 357 | loss_metric = metrics.Mean()
|
352 |
| - optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) |
| 358 | + optimizer = get_sgd_optimizer(learning_rate=1e-3) |
353 | 359 | x_shape = (2, 2, 1)
|
354 | 360 |
|
355 | 361 | custom_objects = {}
|
@@ -397,7 +403,7 @@ def test_same_training_and_prediction():
|
397 | 403 | epochs = 5
|
398 | 404 | loss_fn = tf.keras.losses.MeanSquaredError()
|
399 | 405 | loss_metric = metrics.Mean()
|
400 |
| - optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) |
| 406 | + optimizer = get_sgd_optimizer(learning_rate=1e-3) |
401 | 407 |
|
402 | 408 | x_shape = (2, 2, 1)
|
403 | 409 | kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
|
@@ -568,7 +574,7 @@ def _get_models(x_shape, num_class, depthwise_quantizer, folding_mode,
|
568 | 574 | epochs = 5
|
569 | 575 | loss_fn = tf.keras.losses.MeanSquaredError()
|
570 | 576 | loss_metric = metrics.Mean()
|
571 |
| - optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) |
| 577 | + optimizer = get_sgd_optimizer(learning_rate=1e-3) |
572 | 578 |
|
573 | 579 | pred1 = run_training(
|
574 | 580 | model, epochs, loss_fn, loss_metric, optimizer, train_ds, do_print=False)
|
|
0 commit comments