-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata.py
78 lines (56 loc) · 2.29 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
This files mimics keras.dataset download's function.
For parallel and distributed training, we need to account
for multiple processes (one per GPU) per agent.
For more information on data in Determined, read the document for preparing data.
"""
import gzip
import tempfile
import numpy as np
from tensorflow.python.keras.utils.data_utils import get_file
def load_training_data():
"""Loads the Fashion-MNIST dataset.
Returns:
Tuple of Numpy arrays: `(x_train, y_train)`.
License:
The copyright for Fashion-MNIST is held by Zalando SE.
Fashion-MNIST is licensed under the [MIT license](
https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
"""
download_directory = tempfile.mkdtemp()
base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
files = [
"train-labels-idx1-ubyte.gz",
"train-images-idx3-ubyte.gz",
]
paths = []
for fname in files:
paths.append(get_file(fname, origin=base + fname, cache_subdir=download_directory))
with gzip.open(paths[0], "rb") as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], "rb") as imgpath:
x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
return x_train, y_train
def load_validation_data():
"""Loads the Fashion-MNIST dataset.
Returns:
Tuple of Numpy arrays: `(x_test, y_test)`.
License:
The copyright for Fashion-MNIST is held by Zalando SE.
Fashion-MNIST is licensed under the [MIT license](
https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
"""
download_directory = tempfile.mkdtemp()
base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
files = [
"t10k-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
]
paths = []
for fname in files:
paths.append(get_file(fname, origin=base + fname, cache_subdir=download_directory))
with gzip.open(paths[0], "rb") as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], "rb") as imgpath:
x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
return x_test, y_test