-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
92 lines (75 loc) · 3.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import numpy as np
def plot_original_clean(model, X, Y, output_name):
try:
os.mkdir('images')
except:
pass
print('Plotting results for a few test samples...')
for index in range(len(X)):
start = time.time()
prediction = model.predict(X[index][None,...])
end = time.time()
print(f'Inference time: {end - start} s')
# Plot result
comparison = np.append(np.flip(prediction.squeeze(), axis=1), Y[index].squeeze(), axis=1)
max_abs = np.amax(np.abs(comparison))
plt.matshow(comparison.T, cmap='seismic', interpolation='none', vmin = -max_abs, vmax = max_abs)
plt.title(f'Denoised range: [{np.amin(prediction)},{np.amax(prediction)}]\n Original range: [{np.amin(Y[index])},{np.amax(Y[index])}]')
plt.annotate('Clean', (0.1,0.7), xycoords='figure fraction', size=20, color='White')
plt.annotate('Original', (0.1,0.2), xycoords='figure fraction', size=20, color='White')
plt.axis('off')
plt.colorbar()
plt.savefig(f'images/{index}_{output_name}.png', dpi=200, bbox_inches='tight')
plt.close()
def plot_errors(model, X, Y, output_name):
try:
os.mkdir('errors')
except:
pass
print('Plotting errors for a few test samples...')
for index in range(len(X)):
start = time.time()
prediction = model.predict(X[index][None,...])
end = time.time()
print(f'Inference time: {end - start} s')
error = (prediction.squeeze() - Y[index].squeeze())**2
# Plot result
comparison = np.append(np.flip(error, axis=1), error, axis=1)
plt.matshow(comparison.T, cmap='hot', interpolation='none')
plt.title(f'Denoised range: [{np.amin(prediction)},{np.amax(prediction)}]\n Original range: [{np.amin(Y[index])},{np.amax(Y[index])}]')
plt.axis('off')
plt.colorbar()
plt.savefig(f'errors/{index}_{output_name}.png', dpi=200, bbox_inches='tight')
plt.close()
def plot_history(model, output_name):
try:
os.mkdir('losses')
except:
pass
history=np.load(f'./checkpoints/{output_name}/history.npy',allow_pickle='TRUE').item()
pd.DataFrame(history).plot(logy=True)
plt.grid()
plt.savefig(f'losses/{output_name}.png', dpi=200, bbox_inches='tight')
plt.close()
class plot_losses(tf.keras.callbacks.Callback):
def __init__(self):
self.training_loss = []
self.valid_loss = []
def on_epoch_end(self, epoch, logs = None):
self.training_loss.append(logs['loss'])
self.valid_loss.append(logs['val_loss'])
it = np.arange(1, len(self.training_loss) + 1, 1)
plt.semilogy(it, self.training_loss, label='Training')
plt.semilogy(it, self.valid_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('Losses.png', dpi=200)
plt.close()