-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheck_sgd_deconv_gmm.py
99 lines (80 loc) · 2.39 KB
/
check_sgd_deconv_gmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from deconv.gmm.plotting import plot_covariance
from deconv.gmm.sgd_deconv_gmm import SGDDeconvGMM
from deconv.gmm.data import DeconvDataset
from data import generate_data
def check_sgd_deconv_gmm(D, K, N, plot=False, verbose=False, device=None):
if not device:
device = torch.device('cpu')
data, params = generate_data(D, K, N)
X_train, nc_train, X_test, nc_test = data
means, covars = params
train_data = DeconvDataset(
torch.Tensor(X_train.reshape(-1, D).astype(np.float32)),
torch.Tensor(
nc_train.reshape(-1, D, D).astype(np.float32)
)
)
test_data = DeconvDataset(
torch.Tensor(X_test.reshape(-1, D).astype(np.float32)),
torch.Tensor(
nc_test.reshape(-1, D, D).astype(np.float32)
)
)
gmm = SGDDeconvGMM(
K,
D,
device=device,
batch_size=250,
epochs=200,
restarts=1,
lr=1e-1
)
gmm.fit(train_data, val_data=test_data, verbose=verbose)
train_score = gmm.score_batch(train_data)
test_score = gmm.score_batch(test_data)
print('Training score: {}'.format(train_score))
print('Test score: {}'.format(test_score))
if plot:
fig, ax = plt.subplots()
ax.plot(gmm.train_loss_curve, label='Training Loss')
ax.plot(gmm.val_loss_curve, label='Validation Loss')
fig, ax = plt.subplots()
for i in range(K):
sc = ax.scatter(
X_train[:, i, 0],
X_train[:, i, 1],
alpha=0.2,
marker='x',
label='Cluster {}'.format(i)
)
plot_covariance(
means[i, :],
covars[i, :, :],
ax,
color=sc.get_facecolor()[0]
)
sc = ax.scatter(
gmm.means[:, 0],
gmm.means[:, 1],
marker='+',
label='Fitted Gaussians'
)
for i in range(K):
plot_covariance(
gmm.means[i, :],
gmm.covars[i, :, :],
ax,
color=sc.get_facecolor()[0]
)
ax.legend()
plt.show()
if __name__ == '__main__':
sns.set()
D = 2
K = 3
N = 500
check_sgd_deconv_gmm(D, K, N, verbose=True, plot=True)