-
-
Notifications
You must be signed in to change notification settings - Fork 199
/
Copy pathcapsgnn.py
333 lines (295 loc) · 13.5 KB
/
capsgnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""CapsGNN Trainer."""
import glob
import json
import random
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from torch_geometric.nn import GCNConv
from utils import create_numeric_mapping
from layers import ListModule, PrimaryCapsuleLayer, Attention, SecondaryCapsuleLayer
from layers import margin_loss
class CapsGNN(torch.nn.Module):
"""
An implementation of themodel described in the following paper:
https://openreview.net/forum?id=Byl8BnRcYm
"""
def __init__(self, args, number_of_features, number_of_targets):
super(CapsGNN, self).__init__()
"""
:param args: Arguments object.
:param number_of_features: Number of vertex features.
:param number_of_targets: Number of classes.
"""
self.args = args
self.number_of_features = number_of_features
self.number_of_targets = number_of_targets
self._setup_layers()
def _setup_base_layers(self):
"""
Creating GCN layers.
"""
self.base_layers = [GCNConv(self.number_of_features, self.args.gcn_filters)]
for _ in range(self.args.gcn_layers-1):
self.base_layers.append(GCNConv(self.args.gcn_filters, self.args.gcn_filters))
self.base_layers = ListModule(*self.base_layers)
def _setup_primary_capsules(self):
"""
Creating primary capsules.
"""
self.first_capsule = PrimaryCapsuleLayer(in_units=self.args.gcn_filters,
in_channels=self.args.gcn_layers,
num_units=self.args.gcn_layers,
capsule_dimensions=self.args.capsule_dimensions)
def _setup_attention(self):
"""
Creating attention layer.
"""
self.attention = Attention(self.args.gcn_layers*self.args.capsule_dimensions,
self.args.inner_attention_dimension)
def _setup_graph_capsules(self):
"""
Creating graph capsules.
"""
self.graph_capsule = SecondaryCapsuleLayer(self.args.gcn_layers,
self.args.capsule_dimensions,
self.args.number_of_capsules,
self.args.capsule_dimensions)
def _setup_class_capsule(self):
"""
Creating class capsules.
"""
self.class_capsule = SecondaryCapsuleLayer(self.args.capsule_dimensions,
self.args.number_of_capsules,
self.number_of_targets,
self.args.capsule_dimensions)
def _setup_reconstruction_layers(self):
"""
Creating histogram reconstruction layers.
"""
self.reconstruction_layer_1 = torch.nn.Linear(self.number_of_targets*self.args.capsule_dimensions,
int((self.number_of_features*2)/3))
self.reconstruction_layer_2 = torch.nn.Linear(int((self.number_of_features*2)/3),
int((self.number_of_features*3)/2))
self.reconstruction_layer_3 = torch.nn.Linear(int((self.number_of_features*3)/2),
self.number_of_features)
def _setup_layers(self):
"""
Creating layers of model.
1. GCN layers.
2. Primary capsules.
3. Attention
4. Graph capsules.
5. Class capsules.
6. Reconstruction layers.
"""
self._setup_base_layers()
self._setup_primary_capsules()
self._setup_attention()
self._setup_graph_capsules()
self._setup_class_capsule()
self._setup_reconstruction_layers()
def calculate_reconstruction_loss(self, capsule_input, features):
"""
Calculating the reconstruction loss of the model.
:param capsule_input: Output of class capsule.
:param features: Feature matrix.
:return reconstrcution_loss: Loss of reconstruction.
"""
v_mag = torch.sqrt((capsule_input**2).sum(dim=1))
_, v_max_index = v_mag.max(dim=0)
v_max_index = v_max_index.data
capsule_masked = torch.autograd.Variable(torch.zeros(capsule_input.size()))
capsule_masked[v_max_index, :] = capsule_input[v_max_index, :]
capsule_masked = capsule_masked.view(1, -1)
feature_counts = features.sum(dim=0)
feature_counts = feature_counts/feature_counts.sum()
reconstruction_output = torch.nn.functional.relu(self.reconstruction_layer_1(capsule_masked))
reconstruction_output = torch.nn.functional.relu(self.reconstruction_layer_2(reconstruction_output))
reconstruction_output = torch.softmax(self.reconstruction_layer_3(reconstruction_output), dim=1)
reconstruction_output = reconstruction_output.view(1, self.number_of_features)
reconstruction_loss = torch.sum((features-reconstruction_output)**2)
return reconstruction_loss
def forward(self, data):
"""
Forward propagation pass.
:param data: Dictionary of tensors with features and edges.
:return class_capsule_output: Class capsule outputs.
"""
features = data["features"]
edges = data["edges"]
hidden_representations = []
for layer in self.base_layers:
features = torch.nn.functional.relu(layer(features, edges))
hidden_representations.append(features)
hidden_representations = torch.cat(tuple(hidden_representations))
hidden_representations = hidden_representations.view(1, self.args.gcn_layers, self.args.gcn_filters, -1)
first_capsule_output = self.first_capsule(hidden_representations)
first_capsule_output = first_capsule_output.view(-1, self.args.gcn_layers*self.args.capsule_dimensions)
rescaled_capsule_output = self.attention(first_capsule_output)
rescaled_first_capsule_output = rescaled_capsule_output.view(-1, self.args.gcn_layers,
self.args.capsule_dimensions)
graph_capsule_output = self.graph_capsule(rescaled_first_capsule_output)
reshaped_graph_capsule_output = graph_capsule_output.view(-1, self.args.capsule_dimensions,
self.args.number_of_capsules)
class_capsule_output = self.class_capsule(reshaped_graph_capsule_output)
class_capsule_output = class_capsule_output.view(-1, self.number_of_targets*self.args.capsule_dimensions)
class_capsule_output = torch.mean(class_capsule_output, dim=0).view(1,
self.number_of_targets,
self.args.capsule_dimensions)
recon = class_capsule_output.view(self.number_of_targets, self.args.capsule_dimensions)
reconstruction_loss = self.calculate_reconstruction_loss(recon, data["features"])
return class_capsule_output, reconstruction_loss
class CapsGNNTrainer(object):
"""
CapsGNN training and scoring.
"""
def __init__(self, args):
"""
:param args: Arguments object.
"""
self.args = args
self.setup_model()
def enumerate_unique_labels_and_targets(self):
"""
Enumerating the features and targets in order to setup weights later.
"""
print("\nEnumerating feature and target values.\n")
ending = "*.json"
self.train_graph_paths = glob.glob(self.args.train_graph_folder+ending)
self.test_graph_paths = glob.glob(self.args.test_graph_folder+ending)
graph_paths = self.train_graph_paths + self.test_graph_paths
targets = set()
features = set()
for path in tqdm(graph_paths):
data = json.load(open(path))
targets = targets.union(set([data["target"]]))
features = features.union(set(data["labels"]))
self.target_map = create_numeric_mapping(targets)
self.feature_map = create_numeric_mapping(features)
self.number_of_features = len(self.feature_map)
self.number_of_targets = len(self.target_map)
def setup_model(self):
"""
Enumerating labels and initializing a CapsGNN.
"""
self.enumerate_unique_labels_and_targets()
self.model = CapsGNN(self.args, self.number_of_features, self.number_of_targets)
def create_batches(self):
"""
Batching the graphs for training.
"""
self.batches = []
for i in range(0, len(self.train_graph_paths), self.args.batch_size):
self.batches.append(self.train_graph_paths[i:i+self.args.batch_size])
def create_data_dictionary(self, target, edges, features):
"""
Creating a data dictionary.
:param target: Target vector.
:param edges: Edge list tensor.
:param features: Feature tensor.
"""
to_pass_forward = dict()
to_pass_forward["target"] = target
to_pass_forward["edges"] = edges
to_pass_forward["features"] = features
return to_pass_forward
def create_target(self, data):
"""
Target createn based on data dicionary.
:param data: Data dictionary.
:return : Target vector.
"""
return torch.FloatTensor([0.0 if i != data["target"] else 1.0 for i in range(self.number_of_targets)])
def create_edges(self, data):
"""
Create an edge matrix.
:param data: Data dictionary.
:return : Edge matrix.
"""
edges = [[edge[0], edge[1]] for edge in data["edges"]]
edges = edges + [[edge[1], edge[0]] for edge in data["edges"]]
return torch.t(torch.LongTensor(edges))
def create_features(self, data):
"""
Create feature matrix.
:param data: Data dictionary.
:return features: Matrix of features.
"""
features = np.zeros((len(data["labels"]), self.number_of_features))
node_indices = [node for node in range(len(data["labels"]))]
feature_indices = [self.feature_map[label] for label in data["labels"].values()]
features[node_indices, feature_indices] = 1.0
features = torch.FloatTensor(features)
return features
def create_input_data(self, path):
"""
Creating tensors and a data dictionary with Torch tensors.
:param path: path to the data JSON.
:return to_pass_forward: Data dictionary.
"""
data = json.load(open(path))
target = self.create_target(data)
edges = self.create_edges(data)
features = self.create_features(data)
to_pass_forward = self.create_data_dictionary(target, edges, features)
return to_pass_forward
def fit(self):
"""
Training a model on the training set.
"""
print("\nTraining started.\n")
self.model.train()
optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay)
for _ in tqdm(range(self.args.epochs), desc="Epochs: ", leave=True):
random.shuffle(self.train_graph_paths)
self.create_batches()
losses = 0
self.steps = trange(len(self.batches), desc="Loss")
for step in self.steps:
accumulated_losses = 0
optimizer.zero_grad()
batch = self.batches[step]
for path in batch:
data = self.create_input_data(path)
prediction, reconstruction_loss = self.model(data)
loss = margin_loss(prediction,
data["target"],
self.args.lambd)
loss = loss+self.args.theta*reconstruction_loss
accumulated_losses = accumulated_losses + loss
accumulated_losses = accumulated_losses/len(batch)
accumulated_losses.backward()
optimizer.step()
losses = losses + accumulated_losses.item()
average_loss = losses/(step + 1)
self.steps.set_description("CapsGNN (Loss=%g)" % round(average_loss, 4))
def score(self):
"""
Scoring on the test set.
"""
print("\n\nScoring.\n")
self.model.eval()
self.predictions = []
self.hits = []
for path in tqdm(self.test_graph_paths):
data = self.create_input_data(path)
prediction, _ = self.model(data)
prediction_mag = torch.sqrt((prediction**2).sum(dim=2))
_, prediction_max_index = prediction_mag.max(dim=1)
prediction = prediction_max_index.data.view(-1).item()
self.predictions.append(prediction)
self.hits.append(data["target"][prediction] == 1.0)
print("\nAccuracy: " + str(round(np.mean(self.hits), 4)))
def save_predictions(self):
"""
Saving the test set predictions.
"""
identifiers = [path.split("/")[-1].strip(".json") for path in self.test_graph_paths]
out = pd.DataFrame()
out["id"] = identifiers
out["predictions"] = self.predictions
out.to_csv(self.args.prediction_path, index=None)