-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathmain.py
73 lines (49 loc) · 2.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score
import utils
from model import DSFANet
def main(X, Y, GT, diff):
train_num = 2000
max_iters = 2000
lr = 1e-4
index = np.argsort(diff)
XData = X[index[0:train_num], :]
YData = Y[index[0:train_num], :]
inputX = tf.placeholder(dtype=tf.float32, shape=[None, X.shape[-1]])
inputY = tf.placeholder(dtype=tf.float32, shape=[None, Y.shape[-1]])
model = DSFANet(num=train_num)
loss = model.forward(X=inputX, Y=inputY)
optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)
init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(allow_growth=True)
conf = tf.ConfigProto(gpu_options=gpu_options)
sess = tf.Session(config=conf)
sess.run(init)
train_loss = np.zeros(max_iters)
for k in range(max_iters):
_, train_loss[k] = sess.run([optimizer, loss], feed_dict={inputX: XData, inputY: YData})
if k % 100 == 0:
print('iter %4d, loss is %.4f' % (k, train_loss[k]))
XTest, YTest = sess.run([model.X_, model.Y_], feed_dict={inputX: X, inputY: Y})
sess.close()
X_trans, Y_trans = utils.SFA(XTest, YTest)
diff = X_trans-Y_trans
diff = diff / np.std(diff, axis=0)
plt.imsave('DSFAdiff.png', (diff**2).sum(axis=1).reshape(GT.shape), cmap='gray')
bin = KMeans(n_clusters=2).fit((diff**2).sum(axis=-1, keepdims=True)).labels_
#bin = KMeans(n_clusters=2).fit(diff).labels_
plt.imsave('DSFACD.png', bin.reshape(GT.shape), cmap='gray')
#diff = abs(diff)
#plt.imsave('DSFAcolor.png',(diff/diff.max()).reshape(GT.shape[0], GT.shape[1],3))
print(accuracy_score(GT.reshape(-1, 1)/255, bin))
print(accuracy_score(GT.reshape(-1, 1)/255, 1-bin))
return True
if __name__ == '__main__':
X, Y, GT = utils.load_dataset()
diff = utils.cva(X=X, Y=Y)
plt.imsave('CVAdiff.png', np.reshape(diff, GT.shape), cmap='gray')
main(X, Y, GT, diff)