Skip to content

Commit

Permalink
Data utility
Browse files Browse the repository at this point in the history
  • Loading branch information
vinhkhuc committed Feb 13, 2017
1 parent 5b4b39e commit 6cc505c
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions data_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import gzip
import os
from os import path
import urllib
import numpy as np

DATASET_DIR = 'datasets/'


def download_file(url, local_path):
dir_path = path.dirname(local_path)
if not path.exists(dir_path):
print("Creating the directory '%s' ..." % dir_path)
os.makedirs(dir_path)

print("Downloading from '%s' ..." % url)
urllib.URLopener().retrieve(url, local_path)


def download_mnist(local_path):
url_root = "http://yann.lecun.com/exdb/mnist/"
for f_name in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
f_path = os.path.join(local_path, f_name)
if not path.exists(f_path):
download_file(url_root + f_name, f_path)


def one_hot(x, n):
if type(x) == list:
x = np.array(x)
x = x.flatten()
o_h = np.zeros((len(x), n))
o_h[np.arange(len(x)), x] = 1
return o_h


def load_mnist(ntrain=60000, ntest=10000, onehot=True):
data_dir = os.path.join(DATASET_DIR, 'mnist/')
if not path.exists(data_dir):
download_mnist(data_dir)

with gzip.open(os.path.join(data_dir, 'train-images-idx3-ubyte.gz')) as fd:
buf = fd.read()
loaded = np.frombuffer(buf, dtype=np.uint8)
trX = loaded[16:].reshape((60000, 28 * 28)).astype(float)

with gzip.open(os.path.join(data_dir, 'train-labels-idx1-ubyte.gz')) as fd:
buf = fd.read()
loaded = np.frombuffer(buf, dtype=np.uint8)
trY = loaded[8:].reshape((60000))

with gzip.open(os.path.join(data_dir, 't10k-images-idx3-ubyte.gz')) as fd:
buf = fd.read()
loaded = np.frombuffer(buf, dtype=np.uint8)
teX = loaded[16:].reshape((10000, 28 * 28)).astype(float)

with gzip.open(os.path.join(data_dir, 't10k-labels-idx1-ubyte.gz')) as fd:
buf = fd.read()
loaded = np.frombuffer(buf, dtype=np.uint8)
teY = loaded[8:].reshape((10000))

trX /= 255.
teX /= 255.

trX = trX[:ntrain]
trY = trY[:ntrain]

teX = teX[:ntest]
teY = teY[:ntest]

if onehot:
trY = one_hot(trY, 10)
teY = one_hot(teY, 10)
else:
trY = np.asarray(trY)
teY = np.asarray(teY)

return trX, teX, trY, teY

0 comments on commit 6cc505c

Please sign in to comment.