-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathget_batches_vkitti.py
executable file
·51 lines (46 loc) · 2.38 KB
/
get_batches_vkitti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import tensorflow as tf
def read_and_decode(filename_queue, IMG_HEIGHT, IMG_WIDTH):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'img1_raw': tf.FixedLenFeature([IMG_HEIGHT,IMG_WIDTH,3], tf.float32),
'img2_raw': tf.FixedLenFeature([IMG_HEIGHT,IMG_WIDTH,3], tf.float32),
#'edge1_raw': tf.FixedLenFeature([IMG_HEIGHT,IMG_WIDTH,1], tf.float32),
#'edge2_raw': tf.FixedLenFeature([IMG_HEIGHT,IMG_WIDTH,1], tf.float32),
'flow_raw': tf.FixedLenFeature([IMG_HEIGHT,IMG_WIDTH,2], tf.float32),
'vmap_raw': tf.FixedLenFeature([IMG_HEIGHT,IMG_WIDTH,1], tf.float32)
})
img1 = features['img1_raw']
img2 = features['img2_raw']
#edge1 = features['edge1_raw']
#edge2 = features['edge2_raw']
flow = features['flow_raw']
vmap = features['vmap_raw']
return img1, img2, flow, vmap
def inputs(filename, batch_size, IMG_HEIGHT, IMG_WIDTH, num_epochs=None, capp=2000):
"""Reads input data num_epochs times.
Args:
batch_size: Number of examples per returned batch.
num_epochs: Number of times to read the input data, or 0/None to train forever.
Returns:
A tuple (img, gtc), where:
* image is a float tensor with shape [batch_size, H, W, 3]
in the range [-0.5, 0.5]. Same to gtc.
Note that an tf.train.QueueRunner is added to the graph, which
must be run using e.g. tf.train.start_queue_runners().
"""
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
# Even when reading in multiple threads, share the filename queue.
img1, img2, flow, vmap = read_and_decode( filename_queue, IMG_HEIGHT, IMG_WIDTH )
# Shuffle the examples and collect them into batch_size batches.
# (Internally uses a RandomShuffleQueue.)
# We run this in two threads to avoid being a bottleneck.
img1s, img2s, flows, vmaps = tf.train.shuffle_batch( [img1, img2, flow, vmap],
batch_size=batch_size, num_threads=2, capacity=capp + 3 * batch_size,
# Ensures a minimum amount of shuffling of examples.
min_after_dequeue=capp)
return img1s, img2s, flows, vmaps