forked from adurukan/GNNs-with-MLOps
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
33 lines (25 loc) · 946 Bytes
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from helper import get_data, test_data, report_training_accuracy
# from gat_net import GAT
# from torch_geometric.nn import GATConv
"""
In this script, model will be selected and it will be evaluated with the desired data.
"""
with open("logger.txt", "w") as outfile:
outfile.write("evaluate.py -> Imports are successful.")
data_list = get_data(test_data, "test_data")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("models/gat_300_2").to(device)
acc_graph = {}
def evaluate(data):
"""Script to evaluate accuracies."""
model.eval()
out, accs = model(data.x, data.edge_index), []
acc = float((out.argmax(-1) == data.y).sum() / data.y.shape[0])
accs.append(acc)
return accs
if __name__ == "__main__":
for data, i in zip(data_list, range(len(data_list))):
accs = evaluate(data)
acc_graph[i] = accs
report_training_accuracy(acc_graph)