-
Notifications
You must be signed in to change notification settings - Fork 105
/
Copy pathutils.py
1359 lines (1138 loc) · 49.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
import json
import tempfile
import types
import numpy as np
import os
import six
import re
import networkx as nx
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.models import model_from_json
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
from .qlayers import Clip
from .qconv2d_batchnorm import QConv2DBatchnorm
from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm
from .qdense_batchnorm import QDenseBatchnorm
from .qlayers import QActivation
from .qlayers import QAdaptiveActivation
from .qpooling import QAveragePooling2D
from .qlayers import QDense
from .qlayers import QInitializer
from .qconvolutional import QConv1D
from .qconvolutional import QConv2D
from .qconvolutional import QConv2DTranspose
from .qrecurrent import QSimpleRNN
from .qrecurrent import QSimpleRNNCell
from .qrecurrent import QLSTM
from .qrecurrent import QLSTMCell
from .qrecurrent import QGRU
from .qrecurrent import QGRUCell
from .qrecurrent import QBidirectional
from .qconvolutional import QSeparableConv1D
from .qconvolutional import QSeparableConv2D
from .qconvolutional import QDepthwiseConv2D
from .qnormalization import QBatchNormalization
from .qpooling import QGlobalAveragePooling2D
from .qtools import qgraph
from .quantizers import binary
from .quantizers import bernoulli
from .quantizers import get_weight_scale
from .quantizers import quantized_bits
from .quantizers import quantized_relu
from .quantizers import quantized_ulaw
from .quantizers import quantized_tanh
from .quantizers import quantized_sigmoid
from .quantizers import quantized_po2
from .quantizers import quantized_relu_po2
from .quantizers import stochastic_binary
from .quantizers import stochastic_ternary
from .quantizers import ternary
# from .google_internals.experimental_quantizers import quantized_bits_learnable_scale
# from .google_internals.experimental_quantizers import parametric_quantizer_d_xmax
from .safe_eval import safe_eval
from tensorflow.python.ops import math_ops
from .qmac import QScaleShift
REGISTERED_LAYERS = [
"QActivation",
"QAdaptiveActivation",
"QDense",
"QConv1D",
"QConv2D",
"QSeparableConv1D",
"QSeparableConv2D",
"QDepthwiseConv2D",
"QConv2DTranspose",
"QSimpleRNN",
"QLSTM",
"QGRU",
"QBidirectional",
"QBatchNormalization",
"QConv2DBatchnorm",
"QDepthwiseConv2DBatchnorm",
"QAveragePooling2D",
"QGlobalAveragePooling2D",
"QDenseBatchnorm",
]
def find_bn_fusing_layer_pair(model):
"""Finds layers that can be fused with the following batchnorm layers.
Args:
model: input model
Returns:
Dict that marks all the layer pairs that need to be fused.
Note: supports sequential and non-sequential model
"""
fold_model = clone_model(model)
(graph, _) = qgraph.GenerateGraphFromModel(
fold_model, "quantized_bits(8, 0, 1)", "quantized_bits(8, 0, 1)")
qgraph.GraphAddSingleSourceSingleSink(graph)
qgraph.GraphRemoveNodeWithNodeType(graph, "InputLayer")
qgraph.GraphPropagateActivationsToEdges(graph)
# Finds the Batchnorm nodes and mark them.
layers_followed_by_bn = {}
bn_layers_to_skip = set()
for node_id in nx.topological_sort(graph):
node = graph.nodes[node_id]
layer = node["layer"][0]
if layer:
successor_ids = list(graph.successors(node_id))
is_single = len(successor_ids) == 1
successor_layer = graph.nodes[successor_ids[0]]["layer"][0]
followed_by_bn = (successor_layer.__class__.__name__ ==
"QBatchNormalization")
# TODO(lishanok): extend to QDense types
enable_bn_fusing = layer.__class__.__name__ in [
"QConv2D", "QDepthwiseConv2D"
] and is_single and followed_by_bn
if enable_bn_fusing:
layers_followed_by_bn[layer.name] = successor_layer.name
bn_layers_to_skip.add(successor_layer.name)
return (layers_followed_by_bn, bn_layers_to_skip)
def add_bn_fusing_weights(prev_layer, bn_layer, saved_weights):
"""Adds additional fusing weights to saved_weights.
In hardware inference, we need to combined fuse previous layer's output with
the following batchnorm op.
z[i] = bn(y[i]) = inv[i] * y'[i] * scale[i] - bias'[i] is the final output
of the previous layer and bn layer, with:
inv[i] = gamma[i]* rsqrt(variance[i]^2+epsilon) is computed from the
bn layer weights
y'[i] is the i-th channel output from the previous layer (before scale)
scale[i] is the i-th channel kernel quantizer scale
fused_bias[i] = inv[i] * bias[i] + beta[i] - inv[i]*mean[i] where bias is
the bias term from the previous layer, beta and mean are the bn
layer weights.
Args:
prev_layer: QKeras layer, could be QConv2D/QDepthwiseConv2D/QDense.
bn_layer: The following QBatchNormalization layer that needs to be
fused with the previous layer.
saved_weights: Dict. The centralized weights dictionary that exports
relevant weights and parameters for hardware inference.
"""
bn_qs = bn_layer.quantizers
bn_ws = bn_layer.get_weights()
if bn_qs[4] is not None:
assert bn_qs[0] is None and bn_qs[3] is None, (
"If using the inverse quantizer, the gamma and variance quantizers "
"should not be used in order to avoid quantizing a value twice.")
def apply_quantizer(quantizer, input_weight):
if quantizer:
weight = tf.constant(input_weight)
weight = tf.keras.backend.eval(quantizer(weight))
else:
weight = input_weight
return weight
# Quantize respective bn layer weights
gamma = 1.0
beta = 0
idx = 0
if bn_layer.scale:
gamma = apply_quantizer(bn_layer.gamma_quantizer_internal, bn_ws[idx])
idx += 1
if bn_layer.center:
beta = apply_quantizer(bn_layer.beta_quantizer_internal, bn_ws[idx])
idx += 1
mean = apply_quantizer(bn_layer.mean_quantizer_internal, bn_ws[idx])
idx += 1
variance = apply_quantizer(bn_layer.variance_quantizer_internal, bn_ws[idx])
# Compute inv[i]
inv = gamma * math_ops.rsqrt(variance + bn_layer.epsilon)
inv = inv.numpy()
if bn_layer.inverse_quantizer_internal is not None:
quantizer = bn_layer.inverse_quantizer_internal
inv = tf.keras.backend.eval(quantizer(inv))
# Compute fused_bias[i]
if prev_layer.use_bias:
cur_weights = prev_layer.get_weights()
assert len(cur_weights) == 2, ("Weights should have length of 2. Found"
f"{len(cur_weights)} instead.")
prev_bias = cur_weights[-1]
else:
prev_bias = 0
b_prime = inv * prev_bias + beta - inv * mean
saved_weights[prev_layer.name]["enable_bn_fusing"] = True
saved_weights[prev_layer.name]["fused_bn_layer_name"] = bn_layer.name
saved_weights[prev_layer.name]["bn_inv"] = inv
saved_weights[prev_layer.name]["fused_bias"] = b_prime
# Model utilities: before saving the weights, we want to apply the quantizers
def model_save_quantized_weights(model, filename=None):
"""Quantizes model for inference and save it.
Takes a model with weights, apply quantization function to weights and
returns a dictionary with quantized weights.
User should be aware that "po2" quantization functions cannot really
be quantized in meaningful way in Keras. So, in order to preserve
compatibility with inference flow in Keras, we do not covert "po2"
weights and biases to exponents + signs (in case of quantize_po2), but
return instead (-1)**sign*(2**round(log2(x))). In the returned dictionary,
we will return the pair (sign, round(log2(x))).
Special care needs to be given to quantized_bits(alpha="auto_po2") as well.
Since in this quantizer, hardware needs the integer weights and scale for
hardware inference, this function will return the pair (scale,
integer_weights) in the returned dictionary.
Arguments:
model: model with weights to be quantized.
filename: if specified, we will save the hdf5 containing the quantized
weights so that we can use them for inference later on.
Returns:
dictionary containing layer name and quantized weights that can be used
by a hardware generator.
"""
saved_weights = {}
# Find the conv/dense layers followed by Batchnorm layers
(fusing_layer_pair_dict, bn_layers_to_skip) = find_bn_fusing_layer_pair(model)
print("... quantizing model")
for layer in model.layers:
if hasattr(layer, "get_quantizers"):
# weights for software inference
weights = []
signs = []
scales = []
# weights for hardware inference
hw_weights = []
if any(isinstance(layer, t) for t in [
QConv2DBatchnorm, QDenseBatchnorm, QDepthwiseConv2DBatchnorm]):
qs = layer.get_quantizers()
ws = layer.get_folded_weights()
elif any(isinstance(layer, t) for t in [QSimpleRNN, QLSTM, QGRU]):
qs = layer.get_quantizers()[:-1]
ws = layer.get_weights()
else:
qs = layer.get_quantizers()
ws = layer.get_weights()
has_sign = False
has_scale = False
enable_bn_fusing = False
if (isinstance(layer, QBatchNormalization) and
layer.name in bn_layers_to_skip):
# Mark current bn layer to be fused with the previous layer
enable_bn_fusing = True
for quantizer, weight in zip(qs, ws):
if quantizer:
weight = tf.constant(weight)
weight = tf.keras.backend.eval(quantizer(weight))
# If quantizer is power-of-2 (quantized_po2 or quantized_relu_po2),
# we would like to process it here.
#
# However, we cannot, because we will lose sign information as
# quanized_po2 will be represented by the tuple (sign, log2(abs(w))).
#
# In addition, we will not be able to use the weights on the model
# any longer.
#
# So, instead of "saving" the weights in the model, we will return
# a dictionary so that the proper values can be propagated.
# Weights store the weight in the format that software inference uses.
weights.append(weight)
has_sign = False
has_scale = False
if quantizer:
if isinstance(quantizer, six.string_types):
q_name = quantizer
elif hasattr(quantizer, "__name__"):
q_name = quantizer.__name__
elif hasattr(quantizer, "name"):
q_name = quantizer.name
elif hasattr(quantizer, "__class__"):
q_name = quantizer.__class__.__name__
else:
q_name = ""
if quantizer and ("_po2" in q_name):
# Quantized_relu_po2 does not have a sign.
if isinstance(quantizer, quantized_po2):
has_sign = True
sign = np.sign(weight)
# Makes sure values are -1 or +1 only
sign += (1.0 - np.abs(sign))
# hw_weight store the weight in the format that hardware inference
# uses.
hw_weight = np.round(np.log2(np.abs(weight)))
signs.append(sign)
scales.append([])
elif (isinstance(quantizer, quantized_bits) and
quantizer.alpha == "auto_po2"):
unsigned_bits = quantizer.bits - quantizer.keep_negative
m = K.cast_to_floatx(pow(2, unsigned_bits))
m_i = K.cast_to_floatx(K.pow(2, quantizer.integer))
assert hasattr(quantizer.scale, "numpy"), (
"The auto_po2 quantizer has to be called first in order to know "
"the values of scale.")
scale = K.cast_to_floatx(quantizer.scale.numpy())
# Make sure scale is power of 2 values
log2val = np.log2(scale)
diff = np.round(log2val) - log2val
assert np.all(diff == 0), "scale must be power of 2 values!"
# Convert fixed point weight to integer weight, just
hw_weight = weight * m / m_i
# Because hw_weight is integer weights, set scale = scale * m_i / m
# so that when we can multiply scale with the integer weight
# during hardware inference to get the fixed point weights
scale = scale * m_i / m
has_scale = True
scales.append(scale)
else:
hw_weight = weight
signs.append([])
scales.append([])
hw_weights.append(hw_weight)
# Save the weights in the format that hardware inference uses
saved_weights[layer.name] = {"weights": hw_weights,
"enable_bn_fusing": enable_bn_fusing}
if (isinstance(layer, QAveragePooling2D) or
isinstance(layer, QGlobalAveragePooling2D)):
if isinstance(layer, QAveragePooling2D):
pool_area = layer.pool_size
if isinstance(layer.pool_size, int):
pool_area = layer.pool_size * layer.pool_size
else:
pool_area = np.prod(layer.pool_size)
else:
pool_area = layer.compute_pooling_area(input_shape=layer.input_shape)
saved_weights[
layer.name]["q_mult_factor"] = layer.average_quantizer_internal(
1.0 / pool_area).numpy()
saved_weights[layer.name]["mult_factor"] = 1.0 / pool_area
saved_weights[layer.name]["pool_area"] = pool_area
if has_sign:
saved_weights[layer.name]["signs"] = signs
if has_scale:
saved_weights[layer.name]["scales"] = scales
if not any(isinstance(layer, t) for t in [
QConv2DBatchnorm, QDenseBatchnorm, QDepthwiseConv2DBatchnorm]):
# Set layer weights in the format that software inference uses
layer.set_weights(weights)
else:
print(layer.name, " conv and batchnorm weights cannot be seperately"
" quantized because they will be folded before quantization.")
# adjust weights for bn fusing if necessary
if layer.name in fusing_layer_pair_dict.keys():
print(f"Fuse {layer.name} output with "
f"{fusing_layer_pair_dict[layer.name]} for hardware inference.")
add_bn_fusing_weights(
prev_layer=layer,
bn_layer=model.get_layer(fusing_layer_pair_dict[layer.name]),
saved_weights=saved_weights)
else:
if layer.get_weights():
print(" ", layer.name, "has not been quantized")
if filename:
model.save_weights(filename)
return saved_weights
def quantize_activation(layer_config, activation_bits):
"""Replaces activation by quantized activation functions."""
str_act_bits = str(activation_bits)
# relu -> quantized_relu(bits)
# tanh -> quantized_tanh(bits)
# sigmoid -> quantized_sigmoid(bits)
# more to come later
if layer_config.get("activation", None) is None:
return
if isinstance(layer_config["activation"], six.string_types):
a_name = layer_config["activation"]
elif isinstance(layer_config["activation"], types.FunctionType):
a_name = layer_config["activation"].__name__
else:
a_name = layer_config["activation"].__class__.__name__
if a_name == "linear":
return
if a_name == "relu":
layer_config["activation"] = "quantized_relu(" + str_act_bits + ")"
elif a_name == "tanh":
layer_config["activation"] = "quantized_tanh(" + str_act_bits + ")"
elif a_name == "sigmoid":
layer_config["activation"] = "quantized_sigmoid(" + str_act_bits + ")"
def get_config(quantizer_config, layer, layer_class, parameter=None):
"""Returns search of quantizer on quantizer_config."""
quantizer = quantizer_config.get(layer["config"]["name"],
quantizer_config.get(layer_class, None))
if quantizer is not None and parameter is not None:
quantizer = quantizer.get(parameter, None)
return quantizer
def is_TFOpLambda_layer(layer):
return layer.__class__.__name__ == "TFOpLambda"
def get_y_from_TFOpLambda(model_cfg, layer):
"""Get the value of "y" from the TFOpLambda layer's configuration.
Args:
model_cfg: dictionary type, model.get_config() output
layer: a given layer instance
Return:
value of "y" for a TFOpLambda layer. 'y' here corresponds to how tensorflow
stores TFOpLambda layer parameter in serialization. for example,
TFOpLambda(func), where func is tf.multiply(input_tensor, 3). "y" would be
the value 3.
"""
for layer_config in model_cfg["layers"]:
op_name = layer_config["config"]["name"]
class_name = layer_config["class_name"]
# TODO(lishanok): Extend support for other TFOpLambda types when needed
if op_name == layer.name and class_name == "TFOpLambda":
assert ("tf.__operators__.add" in op_name or "tf.math.multiply"
in op_name), "TFOpLambda layer {} not supported!".format(op_name)
return layer_config["inbound_nodes"][-1][-1]["y"]
return None
def convert_to_folded_model(model):
"""Find conv/dense layers followed by bn layers and fold them.
Args:
model: input model
Returns:
new model without bn layers
list of layers being folded
Note: supports sequential and non-sequential model
"""
fold_model = clone_model(model)
model_cfg = model.get_config()
(graph, _) = qgraph.GenerateGraphFromModel(
fold_model, "quantized_bits(8, 0, 1)", "quantized_bits(8, 0, 1)")
qgraph.GraphAddSingleSourceSingleSink(graph)
qgraph.GraphRemoveNodeWithNodeType(graph, "InputLayer")
qgraph.GraphPropagateActivationsToEdges(graph)
# Finds the Batchnorm nodes to be deleted and mark them.
bn_nodes_to_delete = []
layers_to_fold = []
for node_id in nx.topological_sort(graph):
layer_input_tensors = []
node = graph.nodes[node_id]
layer = node["layer"][0]
if layer:
successor_ids = list(graph.successors(node_id))
is_single = len(successor_ids) == 1
successor_layer = graph.nodes[successor_ids[0]]["layer"][0]
followed_by_bn = (successor_layer.__class__.__name__ ==
"BatchNormalization")
# TODO(lishanok): extend to QDense types
is_foldable = layer.__class__.__name__ in [
"Conv2D", "DepthwiseConv2D"
] and is_single and followed_by_bn
if is_foldable:
# Removes the batchnorm node from the graph.
bn_nodes_to_delete.append(successor_ids[0])
layers_to_fold.append(layer.name)
# Deletes the marked nodes.
for node_id in bn_nodes_to_delete:
qgraph.GraphRemoveNode(graph, node_id)
# Modifies model according to the graph.
model_outputs = []
x = model_inputs = fold_model.inputs
for node_id in nx.topological_sort(graph):
layer_input_tensors = []
node = graph.nodes[node_id]
layer = node["layer"][0]
if layer:
# Gets layer input tensors from graph edge.
for parent_node_id in graph.predecessors(node_id):
edge = graph.edges[(parent_node_id, node_id)]
input_tensor = edge["tensor"]
layer_input_tensors.append(input_tensor)
# We call the layer to get output tensor.
if len(layer_input_tensors) == 1:
layer_input_tensors = layer_input_tensors[0].deref()
else:
layer_input_tensors = [t.deref() for t in layer_input_tensors]
if is_TFOpLambda_layer(layer):
# TFOpLambda layer requires one extra input: "y"
y = get_y_from_TFOpLambda(model_cfg, layer)
x = layer(layer_input_tensors, y)
else:
x = layer(layer_input_tensors)
# Replaces edge tensors between the predecessor and successor
for u, v in graph.edges(node_id):
# u is current layer node, v is successor layer node
# graph[u][v] is the edge between the two nodes
# Replace the tensor on this edge so that the input tensor for the
# successor layer can be updated accordingly.
graph[u][v]["tensor"] = x.ref()
if v == -2 and x not in model_outputs:
# When it is output layer, add the output tensor of this layer
# into model outputs.
model_outputs.append(x)
new_model = Model(inputs=model_inputs, outputs=model_outputs)
return new_model, layers_to_fold
def model_quantize(model,
quantizer_config,
activation_bits,
custom_objects=None,
transfer_weights=False,
prefer_qadaptiveactivation=False,
enable_bn_folding=False):
"""Creates a quantized model from non-quantized model.
The quantized model translation is based on json interface of Keras,
which requires a custom_objects dictionary for "string" types.
Because of the way json works, we pass "string" objects for the
quantization mechanisms and we perform an eval("string") which
technically is not safe, but it will do the job.
The quantizer_config is a dictionary with the following form.
{
Dense_layer_name: {
"kernel_quantizer": "quantizer string",
"bias_quantizer": "quantizer_string"
},
Conv2D_layer_name: {
"kernel_quantizer": "quantizer string",
"bias_quantizer": "quantizer_string"
},
Activation_layer_name: "quantizer string",
"QActivation": { "relu": "quantizer_string" },
"QConv2D": {
"kernel_quantizer": "quantizer string",
"bias_quantizer": "quantizer_string"
},
"QBatchNormalization": {}
}
In the case of "QBidirectional", we can follow the same form as above.
The specified configuration will be used for both forward and backwards
layer.
{
"Bidirectional" : {
"kernel_quantizer" : "quantizer string",
"bias_quantizer" : "quantizer string",
"recurrent_quantizer" : "quantizer string"
}
}
In the case of "QActivation", we can modify only certain types of
activations, for example, a "relu". In this case we represent the
activation name by a dictionary, or we can modify all activations,
without representhing as a set.
We right now require a default case in case we cannot find layer name.
This simplifies the dictionary because the simplest case, we can just
say:
{
"default": {
"kernel": "quantized_bits(4)",
"bias": "quantized_bits(4)"
}
}
and this will quantize all layers' weights and bias to be created with
4 bits.
Arguments:
model: model to be quantized
quantizer_config: dictionary (as above) with quantized parameters
activation_bits: number of bits for quantized_relu, quantized_tanh,
quantized_sigmoid
custom_objects: dictionary following keras recommendations for json
translation.
transfer_weights: if true, weights are to be transfered from model to
qmodel.
prefer_qadaptiveactivation: Bool. If true, try to use QAdaptiveActivation
over QActivation whenever possible
enable_bn_folding: Bool. If true, fold conv/dense layers with
following batch normalization layers whenever possible. use
QConv2DBatchnorm for example, to replace conv2d layers
Returns:
qmodel with quantized operations and custom_objects.
"""
if enable_bn_folding:
# Removes bn layers from the model and find a list of layers to fold.
model, layers_to_fold = convert_to_folded_model(model)
if len(layers_to_fold) == 0:
# If no layers to fold, no need to perform folding.
enable_bn_folding = False
if not custom_objects:
custom_objects = {}
# Let's make a deep copy to make sure our objects are not shared elsewhere.
jm = copy.deepcopy(json.loads(model.to_json()))
custom_objects = copy.deepcopy(custom_objects)
config = jm["config"]
layers = config["layers"]
def quantize_rnn(layer, quantizer_config):
q_name = "Q" + layer["class_name"]
# Needs to add kernel, recurrent bias quantizers.
kernel_quantizer = get_config(
quantizer_config, layer, q_name, "kernel_quantizer")
recurrent_quantizer = get_config(
quantizer_config, layer, q_name, "recurrent_quantizer")
if layer["config"]['use_bias']:
bias_quantizer = get_config(
quantizer_config, layer, q_name, "bias_quantizer")
else:
bias_quantizer = None
state_quantizer = get_config(
quantizer_config, layer, q_name, "state_quantizer")
# This is to avoid unwanted transformations.
if kernel_quantizer is None:
return
layer["config"]["kernel_quantizer"] = kernel_quantizer
layer["config"]["recurrent_quantizer"] = recurrent_quantizer
layer["config"]["bias_quantizer"] = bias_quantizer
layer["config"]["state_quantizer"] = state_quantizer
# If activation is present, add activation here.
activation = get_config(
quantizer_config, layer, q_name, "activation_quantizer")
if activation:
layer["config"]["activation"] = activation
else:
quantize_activation(layer["config"], activation_bits)
# If recurrent activation is present, add activation here.
if layer["class_name"] in ["LSTM", "GRU"]:
recurrent_activation = get_config(
quantizer_config, layer, q_name, "recurrent_activation_quantizer")
if recurrent_activation:
layer["config"]["recurrent_activation"] = recurrent_activation
layer["class_name"] = q_name
registered_name = layer.pop("registered_name", None)
if registered_name:
layer["registered_name"] = q_name
for layer in layers:
layer_config = layer["config"]
# Dense becomes QDense, Conv1D becomes QConv1D etc
# Activation converts activation functions.
if layer["class_name"] in [
"Dense", "Conv1D", "Conv2D", "Conv2DTranspose",
"SeparableConv1D", "SeparableConv2D"
]:
if (layer["class_name"] in ["Dense", "Conv2D"] and enable_bn_folding and
layer["name"] in layers_to_fold):
# Only fold if current layer is followed by BN layer.
q_name = "Q" + layer["class_name"] + "Batchnorm"
layer_config["use_bias"] = True # Folded layers require a bias
# Sets ema_freeze_delay and folding_mode specific to
# QDepthwiseConv2DBatchnorm layer config.
folding_mode = get_config(
quantizer_config, layer, q_name, "folding_mode")
layer_config["folding_mode"] = (
folding_mode if folding_mode else "ema_stats_folding")
ema_freeze_delay = get_config(
quantizer_config, layer, q_name, "ema_freeze_delay")
layer_config["ema_freeze_delay"] = (
ema_freeze_delay if ema_freeze_delay else None)
else:
q_name = "Q" + layer["class_name"]
# Needs to add kernel/bias quantizers.
kernel_quantizer = get_config(
quantizer_config, layer, q_name, "kernel_quantizer")
if layer_config["use_bias"]:
bias_quantizer = get_config(
quantizer_config, layer, q_name, "bias_quantizer")
else:
bias_quantizer = None
if (kernel_quantizer is None and
q_name == "Q" + layer["class_name"] + "Batchnorm"):
# Tries none-folded layer quantizer as a back up.
kernel_quantizer = get_config(
quantizer_config, layer, "Q" + layer["class_name"],
"kernel_quantizer")
bias_quantizer = get_config(
quantizer_config, layer, "Q" + layer["class_name"],
"bias_quantizer")
# This is to avoid unwanted transformations.
if kernel_quantizer is None:
continue
layer["class_name"] = q_name
layer_config["kernel_quantizer"] = kernel_quantizer
layer_config["bias_quantizer"] = bias_quantizer
# If activation is present, add activation here.
quantizer = get_config(
quantizer_config, layer, q_name, "activation_quantizer")
if quantizer:
layer_config["activation"] = quantizer
else:
quantize_activation(layer_config, activation_bits)
elif layer["class_name"] == "DepthwiseConv2D":
if enable_bn_folding and layer["name"] in layers_to_fold:
q_name = "QDepthwiseConv2DBatchnorm"
layer_config["use_bias"] = True # Folded layers require a bias
# Sets ema_freeze_delay and folding_mode specific to
# QDepthwiseConv2DBatchnorm layers.
folding_mode = get_config(
quantizer_config, layer, q_name, "folding_mode")
layer_config["folding_mode"] = (
folding_mode if folding_mode else "ema_stats_folding")
ema_freeze_delay = get_config(
quantizer_config, layer, q_name, "ema_freeze_delay")
layer_config["ema_freeze_delay"] = (
ema_freeze_delay if ema_freeze_delay else None)
else:
q_name = "QDepthwiseConv2D"
# Needs to add kernel/bias quantizers.
depthwise_quantizer = get_config(quantizer_config, layer, q_name,
"depthwise_quantizer")
if layer_config["use_bias"]:
bias_quantizer = get_config(quantizer_config, layer, q_name,
"bias_quantizer")
else:
bias_quantizer = None
if depthwise_quantizer is None and q_name == "QDepthwiseConv2DBatchnorm":
# Tries none-folded layer quantizer as a back up.
depthwise_quantizer = get_config(
quantizer_config, layer, "QDepthwiseConv2D", "depthwise_quantizer")
bias_quantizer = get_config(
quantizer_config, layer, "QDepthwiseConv2D", "bias_quantizer")
# This is to avoid unwanted transformations.
if depthwise_quantizer is None:
continue
layer["class_name"] = q_name
layer_config["depthwise_quantizer"] = depthwise_quantizer
layer_config["bias_quantizer"] = bias_quantizer
# If activation is present, add activation here.
quantizer = get_config(quantizer_config, layer, q_name,
"activation_quantizer",)
if quantizer:
layer_config["activation"] = quantizer
else:
quantize_activation(layer_config, activation_bits)
elif layer["class_name"] in ["SimpleRNN", "LSTM", "GRU"]:
quantize_rnn(layer, quantizer_config)
elif layer["class_name"] == "Bidirectional":
forward_layer_quantizer_config = {
layer_config["layer"]["config"]["name"]:
get_config(quantizer_config, layer, "QBidirectional")
}
quantize_rnn(layer["config"]["layer"], forward_layer_quantizer_config)
if "backward_layer" in layer_config:
backward_layer_quantizer_config = {
layer_config["backward_layer"]["config"]["name"]:
get_config(quantizer_config, layer, "QBidirectional")
}
quantize_rnn(layer["config"]["backward_layer"],
backward_layer_quantizer_config)
layer["class_name"] = "QBidirectional"
elif layer["class_name"] == "Activation":
if prefer_qadaptiveactivation: # Try to find QAdaptiveActivation first
quantizer = get_config(quantizer_config, layer, "QAdaptiveActivation")
is_qadaptiveactivation = True
if quantizer is None: # Try QActivation as a backup
quantizer = get_config(quantizer_config, layer, "QActivation")
is_qadaptiveactivation = False
else: # Tries to find QActivation first.
quantizer = get_config(quantizer_config, layer, "QActivation")
is_qadaptiveactivation = False
if quantizer is None: # Try QAdaptiveActivation as a backup
quantizer = get_config(quantizer_config, layer, "QAdaptiveActivation")
is_qadaptiveactivation = True
# This is to avoid softmax from quantizing in autoq.
if quantizer is None:
continue
# If quantizer exists in dictionary related to this name,
# use it, otherwise, use normal transformations.
if not isinstance(quantizer, dict) or quantizer.get(
layer_config["activation"], None):
# Only change activation layer if we will use a quantized activation.
layer["class_name"] = ("QAdaptiveActivation" if is_qadaptiveactivation
else "QActivation")
if isinstance(quantizer, dict):
quantizer = quantizer[layer_config["activation"]]
if quantizer:
if is_qadaptiveactivation:
assert quantizer.find(",") < 0, \
"Only integer bits should be defined for QAdaptiveActivation"
layer_config["total_bits"] = int(re.sub(r"[^\d]", "", quantizer))
quantizer = re.sub(r"\(.*", "", quantizer) # remove params
layer_config["activation"] = quantizer
else:
quantize_activation(layer_config, activation_bits)
# We have to do this because of other instances of ReLU.
elif layer["class_name"] in ["ReLU", "relu", "LeakyReLU"]:
quantizer = get_config(quantizer_config, layer, "QActivation")
# This is to avoid unwanted transformations.
if quantizer is None:
continue
if layer["class_name"] == "LeakyReLU":
negative_slope = layer["config"]["alpha"]
elif layer["class_name"] == "relu":
max_value = layer["config"]["max_value"]
negative_slope = layer["config"]["alpha"]
threshold = layer["config"]["threshold"]
else: # ReLU from mobilenet
max_value = layer["config"]["max_value"]
negative_slope = layer["config"]["negative_slope"]
threshold = layer["config"]["threshold"]
if negative_slope > 0:
q_name = "leakyrelu"
else:
q_name = "relu"
# If quantizer exists in dictionary related to this name,
# use it, otherwise, use normal transformations.
if not isinstance(quantizer, dict) or quantizer.get(q_name, None):
# Only change activation layer if we will use a quantized activation.
layer["class_name"] = "QActivation"
# Remove relu specific configurations
# remember that quantized relu's are always upper bounded.
if layer["class_name"] == "LeakyReLU":
del layer["config"]["alpha"]
elif layer["class_name"] == "relu":
del layer["config"]["max_value"]
del layer["config"]["alpha"]
del layer["config"]["threshold"]
else: # ReLU from mobilenet
del layer["config"]["max_value"]
del layer["config"]["negative_slope"]
del layer["config"]["threshold"]
if isinstance(quantizer, dict):
quantizer = quantizer[q_name]
if quantizer:
layer["config"]["activation"] = quantizer
else:
quantize_activation(layer["config"], activation_bits)
elif layer["class_name"] == "BatchNormalization":
# We will assume at least QBatchNormalization or
# layer name is in dictionary to enable conversion
# otherwise we will just skip it.
if (
layer_config["name"] not in quantizer_config and
"QBatchNormalization" not in quantizer_config
):
continue
layer["class_name"] = "QBatchNormalization"
# Needs to add kernel/bias quantizers.
gamma_quantizer = get_config(
quantizer_config, layer, "QBatchNormalization",
"gamma_quantizer")
beta_quantizer = get_config(
quantizer_config, layer, "QBatchNormalization",
"beta_quantizer")
mean_quantizer = get_config(
quantizer_config, layer, "QBatchNormalization",
"mean_quantizer")
variance_quantizer = get_config(
quantizer_config, layer, "QBatchNormalization",
"variance_quantizer")
layer_config["gamma_quantizer"] = gamma_quantizer
layer_config["beta_quantizer"] = beta_quantizer
layer_config["mean_quantizer"] = mean_quantizer
layer_config["variance_quantizer"] = variance_quantizer
elif layer["class_name"] in ["AveragePooling2D", "GlobalAveragePooling2D"]:
q_name = "Q" + layer["class_name"]
# Adds the average quanizer to config.
average_quantizer = get_config(
quantizer_config, layer, q_name, "average_quantizer")
# This is to avoid unwanted transformations.
if average_quantizer is None:
continue
layer["class_name"] = q_name
layer_config["average_quantizer"] = average_quantizer
# Adds activation to config.
quantizer = get_config(
quantizer_config, layer, q_name, "activation_quantizer")
if quantizer:
layer_config["activation"] = quantizer