-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_MLP.py
63 lines (51 loc) · 2.33 KB
/
train_MLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from config import args
import tensorflow as tf
import time
from utils import *
from models import MLP
from metrics import *
# Settings
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
tuple_adj = sparse_to_tuple(adj.tocoo())
adj_tensor = tf.SparseTensor(*tuple_adj)
features = preprocess_features(features)
model = MLP(input_dim=features.shape[1], output_dim=y_train.shape[1], adj=adj_tensor)
optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
features_tensor = tf.convert_to_tensor(features,dtype=tf.float32)
y_train_tensor = tf.convert_to_tensor(y_train,dtype=tf.float32)
train_mask_tensor = tf.convert_to_tensor(train_mask)
y_test_tensor = tf.convert_to_tensor(y_test,dtype=tf.float32)
test_mask_tensor = tf.convert_to_tensor(test_mask)
y_val_tensor = tf.convert_to_tensor(y_val,dtype=tf.float32)
val_mask_tensor = tf.convert_to_tensor(val_mask)
best_test_acc = 0
best_val_acc = 0
best_val_loss = 10000
curr_step = 0
for epoch in range(args.epochs):
with tf.GradientTape() as tape:
output = model.call((features_tensor),training=True)
cross_loss = masked_softmax_cross_entropy(output, y_train_tensor,train_mask_tensor)
lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in model.trainable_variables])
loss = cross_loss + args.weight_decay*lossL2
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
output = model.call((features_tensor), training=False)
train_acc = masked_accuracy(output, y_train_tensor,train_mask_tensor)
val_acc = masked_accuracy(output, y_val_tensor,val_mask_tensor)
val_loss = masked_softmax_cross_entropy(output, y_val_tensor, val_mask_tensor)
test_acc = masked_accuracy(output, y_test_tensor,test_mask_tensor)
if val_acc > best_val_acc:
curr_step = 0
best_test_acc = test_acc
best_val_acc = val_acc
best_val_loss= val_loss
# Print results
else:
curr_step +=1
if curr_step > args.early_stop:
print("Early stopping...")
break
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cross_loss),"val_loss=", "{:.5f}".format(val_loss),
"train_acc=", "{:.5f}".format(val_acc), "val_acc=", "{:.5f}".format(val_acc),
"test_acc=", "{:.5f}".format(best_test_acc))