-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathdata.py
71 lines (59 loc) · 2.29 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gzip
import os
import numpy as np
import six
from six.moves.urllib import request
parent = 'http://yann.lecun.com/exdb/mnist'
train_images = 'train-images-idx3-ubyte.gz'
train_labels = 'train-labels-idx1-ubyte.gz'
test_images = 't10k-images-idx3-ubyte.gz'
test_labels = 't10k-labels-idx1-ubyte.gz'
num_train = 60000
num_test = 10000
dim = 784
def load_mnist(images, labels, num):
data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim))
target = np.zeros(num, dtype=np.uint8).reshape((num, ))
with gzip.open(images, 'rb') as f_images,\
gzip.open(labels, 'rb') as f_labels:
f_images.read(16)
f_labels.read(8)
for i in six.moves.range(num):
target[i] = ord(f_labels.read(1))
for j in six.moves.range(dim):
data[i, j] = ord(f_images.read(1))
return data, target
def download_mnist_data():
print('Downloading {:s}...'.format(train_images))
request.urlretrieve('{:s}/{:s}'.format(parent, train_images), train_images)
print('Done')
print('Downloading {:s}...'.format(train_labels))
request.urlretrieve('{:s}/{:s}'.format(parent, train_labels), train_labels)
print('Done')
print('Downloading {:s}...'.format(test_images))
request.urlretrieve('{:s}/{:s}'.format(parent, test_images), test_images)
print('Done')
print('Downloading {:s}...'.format(test_labels))
request.urlretrieve('{:s}/{:s}'.format(parent, test_labels), test_labels)
print('Done')
print('Converting training data...')
data_train, target_train = load_mnist(train_images, train_labels,
num_train)
print('Done')
print('Converting test data...')
data_test, target_test = load_mnist(test_images, test_labels, num_test)
mnist = {}
mnist['data'] = np.append(data_train, data_test, axis=0)
mnist['target'] = np.append(target_train, target_test, axis=0)
print('Done')
print('Save output...')
with open('mnist.pkl', 'wb') as output:
six.moves.cPickle.dump(mnist, output, -1)
print('Done')
print('Convert completed')
def load_mnist_data():
if not os.path.exists('mnist.pkl'):
download_mnist_data()
with open('mnist.pkl', 'rb') as mnist_pickle:
mnist = six.moves.cPickle.load(mnist_pickle)
return mnist