-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
72 lines (58 loc) · 3.12 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.distributions as td
def VAE_loss(rec_o, o, mu, log_var, beta=1.0):
BCE = torch.nn.functional.binary_cross_entropy(rec_o, o, reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE, beta*KLD
def kl_divergence(mu_1, var_1, mu_2, var_2, dim=1):
p = td.Independent(td.Normal(mu_1, torch.sqrt(var_1)), dim)
q = td.Independent(td.Normal(mu_2, torch.sqrt(var_2)), dim)
div = td.kl_divergence(p, q)
div = torch.max(div, div.new_full(div.size(), 3))
return torch.mean(div)
def kl_divergence_balance(mu_1, var_1, mu_2, var_2, alpha=0.8, dim=1):
p = td.Independent(td.Normal(mu_1, torch.sqrt(var_1)), dim)
p_stop_grad = td.Independent(td.Normal(mu_1.detach(), torch.sqrt(var_1.detach())), dim)
q = td.Independent(td.Normal(mu_2, torch.sqrt(var_2)), dim)
q_stop_grad = td.Independent(td.Normal(mu_2.detach(), torch.sqrt(var_2.detach())), dim)
div = alpha * td.kl_divergence(p_stop_grad, q) + (1 - alpha) * td.kl_divergence(p, q_stop_grad)
div = torch.max(div, div.new_full(div.size(), 3))
return torch.mean(div)
def loss_loglikelihood(mu, target, var, dim):
normal_dist = torch.distributions.Independent(torch.distributions.Normal(mu, var), dim)
return torch.mean(normal_dist.log_prob(target))
def loglikelihood_analitical_loss(mu, target, std):
loss = 0.5 * ((mu - target) / std).pow(2) + torch.log(std)
return torch.mean(loss)
def contrastive_loss(z_next, z_neg, hinge=0.5):
dist = torch.nn.functional.mse_loss(z_next, z_neg)
zeros = torch.zeros_like(dist)
neg_loss = torch.max(zeros, hinge - dist)
return torch.mean(neg_loss)
def mse_loss(x, x_target):
dist = torch.nn.functional.mse_loss(x, x_target)
return torch.mean(dist)
def bce_loss(x, x_target):
bce = torch.nn.functional.binary_cross_entropy(x, x_target, reduction='sum')
return bce
def temp_coherence(sd, a, alpha=2.0):
loss = torch.exp(- alpha * torch.norm(sd, p=2, dim=1) * torch.norm(a, p=2, dim=1))
return torch.mean(loss)
def causality(s1, s2, a1, a2, beta=10.0):
loss = torch.exp(-torch.norm(s1 - s2, p=2, dim=1)**2) * torch.exp(- beta * torch.norm(a1 - a2, p=2, dim=1)**2)
return torch.mean(loss)
def proportionality(sd1, sd2, a1, a2, beta=10.0):
loss = (torch.norm(sd2, p=2, dim=1) - torch.norm(sd1, p=2, dim=1))**2 * torch.exp(-beta * torch.norm(a1 - a2, p=2, dim=1)**2)
return torch.mean(loss)
def repeatability(s1, s2, sd1, sd2, a1, a2, beta=10.0):
loss = torch.norm(sd2 - sd1, p=2, dim=1)**2 * torch.exp(-torch.norm(s1 - s2, p=2, dim=1)**2) * torch.exp(-beta * torch.norm(a1 - a2, p=2, dim=1)**2)
return torch.mean(loss)
def bisimulation_loss(z1, z2, r1, r2, mu1, std1, mu2, std2):
a = torch.nn.functional.smooth_l1_loss(z1, z2, reduction='none').mean(dim=1).view(-1, 1)
b = torch.nn.functional.smooth_l1_loss(r1, r2, reduction='none')
c = Wasserstein2(mu1, std1, mu2, std2).mean(dim=1).view(-1, 1)
loss = torch.square(a - b - c)
return torch.mean(loss)
def Wasserstein2(mu1, std1, mu2, std2, gamma=0.99):
loss = gamma * torch.sqrt((mu1 - mu2).pow(2) + (std1 - std2).pow(2))
return loss