-
Notifications
You must be signed in to change notification settings - Fork 125
/
Copy pathquantization_analysis_util.py
328 lines (262 loc) · 11.5 KB
/
quantization_analysis_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import os
from matplotlib import pyplot as plt
from typing import List
import torch
import torch.nn as nn
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from .train_util import get_logger, get_module_device
log = get_logger(__name__, 'INFO')
def sqnr(x: torch.Tensor, y: torch.Tensor):
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return (20 * torch.log10(Ps / Pn)).item()
def cosine(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
"""calulate the cosine similarity between x and y"""
if x.shape != y.shape:
raise ValueError(f'Can not compute loss for tensors with different shape. ({x.shape} and {y.shape})')
reduction = str(reduction).lower()
if x.ndim == 1:
x = x.unsqueeze(0)
y = y.unsqueeze(0)
x = x.flatten(start_dim=1).float()
y = y.flatten(start_dim=1).float()
cosine_sim = torch.cosine_similarity(x, y, dim=-1)
if reduction == 'mean':
return torch.mean(cosine_sim)
elif reduction == 'sum':
return torch.sum(cosine_sim)
elif reduction == 'none':
return cosine_sim
else:
raise ValueError(f'Cosine similarity do not supported {reduction} method.')
METRIC_DICT = {
'cosine': cosine,
'sqnr': sqnr,
}
def error_print(metric, q_errors_activ, q_errors_weight, sort_num):
logs = []
if len(q_errors_weight) > 0:
logs.append('')
logs.append(f'Weights ({metric} sorted {sort_num}):')
for n, m, e in q_errors_weight:
logs.append(f'{n:40} {metric}: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}')
if len(q_errors_activ) > 0:
logs.append('')
logs.append(f'Activations ({metric} sorted {sort_num}):')
for n, m, e in q_errors_activ:
logs.append(f'{n:50} {metric}: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}')
if len(q_errors_weight) == 0 and len(q_errors_activ) == 0:
logs.append('')
logs.append('All good!')
if len(logs) > 0:
logs.insert(0, 'Quantization error report:')
logs.append('')
full_log = '\n'.join(logs)
log.warning(full_log)
def layer_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine', sort_num: float = 20):
"""Generates the layerwise quant error report using the given metric, the q_model need to be qat_prepared.
Args:
q_model: The quant prepared model
dummy_input: A viable input to the model
metric: Metrics for measuring the error of floating point tensor and quantized tensor.
Default to be 'cosine', optional 'sqnr'.
sort_num : The smallest sort_num layer0 on given metric. Defaults to 20
"""
if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel):
model = q_model.module
else:
model = q_model
metric_fn = METRIC_DICT[metric]
train_flag = model.training
model.eval()
with torch.no_grad():
modules_list = {}
names_list = {}
float_results = {}
hooks = []
def forward_hook(module, input, output):
name = names_list[module]
float_results[name] = input
fake_quant_enabled_dict = {}
observer_enabled_dict = {}
for n, m in model.named_modules():
if isinstance(m, torch.quantization.FakeQuantize):
names_list[m] = n
modules_list[n] = m
fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
observer_enabled_dict[m] = m.observer_enabled.clone()
hooks.append(m.register_forward_hook(forward_hook))
if len(modules_list) == 0:
log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?')
model.apply(torch.quantization.disable_fake_quant)
model.apply(torch.quantization.disable_observer)
device = get_module_device(model)
if type(dummy_input) is torch.Tensor:
actual_input = [dummy_input]
elif isinstance(dummy_input, (tuple, list)):
actual_input = list(dummy_input)
else:
log.error(f'Unsupported type {type(dummy_input)} for dummy input')
assert False
for i in range(len(actual_input)):
dummy_input = actual_input[i]
if type(dummy_input) is torch.Tensor:
if dummy_input.device != device:
actual_input[i] = dummy_input.to(device)
with torch.no_grad():
model(*actual_input)
for h in hooks:
h.remove()
hooks.clear()
for m, v in fake_quant_enabled_dict.items():
m.fake_quant_enabled = v
q_errors_weight = []
q_errors_activ = []
while len(float_results) > 0:
n, f = float_results.popitem()
mod = modules_list[n]
with torch.no_grad():
q = mod(*f)
loss = metric_fn(f[0], q)
actual_n = '.'.join(n.split('.')[:-1])
if n.endswith('.weight_fake_quant'):
q_errors_weight.append((actual_n, mod, loss))
else:
q_errors_activ.append((actual_n, mod, loss))
q_errors_weight = sorted(q_errors_weight, key=lambda x: x[2])
q_errors_activ = sorted(q_errors_activ, key=lambda x: x[2])
q_errors_weight = q_errors_weight[:sort_num]
q_errors_activ = q_errors_activ[:sort_num]
error_print(metric, q_errors_activ, q_errors_weight, sort_num)
for m, v in observer_enabled_dict.items():
m.observer_enabled = v
if train_flag:
model.train()
def graph_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine'):
"""Generates the cumulative quant error report using the given metric, the q_model need to be qat_prepared.
Args:
q_model: The quant prepared model.
dummy_input: A viable input to the model
metric: Metrics for measuring the error of floating point tensor and quantized tensor.
Default to be 'cosine', optional 'sqnr'.
"""
if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel):
model = q_model.module
else:
model = q_model
metric_fn = METRIC_DICT[metric]
train_flag = model.training
model.eval()
with torch.no_grad():
modules_list = {}
names_list = {}
results = {}
hooks = []
def forward_hook(module, input, output):
name = names_list[module]
results[name] = input
fake_quant_enabled_dict = {}
observer_enabled_dict = {}
for n, m in model.named_modules():
if isinstance(m, torch.quantization.FakeQuantize):
names_list[m] = n
modules_list[n] = m
fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
observer_enabled_dict[m] = m.observer_enabled.clone()
hooks.append(m.register_forward_hook(forward_hook))
model.apply(torch.quantization.disable_fake_quant)
model.apply(torch.quantization.disable_observer)
if len(modules_list) == 0:
log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?')
device = get_module_device(model)
if type(dummy_input) is torch.Tensor:
actual_input = [dummy_input]
elif isinstance(dummy_input, (tuple, list)):
actual_input = list(dummy_input)
else:
log.error(f'Unsupported type {type(dummy_input)} for dummy input')
assert False
for i in range(len(actual_input)):
dummy_input = actual_input[i]
if type(dummy_input) is torch.Tensor:
if dummy_input.device != device:
actual_input[i] = dummy_input.to(device)
model(*actual_input)
# Restore fake-quantize and record activation with quantization error.
for m, v in fake_quant_enabled_dict.items():
m.fake_quant_enabled = v
float_results = results
results = {}
model(*actual_input)
for h in hooks:
h.remove()
hooks.clear()
q_errors_activ = []
for name, f_tensor in float_results.items():
assert name in results, f'{name} not in results'
actual_n = '.'.join(name.split('.')[:-1])
loss = metric_fn(f_tensor[0], results[name][0])
if not name.endswith('.weight_fake_quant'):
q_errors_activ.append((actual_n, modules_list[name], loss))
error_print(metric, q_errors_activ, [], '')
for m, v in observer_enabled_dict.items():
m.observer_enabled = v
if train_flag:
model.train()
def get_weight_dis(
model: nn.Module,
unique_name_list: List[str] = None,
nbins=256,
save_path: str = 'out',
threshold=20,
fig_size=(7, 7),
):
"""Draw the weight distribution of model
Args:
model: We recommend use ptq-prepared model to draw fused weight distribution
unique_name_list: You can set the layer which you want to get distribution, default to all layer of model
nbins: Bins of distribution, default to be 256
save_path: Weight distribution fig weill saved at "[save_path]/weight_distribution"
threshold: The threshold of weight range to used to prompt anomalies
fig_size: Set fig size
"""
with torch.no_grad():
save_dir = os.path.join(save_path, 'weight_distribution')
log.info(f"jpgs will saved at {save_dir}")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
warning_layer = dict()
for name, mod in model.named_modules():
if (not hasattr(mod, 'weight')) or isinstance(mod, nn.BatchNorm2d):
continue
if unique_name_list is None or name in unique_name_list:
op_type = type(mod).__name__
x = mod.weight.cpu()
if op_type in dir(torch.nn.intrinsic.qat) and hasattr(mod, 'bn'):
# Use torch.nn.util.fusion.fuse_conv_bn_weights to caculate bn_fused conv's weight.
bn_var_rsqrt = torch.rsqrt(mod.bn.running_var + mod.bn.eps)
x = mod.weight * (mod.bn.weight * bn_var_rsqrt).reshape([-1] + [1] * (len(mod.weight.shape) - 1))
x = x.cpu()
y = torch.histc(x, nbins)
x_min = torch.min(x)
x_max = torch.max(x)
if x_max - x_min > threshold:
warning_layer[name] = (op_type, float(x_min), float(x_max))
bin_width = (x_max - x_min) / nbins
x_s = [x_min + (idx + 0.5) * bin_width for idx in range(nbins)]
fig, ax = plt.subplots(figsize=fig_size)
ax.set_yscale('log')
ax.plot(x_s, y.detach().numpy())
ax.set_title(f'Op_uname: {name}[{op_type}]')
ax.set_xlabel(f'Range:[{x_min:.4f},{x_max:.4f}]')
ax.set_ylabel('Count')
save_path = os.path.join(save_dir, f'{name}.jpg')
plt.savefig(save_path)
plt.cla()
if warning_layer:
log_str = f'\n---------the layer weight range length greater than {threshold}---------\n'
for k, v in warning_layer.items():
log_str += f'{k}, {v}\n'
log_str += '---------------------------------------------------------------'
log.warning(log_str)