-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathutils.py
64 lines (55 loc) · 2.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import torch
import torch.nn.functional as F
from Params import args
import random
def cal_loss_r(preds, labels, mask):
loss = torch.sum(torch.square(preds - labels) * mask) / torch.sum(mask)
return loss
def cal_metrics_r(preds, labels, mask):
loss = np.sum(np.square(preds - labels) * mask) / np.sum(mask)
sqLoss = np.sum(np.sum(np.square(preds - labels) * mask, axis=0), axis=0)
absLoss = np.sum(np.sum(np.abs(preds - labels) * mask, axis=0), axis=0)
tstNums = np.sum(np.sum(mask, axis=0), axis=0)
posMask = mask * np.greater(labels, 0.5)
apeLoss = np.sum(np.sum(np.abs(preds - labels) / (labels + 1e-8) * posMask, axis=0), axis=0)
posNums = np.sum(np.sum(posMask, axis=0), axis=0)
return loss, sqLoss, absLoss, tstNums, apeLoss, posNums
def cal_metrics_r_mask(preds, labels, mask, mask_sparsity):
loss = np.sum(np.square(preds - labels) * mask) / np.sum(mask)
sqLoss = np.sum(np.sum(np.square(preds - labels) * mask * mask_sparsity, axis=0), axis=0)
absLoss = np.sum(np.sum(np.abs(preds - labels) * mask * mask_sparsity, axis=0), axis=0)
tstNums = np.sum(np.sum(mask * mask_sparsity, axis=0), axis=0)
posMask = mask * mask_sparsity * np.greater(labels, 0.5)
apeLoss = np.sum(np.sum(np.abs(preds - labels) / (labels + 1e-8) * posMask, axis=0), axis=0)
posNums = np.sum(np.sum(posMask, axis=0), axis=0)
return loss, sqLoss, absLoss, tstNums, apeLoss, posNums
def Informax_loss(DGI_pred, DGI_labels):
BCE_loss = torch.nn.BCEWithLogitsLoss()
loss = BCE_loss(DGI_pred, DGI_labels)
return loss
def infoNCEloss(q, k):
T = args.t
q = q.expand_as(k)
q = q.permute(0, 3, 4, 2, 1)
k = k.permute(0, 3, 4, 2, 1)
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
pos_sim = torch.sum(torch.mul(q, k), dim=-1)
neg_sim = torch.matmul(q, k.transpose(-1, -2))
pos = torch.exp(torch.div(pos_sim, T))
neg = torch.sum(torch.exp(torch.div(neg_sim, T)), dim=-1)
denominator = neg + pos
return torch.mean(-torch.log(torch.div(pos, denominator)))
def seed_torch(seed=523):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def makePrint(name, ep, reses):
ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
for metric in reses:
val = reses[metric]
ret += '%s = %.4f, ' % (metric, val)
ret = ret[:-2] + ' '
return ret