-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
426 lines (354 loc) · 16.6 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
# Copyright (c) 2022 ONERA, Magellium and IMT, Romain Thoreau, Laurent Risser, Véronique Achard, Béatrice Berthelot, Xavier Briottet.
# Script to compute quantitative metrics and to reproduce figure 6
# Part of the following code is under the following license:
# Copyright 2018 Ubisoft La Forge Authors. All rights reserved.
import numpy as np
from pyitlib import discrete_random_variable as drv
from sklearn.preprocessing import minmax_scale
def get_mutual_information(x, y, normalize=True):
''' Compute mutual information between two random variables
:param x: random variable
:param y: random variable
'''
if normalize:
return drv.information_mutual_normalised(x, y, norm_factor='Y', cartesian_product=True)
else:
return drv.information_mutual(x, y, cartesian_product=True)
def jemmig(factors, codes, continuous_factors=True, nb_bins=10):
''' JEMMIG metric from K. Do and T. Tran,
“Theory and evaluation metrics for learning disentangled representations,”
in ICLR, 2020.
:param factors: dataset of factors
each column is a factor and each line is a data point
:param codes: latent codes associated to the dataset of factors
each column is a latent code and each line is a data point
:param continuous_factors: True: factors are described as continuous variables
False: factors are described as discrete variables
:param nb_bins: number of bins to use for discretization
'''
# count the number of factors and latent codes
nb_factors = factors.shape[1]
nb_codes = codes.shape[1]
# quantize factors if they are continuous
if continuous_factors:
factors = minmax_scale(factors) # normalize in [0, 1] all columns
factors = get_bin_index(factors, nb_bins) # quantize values and get indexes
# quantize latent codes
codes = minmax_scale(codes) # normalize in [0, 1] all columns
codes = get_bin_index(codes, nb_bins) # quantize values and get indexes
# compute mutual information matrix
mi_matrix = np.zeros((nb_factors, nb_codes))
for f in range(nb_factors):
for c in range(nb_codes):
mi_matrix[f, c] = get_mutual_information(factors[:, f], codes[:, c], normalize=False)
# compute joint entropy matrix
je_matrix = np.zeros((nb_factors, nb_codes))
for f in range(nb_factors):
for c in range(nb_codes):
X = np.stack((factors[:, f], codes[:, c]), 0)
je_matrix[f, c] = drv.entropy_joint(X)
# compute the mean gap for all factors
sum_gap = 0
jemmig_scores = []; je = []; gap = []
for f in range(nb_factors):
mi_f = np.sort(mi_matrix[f, :])
je_idx = np.argsort(mi_matrix[f, :])[-1]; je.append(je_matrix[f, je_idx])
gap.append(mi_f[-1] - mi_f[-2])
jemmig_not_normalized = je_matrix[f, je_idx] - mi_f[-1] + mi_f[-2]
# normalize by H(f) + log(#bins)
jemmig_f = jemmig_not_normalized / (drv.entropy_joint(factors[:, f]) + np.log2(nb_bins))
jemmig_f = 1 - jemmig_f
jemmig_scores.append(jemmig_f)
sum_gap += jemmig_f
# compute the mean gap
jemmig_score = sum_gap / nb_factors
return jemmig_score, jemmig_scores, je, gap
def get_bin_index(x, nb_bins):
''' Discretize input variable
:param x: input variable
:param nb_bins: number of bins to use for discretization
'''
# get bins limits
bins = np.linspace(0, 1, nb_bins + 1)
# discretize input variable
return np.digitize(x, bins[:-1], right=False).astype(int)
#==================== Following code is original ========================
import torch
import torch.nn.functional as F
from models.utils import sam_, one_hot
from data import SimulatedDataSet
from models.model_loader import load_model
from sklearn.metrics import f1_score, confusion_matrix, classification_report
import json
import matplotlib.pyplot as plt
import sys
import math
def build_data_under_diff_irradiance(dataset, class_id):
spectra = []
z_phi_list = []
omega_list = []
alpha_list = []
eta_list = []
rho = [torch.from_numpy(rho_).unsqueeze(0) for rho_ in dataset.classes[class_id]['spectrum']]
if len(rho) > 1:
z_phi = torch.linspace(0, 1, math.ceil(1e4/(10*len(rho)*10)))
else:
z_phi = torch.linspace(0, 1, math.ceil(1e4/(10*len(rho))))
omega = torch.linspace(0.2, 1, 10)
E_dir = torch.from_numpy(dataset.E_dir)
E_dif = torch.from_numpy(dataset.E_dif)
theta = torch.tensor([dataset.theta])
if len(rho) > 1:
for k in range(len(rho)):
for alpha in torch.linspace(0, 1, 10):
rho_ = (1-alpha)*rho[k] + alpha*rho[(k+1)%len(rho)]
for z1 in z_phi:
for O in omega:
cochise_correction = (z1*E_dir + O*E_dif)/(torch.cos(theta)*E_dir + E_dif)
sp = rho_*cochise_correction
spectra.append(sp)
z_phi_list.append(z1)
omega_list.append(O)
alpha_list.append(alpha)
eta_list.append(torch.tensor([k+(k+1)%len(rho)]))
else:
rho_ = rho[0]
for z1 in z_phi:
for O in omega:
cochise_correction = (z1*E_dir + O*E_dif)/(torch.cos(theta)*E_dir + E_dif)
sp = rho_*cochise_correction
spectra.append(sp)
z_phi_list.append(z1)
omega_list.append(O)
alpha_list.append(torch.ones(1))
eta_list.append(torch.ones(1))
spectra = torch.cat(spectra).float()
z_phi = torch.cat([x.view(1, 1) for x in z_phi_list])
omega = torch.cat([x.view(1, 1) for x in omega_list])
alpha = torch.cat([x.view(1, 1) for x in alpha_list])
eta = torch.cat([x.view(1, 1) for x in alpha_list])
factors = torch.cat((torch.ones((spectra.shape[0], 1))*class_id, z_phi, omega, alpha, eta), dim=-1)
return spectra, factors
def plot_z_true_vs_z_pred(z_true, z_pred, confusion, z_std, class_id, fontsize=20, colors = ['#F2A65A', '#909CC2', '#89A7A7', '#6320EE', '#E94974']):
x=np.linspace(0,1,100)
z_pred = z_pred.numpy()
z_true = z_true.numpy()
z_std = torch.exp(10*z_std).numpy()
fig = plt.figure()
plt.scatter(z_true[confusion==1], z_pred[confusion==1], alpha=0.15, color=colors[3], s=z_std[confusion==1])
plt.scatter(z_true[confusion==0], z_pred[confusion==0], alpha=0.15, color=colors[0], s=z_std[confusion==0])
plt.plot(x, x, lw=2, color=colors[4], label='y=x')
plt.xlabel(r"$\delta_{dir} cos \: \Theta$", fontsize=fontsize)
plt.ylabel(r'$z_P$', fontsize=fontsize)
plt.legend(loc=4, prop={'size': 20})
# plt.show()
# pdb.set_trace()
plt.savefig('./results/simulation/{}/Figures/cos_{}.pdf'.format(config['model'], class_id), dpi=200, bbox_inches='tight', pad_inches=0.05)
def plot_omega_true_vs_omega_pred(omega_true, z_pred, confusion, fontsize=20, colors = ['#F2A65A', '#909CC2', '#89A7A7', '#6320EE', '#E94974']):
x=np.linspace(0,1,100)
z_pred = z_pred.numpy()
omega_true = omega_true.numpy()
fig = plt.figure()
plt.scatter(omega_true[confusion==1], z_pred[confusion==1]+0.2, alpha=0.25, color=colors[3], s=10)
plt.scatter(omega_true[confusion==0], z_pred[confusion==0]+0.2, alpha=0.25, color=colors[0], s=10)
plt.plot(x, x, lw=2, color=colors[4], label='y=x')
plt.xlabel(r"$\Omega$", fontsize=fontsize)
plt.ylabel(r'$\hat{Omega}$', fontsize=fontsize)
plt.legend(loc=4, prop={'size': 20})
plt.show()
# pdb.set_trace()
# plt.savefig('./results/{}/Figures/cos_{}.pdf'.format(config['model'], class_id), dpi=200, bbox_inches='tight', pad_inches=0.05)
def plot_irradiance(z_true, omega_true, confusion, fontsize=20, colors = ['#F2A65A', '#909CC2', '#89A7A7', '#6320EE', '#E94974']):
fig = plt.figure()
plt.scatter(z_pred[confusion==1], omega_true[confusion==1], alpha=0.25, color=colors[3], s=10)
plt.scatter(z_pred[confusion==0], omega_true[confusion==0], alpha=0.25, color=colors[0], s=10)
plt.xlabel(r"$\delta_{dir} cos \: \Theta$", fontsize=fontsize)
plt.ylabel(r'$\Omega$', fontsize=fontsize)
plt.legend(loc=4, prop={'size': 20})
plt.show()
def plot_confusions(dataset, model, spectra, confusion, z_pred_phi, z_pred_eta, z_true, omega_true, logits):
confusion_spectra = spectra[confusion==0]
z_phi = z_pred_phi[confusion==0]
z_eta = z_pred_eta[confusion==0]
omega = z_phi+0.2
z_true = z_true[confusion==0]
omega_true = omega_true[confusion==0]
logits = logits[confusion==0]
for class_id in np.unique(confusion_pred):
rho = dataset.classes[class_id+1]['spectrum']
fig, ax = plt.subplots(1, 4)
plt.title(dataset.classes[class_id+1]['label'])
sp = confusion_spectra[confusion_pred==class_id]
z_phi_ = z_phi[confusion_pred==class_id]
z_eta_ = z_eta[confusion_pred==class_id]
z = torch.cat((z_phi_.unsqueeze(1), z_eta_), dim=-1)
omega_ = omega[confusion_pred==class_id]
y = one_hot(np.array([class_id]*z.shape[0]), model.y_dim)
y_true = one_hot(np.array([true_class_id-1]*z.shape[0]), model.y_dim)
z_true_ = z_true[confusion_pred==class_id]
omega_true_ = omega_true[confusion_pred==class_id]
logits_pred = logits[confusion_pred==class_id]
with torch.no_grad():
x = model.decoder(z, y)
s = model.decoder(z, y_true)
loss_x = (torch.mean(F.mse_loss(sp, x, reduction='none'), dim=-1) + config['lambda_sam']*sam_(sp, x, reduction='none')).mean()
loss_s = (torch.mean(F.mse_loss(sp, s, reduction='none'), dim=-1) + config['lambda_sam']*sam_(sp, s, reduction='none')).mean()
for i in range(sp.shape[0]):
ax[0].plot(sp[i,:], alpha=0.2) # Les spectres de test
for j in range(sp.shape[0]):
ax[1].plot(s[i], alpha=0.2)
ax[2].plot(x[i], alpha=0.2) # La reconstruction avec la classe mal prédite
for rho_ in rho:
rho_ = torch.from_numpy(rho_).unsqueeze(0)
cochise_correction = (z_phi_[i]*E_dir + omega_[i]*E_dif)/(torch.cos(theta)*E_dir + E_dif)
other_class_sp = (rho_*cochise_correction).view(-1)
ax[3].plot(other_class_sp, alpha=0.2)
ax[0].set_ylim(0, 1)
ax[0].set_title('Test spectra of {} confused with {}'.\
format(dataset.classes[true_class_id]['label'], dataset.classes[class_id+1]['label']), fontdict={'fontsize': 8})
ax[1].set_ylim(0, 1)
ax[1].set_title('Reconstructed spectra of {} under predicted irradiance - {:.2f}'.format(dataset.classes[true_class_id]['label'], loss_s), fontdict={'fontsize': 8})
ax[2].set_ylim(0, 1)
ax[2].set_title('Reconstructed spectra of {} under predicted irradiance - {:.2f}'.format(dataset.classes[class_id+1]['label'], loss_x), fontdict={'fontsize': 8})
ax[3].set_ylim(0, 1)
ax[3].set_title('Test spectra of {} under predicted irradiance'.format(dataset.classes[class_id+1]['label']), fontdict={'fontsize': 8})
plt.show()
def write_confusions(pred, true_class_id, report):
confusion = (pred == true_class_id-1).long()
confusion_pred = pred[confusion==0]
unique, counts = np.unique(confusion_pred, return_counts=True)
counts = counts / len(confusion_pred)
for i, class_id in enumerate(unique):
report[true_class_id]['confusions'][int(class_id+1)] = counts[i]
def plot_entropy_std(entropy, z_std, fontsize=20, colors = ['#F2A65A', '#909CC2', '#89A7A7', '#6320EE', '#E94974']):
plt.scatter(entropy, z_std, alpha=0.25, color=colors[3], s=10)
plt.xlabel("Entropy", fontsize=fontsize)
plt.ylabel(r'$z_P$ standard deviation ', fontsize=fontsize)
plt.legend(loc=4, prop={'size': 20})
plt.show()
if __name__ == "__main__":
global_report = []
for k in range(1, len(sys.argv)):
print('Model ', k)
print(sys.argv[k])
results_path = sys.argv[k]
with open(results_path + '/config.json') as f:
config = json.load(f)
config['device'] = 'cpu'
dataset = SimulatedDataSet()
target_names = [dataset.classes[i]['label'] for i in range(len(dataset.classes))][1:]
E_dir = torch.from_numpy(dataset.E_dir)
E_dif = torch.from_numpy(dataset.E_dif)
theta = torch.tensor([dataset.theta])
model = load_model(dataset, config)
checkpoint = torch.load(results_path + '/best_model.pth.tar', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
pred_q, pred_p, labels = [], [], []
factors, codes = [], []
entropy, z_std_ = [], []
report = {}
report['jemmig_score'] = {}; report['gap'] = {}; report['je'] = {}
report['avg_f1_score_p'] = 0
report['avg_f1_score_q'] = 0
for true_class_id in range(1, model.n_classes+1):
report[true_class_id] = {}
report[true_class_id]['confusions'] = {}
spectra, factors_ = build_data_under_diff_irradiance(dataset, true_class_id)
labels_, z_true, omega_true, alpha, eta = torch.split(factors_, 1, dim=1)
labels.extend(labels_.long().numpy().reshape(-1))
if config['model'] == 'ssInfoGAN':
with torch.no_grad():
real_features, _ = model.netD(spectra)
logits = model.netQss(real_features)
pred = torch.argmax(logits, dim=-1)
z_mu, z_var = model.netQus(real_features)
z = z_mu + torch.randn(z_var.shape)*z_var**0.5
codes_ = torch.cat((pred.unsqueeze(1), z), dim=-1)
write_confusions(pred, true_class_id, report)
pred_p.extend(pred.numpy())
pred_q.extend(pred.numpy())
factors.append(factors_)
codes.append(codes_)
elif config['model'] in ['gaussian', 'guided', 'p3VAE', 'guided_no_gs', 'p3VAE_no_gs']:
Lr, pred, z_P_std, random_z_P, random_z_A = model.argmax_p_y_x_batch(spectra, config)
codes_ = torch.cat((pred.unsqueeze(1), random_z_P.unsqueeze(1), random_z_A), dim=-1)
write_confusions(pred, true_class_id, report)
pred_p.extend(pred.numpy())
confusion = (pred == true_class_id-1).long()
plot_z_true_vs_z_pred(z_true.squeeze(1), random_z_P, confusion, z_P_std, true_class_id)
with torch.no_grad():
logits = model.q_y_x_batch(spectra)
pred = torch.argmax(logits, dim=-1)
pred_q.extend(pred.numpy())
factors.append(factors_)
codes.append(codes_)
elif config['model'] in ['CNN', 'CNN_full_annotations']:
logits = model(spectra)
pred = torch.argmax(logits, dim=-1)
pred_q.extend(pred.cpu().numpy())
pred_p.extend(pred.cpu().numpy())
labels = np.array(labels)-1
pred_q = np.array(pred_q)
pred_p = np.array(pred_p)
f1_score_q = f1_score(labels, pred_q, average=None)
f1_score_p = f1_score(labels, pred_p, average=None)
for class_id in range(1, model.n_classes+1):
report[class_id]['f1_score_q'] = f1_score_q[class_id-1]
report[class_id]['f1_score_p'] = f1_score_p[class_id-1]
report['avg_f1_score_q'] += f1_score_q[class_id-1]/len(f1_score_q)
report['avg_f1_score_p'] += f1_score_p[class_id-1]/len(f1_score_p)
if config['model'] in ['CNN', 'CNN_full_annotations']:
global_report.append(report)
else:
factors = torch.cat(factors, dim=0).numpy()
codes = torch.cat(codes, dim=0).numpy()
jemmig_score, jemmig_scores, je, gap = jemmig(factors, codes, nb_bins=20)
for i in range(len(jemmig_scores)):
report['jemmig_score'][i] = jemmig_scores[i]; report['gap'][i] = gap[i]; report['je'][i] = je[i]
report['jemmig_score']['avg_jemmig_score'] = jemmig_score
report['gap']['avg_gap'] = sum(gap) / len(gap)
report['je']['avg_je'] = sum(je) / len(je)
print('Gap: ', gap)
global_report.append(report)
avg_report = {}
avg_report['jemmig_score'] = {}; avg_report['je'] = {}; avg_report['gap'] = {}
for class_id in range(1, len(f1_score_q)+1):
avg_report[class_id] = {}
for metric in report[class_id]:
avg_report[class_id][metric] = 0
try:
for i in range(len(jemmig_scores)):
avg_report['jemmig_score'][i] = 0; avg_report['je'][i] = 0; avg_report['gap'][i] = 0
except:
pass
avg_report['jemmig_score']['avg_jemmig_score'] = 0
avg_report['gap']['avg_gap'] = 0
avg_report['je']['avg_je'] = 0
avg_report['avg_f1_score_q'] = 0
avg_report['avg_f1_score_p'] = 0
avg_report['f1_score_p'] = []; avg_report['f1_score_q'] = []
for report in global_report:
for key in report:
if key in ['avg_f1_score_q', 'avg_f1_score_p']:
avg_report[key] += report[key]/len(global_report)
elif key in ['jemmig_score', 'je', 'gap']:
try:
if key == 'jemmig_score':
avg_report[key]['avg_jemmig_score'] += report[key]['avg_jemmig_score']/len(global_report)
elif key == 'gap':
avg_report[key]['avg_gap'] += report[key]['avg_gap']/len(global_report)
elif key == 'je':
avg_report[key]['avg_je'] += report[key]['avg_je']/len(global_report)
for i in range(len(jemmig_scores)):
avg_report[key][i] += report[key][i]/len(global_report)
except:
pass
else:
for metric in report[key]:
if metric == 'confusions':
avg_report[key][metric] = report[key][metric]
else:
avg_report[key][metric] += report[key][metric]/len(global_report)
avg_report['f1_score_p'].append(report['avg_f1_score_p']) ; avg_report['f1_score_q'].append(report['avg_f1_score_q'])
with open('./results/simulation/{}/classification_report.json'.format(config['model']), 'w') as f:
json.dump(avg_report, f, indent=4)