-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathvisualizations.py
69 lines (60 loc) · 1.74 KB
/
visualizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import matplotlib.pyplot as plt
import os
COLORS = [
'tab:blue',
'tab:orange',
'tab:green',
'tab:red',
'tab:purple',
'tab:brown',
'tab:pink',
'tab:gray',
'tab:olive',
'tab:cyan',
]
def show_clusters_centroids(clusters,
centroids,
title,
x_var_indx=0,
y_var_indx=1,
x_var_name='Variable 1',
y_var_name="Variable 2",
keep=False):
"""
Show the current clustering for 1 second and save the plot
Input:
clusters (list of lists of lists): A List of Clusters. Each cluster
is also a list of points in the cluster. SEE: k_means.get_clusters()
centroids (list of lists): A list with the current centroids
title (string): The title for the plot.
"""
for i, cluster in enumerate(clusters):
cluster = np.array(cluster)
plt.scatter(
cluster[:,x_var_indx],
cluster[:,y_var_indx],
c = COLORS[i],
label="Cluster {}".format(i)
)
for i, centroid in enumerate(centroids):
plt.scatter(
centroid[x_var_indx],
centroid[y_var_indx],
c = COLORS[i],
marker='x',
s=100
)
plt.title(title)
plt.xlabel(x_var_name)
plt.ylabel(y_var_name)
plt.legend()
if not os.path.isdir('./images/kmeans_out/'):
os.mkdir('./images/kmeans_out')
plt.savefig("./images/kmeans_out/kmeans_{}.png".format(title))
if not keep:
plt.show(block=False)
plt.pause(1)
plt.close()
else:
plt.show()