-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVisualizations.py
56 lines (44 loc) · 1.9 KB
/
Visualizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import random
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_samples(split_set,preds=None):
idx_to_class = {split_set.class_to_idx[k]: k for k in split_set.class_to_idx}
randomlist = random.sample(range(0, len(split_set)), 9)
#print(randomlist)
k=0
figure, ax = plt.subplots(3, 3,constrained_layout = True)
for i in range(3):
for j in range(3):
idx=randomlist[k]
k+=1
ax[i, j].imshow(split_set[idx][0].permute(1,2,0))
ax[i,j].get_xaxis().set_visible(False)
ax[i,j].get_yaxis().set_visible(False)
if ( (preds is None) ):
ax[i, j].set_title(f"{split_set[idx][1]}-{idx_to_class[split_set[idx][1]]}", fontsize=12)
else:
if split_set[idx][1]==preds[idx]:
check="Correctly_Classified"
else:
check="Wrongly_Classified"
ax[i, j].set_title(f"{check} \n Actual:{idx_to_class[split_set[idx][1]]} \n predicted:{idx_to_class[preds[idx]]}"
, fontsize=10)
def plot_confusion(labels, preds,name,num_classes=2,Normalize=False):
if Normalize:
conf=confusion_matrix(labels, preds,normalize='true')
else:
conf=confusion_matrix(labels, preds)
if num_classes==2:
labels_name=['Normal', 'Pneumonia']
elif num_classes==3:
labels_name=['Covid','Normal', 'Pneumonia']
fig, ax = plt.subplots(figsize=(10,6))
ax = sns.heatmap(conf, annot=True,xticklabels=labels_name,yticklabels=labels_name,fmt='.3f',
cmap=sns.cubehelix_palette(as_cmap=True))
font1 = {'family': 'sans-serif','weight': 'bold','color':"sienna",'size': 16}
ax.set_xlabel('Predicted Label',fontdict=font1)
ax.set_ylabel('Actual Label',fontdict=font1)
plt.show()
fig.savefig(f"./Visualizations/{name}.png", dpi=300)