-
-
Notifications
You must be signed in to change notification settings - Fork 3.2k
/
Copy pathtfrecords.py
136 lines (110 loc) · 3.95 KB
/
tfrecords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-
"""TFRecords.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1p-Nz6v3CyqKSc-QazX1FgvZkamt5T-uC
"""
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Load MNIST data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocessing
x_train = x_train / 255.0
x_test = x_test / 255.0
# Track the data type
dataType = x_train.dtype
print(f"Data type: {dataType}")
labelType = y_test.dtype
print(f"Data type: {labelType}")
im_list = []
n_samples_to_show = 16
c = 0
for i in range(n_samples_to_show):
im_list.append(x_train[i])
# Visualization
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(4., 4.))
# Ref: https://matplotlib.org/3.1.1/gallery/axes_grid1/simple_axesgrid.html
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(4, 4), # creates 2x2 grid of axes
axes_pad=0.1, # pad between axes in inch.
)
# Show image grid
for ax, im in zip(grid, im_list):
# Iterating over the grid returns the Axes.
ax.imshow(im, 'gray')
plt.show()
# Convert values to compatible tf.Example types.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# Create the features dictionary.
def image_example(image, label, dimension):
feature = {
'dimension': _int64_feature(dimension),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image.tobytes()),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
record_file = 'mnistTrain.tfrecords'
n_samples = x_train.shape[0]
dimension = x_train.shape[1]
with tf.io.TFRecordWriter(record_file) as writer:
for i in range(n_samples):
image = x_train[i]
label = y_train[i]
tf_example = image_example(image, label, dimension)
writer.write(tf_example.SerializeToString())
# Create the dataset object from tfrecord file(s)
dataset = tf.data.TFRecordDataset(record_file)
# Decoding function
def parse_record(record):
name_to_features = {
'dimension': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
return tf.io.parse_single_example(record, name_to_features)
def decode_record(record):
image = tf.io.decode_raw(
record['image_raw'], out_type=dataType, little_endian=True, fixed_length=None, name=None
)
label = record['label']
dimension = record['dimension']
image = tf.reshape(image, (dimension, dimension))
return (image, label)
im_list = []
n_samples_to_show = 16
c = 0
for record in dataset:
c+=1
if c > n_samples_to_show:
break
parsed_record = parse_record(record)
decoded_record = decode_record(parsed_record)
image, label = decoded_record
im_list.append(image)
# Visualization
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(4., 4.))
# Ref: https://matplotlib.org/3.1.1/gallery/axes_grid1/simple_axesgrid.html
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(4, 4), # creates 2x2 grid of axes
axes_pad=0.1, # pad between axes in inch.
)
# Show image grid
for ax, im in zip(grid, im_list):
# Iterating over the grid returns the Axes.
ax.imshow(im, 'gray')
plt.show()