diff --git a/qkeras/__init__.py b/qkeras/__init__.py index 67918c46..3cffdf7f 100644 --- a/qkeras/__init__.py +++ b/qkeras/__init__.py @@ -34,6 +34,7 @@ #from .qtools.settings import cfg from .qconv2d_batchnorm import QConv2DBatchnorm from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm +from .qdense_batchnorm import QDenseBatchnorm assert tf.executing_eagerly(), "QKeras requires TF with eager execution mode on" diff --git a/qkeras/qdense_batchnorm.py b/qkeras/qdense_batchnorm.py new file mode 100644 index 00000000..836f5d46 --- /dev/null +++ b/qkeras/qdense_batchnorm.py @@ -0,0 +1,327 @@ +# Copyright 2020 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fold batchnormalization with previous QDense layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import warnings + +import numpy as np +import six + + +from qkeras.qlayers import QDense +from qkeras.quantizers import * + +import tensorflow.compat.v2 as tf +from tensorflow.keras import layers +from tensorflow.python.framework import smart_cond as tf_utils +from tensorflow.python.ops import math_ops + +tf.compat.v2.enable_v2_behavior() + + +class QDenseBatchnorm(QDense): + """Implements a quantized Dense layer fused with Batchnorm.""" + + def __init__( + self, + units, + activation=None, + use_bias=True, + kernel_initializer="he_normal", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_quantizer=None, + bias_quantizer=None, + kernel_range=None, + bias_range=None, + + # batchnorm params + axis=-1, + momentum=0.99, + epsilon=0.001, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + moving_mean_initializer="zeros", + moving_variance_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + fused=None, + trainable=True, + virtual_batch_size=None, + adjustment=None, + + # other params + ema_freeze_delay=None, + folding_mode="ema_stats_folding", + **kwargs): + + super(QDenseBatchnorm, self).__init__( + units=units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + kernel_quantizer=kernel_quantizer, + bias_quantizer=bias_quantizer, + kernel_range=kernel_range, + bias_range=bias_range, + **kwargs) + + # initialization of batchnorm part of the composite layer + self.batchnorm = layers.BatchNormalization( + axis=axis, momentum=momentum, epsilon=epsilon, center=center, + scale=scale, beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + beta_regularizer=beta_regularizer, + gamma_regularizer=gamma_regularizer, + beta_constraint=beta_constraint, gamma_constraint=gamma_constraint, + renorm=renorm, renorm_clipping=renorm_clipping, + renorm_momentum=renorm_momentum, fused=fused, trainable=trainable, + virtual_batch_size=virtual_batch_size, adjustment=adjustment + ) + + self.ema_freeze_delay = ema_freeze_delay + assert folding_mode in ["ema_stats_folding", "batch_stats_folding"] + self.folding_mode = folding_mode + + def build(self, input_shape): + super(QDenseBatchnorm, self).build(input_shape) + + # self._iteration (i.e., training_steps) is initialized with -1. When + # loading ckpt, it can load the number of training steps that have been + # previously trainied. If start training from scratch. + # TODO(lishanok): develop a way to count iterations outside layer + self._iteration = tf.Variable(-1, trainable=False, name="iteration", + dtype=tf.int64) + + def call(self, inputs, training=None): + + # numpy value, mark the layer is in training + training = self.batchnorm._get_training_value(training) # pylint: disable=protected-access + + # checking if to update batchnorm params + if (self.ema_freeze_delay is None) or (self.ema_freeze_delay < 0): + # if ema_freeze_delay is None or a negative value, do not freeze bn stats + bn_training = tf.cast(training, dtype=bool) + else: + bn_training = tf.math.logical_and(training, tf.math.less_equal( + self._iteration, self.ema_freeze_delay)) + + kernel = self.kernel + + #execute qdense output + qdense_outputs = tf.keras.backend.dot( + inputs, + kernel + ) + + if self.use_bias: + bias = self.bias + qdense_outputs = tf.keras.backend.bias_add( + qdense_outputs, + bias, + data_format="channels_last") + else: + bias = 0 + + # begin batchnorm + _ = self.batchnorm(qdense_outputs, training=bn_training) + + self._iteration.assign_add(tf_utils.smart_cond( + training, lambda: tf.constant(1, tf.int64), + lambda: tf.constant(0, tf.int64))) + + # calculate mean and variance from current batch + bn_shape = qdense_outputs.shape + ndims = len(bn_shape) + reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis] + keep_dims = len(self.batchnorm.axis) > 1 + mean, variance = self.batchnorm._moments( # pylint: disable=protected-access + math_ops.cast(qdense_outputs, self.batchnorm._param_dtype), # pylint: disable=protected-access + reduction_axes, + keep_dims=keep_dims) + + # get batchnorm weights + gamma = self.batchnorm.gamma + beta = self.batchnorm.beta + moving_mean = self.batchnorm.moving_mean + moving_variance = self.batchnorm.moving_variance + + if self.folding_mode == "batch_stats_folding": + # using batch mean and variance in the initial training stage + # after sufficient training, switch to moving mean and variance + new_mean = tf_utils.smart_cond( + bn_training, lambda: mean, lambda: moving_mean) + new_variance = tf_utils.smart_cond( + bn_training, lambda: variance, lambda: moving_variance) + + # get the inversion factor so that we replace division by multiplication + inv = math_ops.rsqrt(new_variance + self.batchnorm.epsilon) + if gamma is not None: + inv *= gamma + + # fold bias with bn stats + folded_bias = inv * (bias - new_mean) + beta + + elif self.folding_mode == "ema_stats_folding": + # We always scale the weights with a correction factor to the long term + # statistics prior to quantization. This ensures that there is no jitter + # in the quantized weights due to batch to batch variation. During the + # initial phase of training, we undo the scaling of the weights so that + # outputs are identical to regular batch normalization. We also modify + # the bias terms correspondingly. After sufficient training, switch from + # using batch statistics to long term moving averages for batch + # normalization. + + # use batch stats for calcuating bias before bn freeze, and use moving + # stats after bn freeze + mv_inv = math_ops.rsqrt(moving_variance + self.batchnorm.epsilon) + batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon) + + if gamma is not None: + mv_inv *= gamma + batch_inv *= gamma + folded_bias = tf_utils.smart_cond( + bn_training, + lambda: batch_inv * (bias - mean) + beta, + lambda: mv_inv * (bias - moving_mean) + beta) + # moving stats is always used to fold kernel in tflite; before bn freeze + # an additional correction factor will be applied to the conv2d output + # end batchnorm + inv = mv_inv + else: + assert ValueError + + # wrap dense kernel with bn parameters + folded_kernel = inv*kernel + # quantize the folded kernel + if self.kernel_quantizer is not None: + q_folded_kernel = self.kernel_quantizer_internal(folded_kernel) + else: + q_folded_kernel = folded_kernel + + #quantize the folded bias + if self.bias_quantizer_internal is not None: + q_folded_bias = self.bias_quantizer_internal(folded_bias) + else: + q_folded_bias = folded_bias + + applied_kernel = q_folded_kernel + applied_bias = q_folded_bias + + #calculate qdense output using the quantized folded kernel + folded_outputs = tf.keras.backend.dot(inputs, applied_kernel) + + if training is True and self.folding_mode == "ema_stats_folding": + batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon) + y_corr = tf_utils.smart_cond( + bn_training, + lambda: (math_ops.sqrt(moving_variance + self.batchnorm.epsilon) * + math_ops.rsqrt(variance + self.batchnorm.epsilon)), + lambda: tf.constant(1.0, shape=moving_variance.shape)) + folded_outputs = math_ops.mul(folded_outputs, y_corr) + + folded_outputs = tf.keras.backend.bias_add( + folded_outputs, + applied_bias, + data_format="channels_last" + ) + + if self.activation is not None: + return self.activation(folded_outputs) + + return folded_outputs + + def get_config(self): + base_config = super().get_config() + bn_config = self.batchnorm.get_config() + config = {"ema_freeze_delay": self.ema_freeze_delay, + "folding_mode": self.folding_mode} + name = base_config["name"] + out_config = dict( + list(base_config.items()) + + list(bn_config.items()) + list(config.items())) + + # names from different config override each other; use the base layer name + # as the this layer's config name + out_config["name"] = name + return out_config + + def get_quantization_config(self): + return { + "kernel_quantizer": str(self.kernel_quantizer_internal), + "bias_quantizer": str(self.bias_quantizer_internal), + } + + def get_quantizers(self): + return self.quantizers + + # def get_prunable_weights(self): + # return [self.kernel] + + def get_folded_weights(self): + """Function to get the batchnorm folded weights. + This function converts the weights by folding batchnorm parameters into + the weight of QDense. The high-level equation: + W_fold = gamma * W / sqrt(variance + epsilon) + bias_fold = gamma * (bias - moving_mean) / sqrt(variance + epsilon) + beta + """ + + kernel = self.kernel + if self.use_bias: + bias = self.bias + else: + bias = 0 + + # get batchnorm weights and moving stats + gamma = self.batchnorm.gamma + beta = self.batchnorm.beta + moving_mean = self.batchnorm.moving_mean + moving_variance = self.batchnorm.moving_variance + + # get the inversion factor so that we replace division by multiplication + inv = math_ops.rsqrt(moving_variance + self.batchnorm.epsilon) + if gamma is not None: + inv *= gamma + + # wrap conv kernel and bias with bn parameters + folded_kernel = inv * kernel + folded_bias = inv * (bias - moving_mean) + beta + + return [folded_kernel, folded_bias] \ No newline at end of file diff --git a/qkeras/utils.py b/qkeras/utils.py index 40ca10c2..80d3ba82 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -36,6 +36,7 @@ from .qlayers import Clip from .qconv2d_batchnorm import QConv2DBatchnorm from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm +from .qdense_batchnorm import QDenseBatchnorm from .qlayers import QActivation from .qlayers import QAdaptiveActivation from .qpooling import QAveragePooling2D @@ -96,6 +97,7 @@ "QDepthwiseConv2DBatchnorm", "QAveragePooling2D", "QGlobalAveragePooling2D", + "QDenseBatchnorm", ] @@ -264,7 +266,7 @@ def model_save_quantized_weights(model, filename=None): hw_weights = [] if any(isinstance(layer, t) for t in [ - QConv2DBatchnorm, QDepthwiseConv2DBatchnorm]): + QConv2DBatchnorm, QDenseBatchnorm, QDepthwiseConv2DBatchnorm]): qs = layer.get_quantizers() ws = layer.get_folded_weights() elif any(isinstance(layer, t) for t in [QSimpleRNN, QLSTM, QGRU]): @@ -380,7 +382,7 @@ def model_save_quantized_weights(model, filename=None): if has_scale: saved_weights[layer.name]["scales"] = scales if not any(isinstance(layer, t) for t in [ - QConv2DBatchnorm, QDepthwiseConv2DBatchnorm]): + QConv2DBatchnorm, QDenseBatchnorm, QDepthwiseConv2DBatchnorm]): # Set layer weights in the format that software inference uses layer.set_weights(weights) else: @@ -1056,6 +1058,8 @@ def _add_supported_quantized_objects(custom_objects): custom_objects["QConv2DBatchnorm"] = QConv2DBatchnorm custom_objects["QDepthwiseConv2DBatchnorm"] = QDepthwiseConv2DBatchnorm + custom_objects["QDenseBatchnorm"] = QDenseBatchnorm + custom_objects["QAveragePooling2D"] = QAveragePooling2D custom_objects["QGlobalAveragePooling2D"] = QGlobalAveragePooling2D custom_objects["QScaleShift"] = QScaleShift diff --git a/tests/autoqkeras_test.py b/tests/autoqkeras_test.py index 8d0f5239..5d17e9bf 100644 --- a/tests/autoqkeras_test.py +++ b/tests/autoqkeras_test.py @@ -18,6 +18,7 @@ import tempfile import numpy as np import pytest +import random from sklearn.datasets import load_iris from sklearn.preprocessing import MinMaxScaler import tensorflow.compat.v2 as tf @@ -35,6 +36,13 @@ from qkeras.autoqkeras import AutoQKerasScheduler +def get_adam_optimizer(learning_rate): + if hasattr(tf.keras.optimizers, "legacy"): + return tf.keras.optimizers.legacy.Adam(learning_rate) + else: + return tf.keras.optimizers.Adam(learning_rate) + + def dense_model(): """Creates test dense model.""" @@ -57,8 +65,10 @@ def dense_model(): def test_autoqkeras(): """Tests AutoQKeras scheduler.""" - np.random.seed(42) - tf.random.set_seed(42) + seed = 42 + random.seed(seed) + np.random.seed(seed) + tf.random.set_seed(seed) x_train, y_train = load_iris(return_X_y=True) @@ -104,7 +114,7 @@ def test_autoqkeras(): model = dense_model() model.summary() - optimizer = Adam(lr=0.01) + optimizer = get_adam_optimizer(learning_rate=0.01) model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"]) @@ -136,14 +146,14 @@ def test_autoqkeras(): } autoqk = AutoQKerasScheduler(model, metrics=["acc"], **run_config) - autoqk.fit(x_train, y_train, validation_split=0.1, batch_size=150, epochs=4) + autoqk.fit(x_train, y_train, validation_split=0.1, batch_size=150, epochs=8) qmodel = autoqk.get_best_model() - optimizer = Adam(lr=0.01) + optimizer = get_adam_optimizer(learning_rate=0.01) qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"]) - history = qmodel.fit(x_train, y_train, epochs=5, batch_size=150, + history = qmodel.fit(x_train, y_train, epochs=10, batch_size=150, validation_split=0.1) quantized_acc = history.history["acc"][-1] diff --git a/tests/bn_folding_test.py b/tests/bn_folding_test.py index ef152da1..3b72e29c 100644 --- a/tests/bn_folding_test.py +++ b/tests/bn_folding_test.py @@ -29,9 +29,11 @@ from tensorflow.keras.backend import clear_session from tensorflow.keras.utils import to_categorical from tensorflow.keras import metrics +import pytest from qkeras import QConv2DBatchnorm from qkeras import QConv2D +from qkeras import QDenseBatchnorm from qkeras import QDense from qkeras import QActivation from qkeras import QDepthwiseConv2D @@ -110,7 +112,7 @@ def get_qconv2d_batchnorm_model(input_shape, kernel_size, folding_mode, return model -def get_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay): +def get_conv2d_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay): x_shape = (2, 2, 1) loss_fn = tf.keras.losses.MeanSquaredError() @@ -164,6 +166,60 @@ def get_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay): return (unfold_model, fold_model) +def get_dense_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay): + + x_shape = (4,) + loss_fn = tf.keras.losses.MeanSquaredError() + optimizer = get_sgd_optimizer(learning_rate=1e-3) + + # define a model with seperate conv2d and bn layers + x = x_in = layers.Input(x_shape, name="input") + x = QDense( + 2, + kernel_initializer="ones", + bias_initializer="zeros", use_bias=False, + kernel_quantizer=kernel_quantizer, bias_quantizer=None, + name="conv2d")(x) + x = layers.BatchNormalization( + axis=-1, + momentum=0.99, + epsilon=0.001, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + moving_mean_initializer="zeros", + moving_variance_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + fused=None, + trainable=True, + virtual_batch_size=None, + adjustment=None, + name="bn")(x) + unfold_model = Model(inputs=[x_in], outputs=[x]) + unfold_model.compile(loss=loss_fn, optimizer=optimizer, metrics="acc") + + x = x_in = layers.Input(x_shape, name="input") + x = QDenseBatchnorm( + 2, + kernel_initializer="ones", bias_initializer="zeros", use_bias=False, + kernel_quantizer=kernel_quantizer, beta_initializer="zeros", + gamma_initializer="ones", moving_mean_initializer="zeros", + moving_variance_initializer="ones", folding_mode=folding_mode, + ema_freeze_delay=ema_freeze_delay, + name="foldconv2d")(x) + fold_model = Model(inputs=[x_in], outputs=[x]) + fold_model.compile(loss=loss_fn, optimizer=optimizer, metrics="acc") + + return (unfold_model, fold_model) + + def get_debug_model(model): layer_output_list = [] for layer in model.layers: @@ -181,10 +237,7 @@ def generate_dataset(train_size=10, output_shape=None): """create tf.data.Dataset with shape: (N,) + input_shape.""" - x_train = np.random.randint( - 4, size=(train_size, input_shape[0], input_shape[1], input_shape[2])) - x_train = np.random.rand( - train_size, input_shape[0], input_shape[1], input_shape[2]) + x_train = np.random.rand(*(train_size,) + input_shape) if output_shape: y_train = np.random.random_sample((train_size,) + output_shape) @@ -399,7 +452,8 @@ def test_loading(): assert_equal(weight1[1], weight2[1]) -def test_same_training_and_prediction(): +@pytest.mark.parametrize("model_name", ["conv2d", "dense"]) +def test_same_training_and_prediction(model_name): """test if fold/unfold layer has the same training and prediction output.""" epochs = 5 @@ -407,8 +461,12 @@ def test_same_training_and_prediction(): loss_metric = metrics.Mean() optimizer = get_sgd_optimizer(learning_rate=1e-3) - x_shape = (2, 2, 1) - kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]]) + if model_name == "conv2d": + x_shape = (2, 2, 1) + kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]]) + elif model_name == "dense": + x_shape = (4,) + kernel = np.array([[1., 1.], [1., 0.], [1., 1.], [0., 1.]]) gamma = np.array([2., 1.]) beta = np.array([0., 1.]) moving_mean = np.array([1., 1.]) @@ -418,12 +476,20 @@ def test_same_training_and_prediction(): train_ds = generate_dataset(train_size=10, batch_size=10, input_shape=x_shape, num_class=2) - (unfold_model, fold_model_batch) = get_models_with_one_layer( - kernel_quantizer=None, folding_mode="batch_stats_folding", - ema_freeze_delay=10) - (_, fold_model_ema) = get_models_with_one_layer( - kernel_quantizer=None, folding_mode="ema_stats_folding", - ema_freeze_delay=10) + if model_name == "conv2d": + (unfold_model, fold_model_batch) = get_conv2d_models_with_one_layer( + kernel_quantizer=None, folding_mode="batch_stats_folding", + ema_freeze_delay=10) + (_, fold_model_ema) = get_conv2d_models_with_one_layer( + kernel_quantizer=None, folding_mode="ema_stats_folding", + ema_freeze_delay=10) + elif model_name == "dense": + (unfold_model, fold_model_batch) = get_dense_models_with_one_layer( + kernel_quantizer=None, folding_mode="batch_stats_folding", + ema_freeze_delay=10) + (_, fold_model_ema) = get_dense_models_with_one_layer( + kernel_quantizer=None, folding_mode="ema_stats_folding", + ema_freeze_delay=10) unfold_model.layers[1].set_weights([kernel]) unfold_model.layers[2].set_weights( @@ -457,12 +523,20 @@ def test_same_training_and_prediction(): # models should be different, but the two folding modes should be the same epochs = 5 iteration = np.array(8) - (unfold_model, fold_model_batch) = get_models_with_one_layer( - kernel_quantizer=None, folding_mode="batch_stats_folding", - ema_freeze_delay=10) - (_, fold_model_ema) = get_models_with_one_layer( - kernel_quantizer=None, folding_mode="ema_stats_folding", - ema_freeze_delay=10) + if model_name == "conv2d": + (unfold_model, fold_model_batch) = get_conv2d_models_with_one_layer( + kernel_quantizer=None, folding_mode="batch_stats_folding", + ema_freeze_delay=10) + (_, fold_model_ema) = get_conv2d_models_with_one_layer( + kernel_quantizer=None, folding_mode="ema_stats_folding", + ema_freeze_delay=10) + elif model_name == "dense": + (unfold_model, fold_model_batch) = get_dense_models_with_one_layer( + kernel_quantizer=None, folding_mode="batch_stats_folding", + ema_freeze_delay=10) + (_, fold_model_ema) = get_dense_models_with_one_layer( + kernel_quantizer=None, folding_mode="ema_stats_folding", + ema_freeze_delay=10) unfold_model.layers[1].set_weights([kernel]) unfold_model.layers[2].set_weights( [gamma, beta, moving_mean, moving_variance])