forked from mqtlam/tsne-d3-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
126 lines (92 loc) · 2.95 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from sklearn.manifold import TSNE
import numpy as np
from matplotlib import cm
import matplotlib
import pandas as pd
import argparse
def get_colors(n):
cmap = cm.get_cmap('seismic', n)
return np.array(
[matplotlib.colors.rgb2hex(cmap(i)[:3]) for i in range(cmap.N)])
def read_input_csv(file_path, has_header, has_target):
if has_header:
df = pd.read_csv(file_path, header=0)
else:
df = pd.read_csv(file_path, header=None)
if has_target:
initial_columns = ['image', 'target']
else:
initial_columns = ['image']
data = np.array(df.iloc[:, len(initial_columns):])
df = df.iloc[:, :len(initial_columns)]
df.columns = initial_columns
df['data'] = [list(d) for d in data]
df.image = df.image.apply(lambda x: '/' + x)
return df
def calculate_tsne(data):
tsne = TSNE(n_components=2, verbose=1)
return tsne.fit_transform(data)
def get_args():
# parse arguments
parser = argparse.ArgumentParser(description='T-SNE Data visualizer')
parser.add_argument('input_data', type=str, help='Path to CSV data file')
parser.add_argument(
'--max_num_points',
type=int,
default=1000,
help='Max number of data to consider')
parser.add_argument(
'--targets',
dest='has_target',
action='store_true',
help='CSV data file contains target information')
parser.add_argument(
'-s',
'--server',
dest='is_server',
action='store_true',
help='Start a server')
parser.add_argument(
'--header',
dest='has_header',
action='store_true',
help='Flag indicating CSV data file contains header info')
return parser.parse_args()
def write_csv(path, color, image_name, x, y):
df = pd.DataFrame({'color': []})
df['color'] = color
df['image_name'] = image_name
df['x'] = x
df['y'] = y
df.to_csv(path, index=False)
def start_server():
from http import server
import socketserver
import os
web_dir = os.path.join(os.path.dirname(__file__), 'web')
os.chdir(web_dir)
port = 8000
Handler = server.SimpleHTTPRequestHandler
httpd = socketserver.TCPServer(("", port), Handler)
print("serving at port", port)
httpd.serve_forever()
def main():
args = get_args()
df = read_input_csv(args.input_data, args.has_header, args.has_target)
# TODO: do using targets???
# apply the permutations
df.apply(np.random.permutation)
maxn = len(df) if args.max_num_points > len(df) else args.max_num_points
df = df[:maxn]
rdata = calculate_tsne(np.array(df.data.tolist()))
x = rdata[:, 0]
y = rdata[:, 1]
if not args.has_target:
colors = ["red"] * len(df)
else:
colors = get_colors(len(np.unique(
df.target.tolist())))[df.target.tolist()]
write_csv('web/data.csv', colors, df.image.tolist(), x, y)
if args.is_server:
start_server()
main()