-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathautoencoder_data_streams.py
176 lines (142 loc) · 6.11 KB
/
autoencoder_data_streams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Wrapper class to read light field CNN training / test data
#
# (c) Bastian Goldluecke 4/2017, University of Konstanz
# License: Creative Commons BY-SA 4.0
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ctypes
import os
import h5py
import lf_tools
# dataset for light field angular patch stacks
class dataset:
def __init__(self, filename, subsets={ 'validate' : 0.0, 'train': 1.0 }, max_examples=1e20, min_example=0, random_shuffle=True ):
"""Construct a light field angular patch dataset.
Initialized from HDF5, which is expected to contain two datasets:
'data' - patch data, shape [num_examples, T,S, num_labels]
'labels' - label data, shape [num_examples, 1]
Labels are given in non-integers, i.e. can be between two integer labels.
Labels are one-indexed (Matlab style)
"""
print( 'reading HDF5 dataset ' + filename )
self._filename = filename
base=os.path.basename( filename )
self._file_id = os.path.splitext(base)[0]
self._file = h5py.File( self._filename,'r')
self.streams = self._file.keys()
self.stream = dict()
self._num_examples = max_examples
for stream in self.streams:
print( ' found data stream ' + stream )
self.stream[ stream ] = self._file[ stream ]
sh = self.stream[ stream ].shape
print( ' shape %s' % (sh,) )
self._num_examples = min( self._num_examples, sh[ -1 ] )
print( ' total %i training examples used.' % (self._num_examples) )
seed = ctypes.c_ushort(hash( filename )).value
np.random.seed( seed )
self._permutation = np.arange( min_example, self._num_examples)
if random_shuffle:
np.random.shuffle( self._permutation )
self.subsets = dict()
min_index = 0
for s in subsets:
subset = dict()
subset[ 'id' ] = s
# split off the index list for this subset from the global permutation
p = subsets[s]
n = int( p * self._num_examples )
max_index = min( len(self._permutation), n+min_index )
subset[ 'indices' ] = self._permutation[ min_index : max_index ]
if random_shuffle:
np.random.shuffle( subset[ 'indices' ] )
# current subset epoch
subset[ 'epoch' ] = 0
# current offset and minibatch index
subset[ 'index' ] = 0
subset[ 'minibatch_index' ] = 0
min_index = max_index
self.subsets[ s ] = subset
# flag if set is to be shuffled
self._shuffle = random_shuffle
# this function pulls exactly batch_size training examples for
# a certain subset of the training data.
#
# result will be None if subset is smaller than the batch size.
#
def next_batch( self, ColorSpace, batch_size, subset_name='train' ):
"""Return the next `batch_size` examples from this data set."""
subset = self.subsets[ subset_name ]
subset[ 'minibatch_index' ] += 1
# pull indices
new_epoch = 0
start = subset[ 'index' ]
end = min( start + batch_size, len( subset[ 'indices' ] ))
subset[ 'index' ] = end
idx = subset[ 'indices' ][ start:end ]
sz = len( idx )
# if not enough examples drawn, reshuffle subset (if desired),
# and pull remaining ones.
missing = batch_size - sz
if missing > 0:
subset[ 'epoch' ] += 1
subset[ 'index' ] = 0
subset[ 'minibatch_index' ] = 0
if self._shuffle:
np.random.shuffle( subset[ 'indices' ] )
start = subset[ 'index' ]
end = min( start + missing, len( subset[ 'indices' ] ) )
subset[ 'index' ] = end
idx = np.append( idx, subset[ 'indices' ][ start:end ] )
# if there are still samples missing, subset is too small
sz = len( idx )
if sz < batch_size:
return None
# retrieve index set from permuted index array
batch = dict()
for stream in self.streams:
# create array for stream, requires stream shape
sh = list( self.stream[ stream ].shape )
batch[ stream ] = np.zeros( [sz] + sh[0:-1], np.float32 )
if batch[ 'stacks_v_HR' ].shape[-2] == 192:
batch['stacks_v_s4'] = np.zeros( [sz] + [9,192,192,3], np.float32 )
batch['stacks_h_s4'] = np.zeros([sz] + [9, 192, 192, 3], np.float32)
batch['stacks_v_s2'] = np.zeros( [sz] + [9,96,96,3], np.float32 )
batch['stacks_h_s2'] = np.zeros([sz] + [9, 96, 96, 3], np.float32)
batch['stacks_v'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
batch['stacks_h'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
batch['stacks_bicubic_v'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
batch['stacks_bicubic_h'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
if batch[ 'stacks_v_HR' ].shape[-2] == 96:
batch['stacks_v_s2'] = np.zeros( [sz] + [9,96,96,3], np.float32 )
batch['stacks_h_s2'] = np.zeros([sz] + [9, 96, 96, 3], np.float32)
batch['stacks_v'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
batch['stacks_h'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
batch['stacks_bicubic_v'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
batch['stacks_bicubic_h'] = np.zeros([sz] + [9, 48, 48, 3], np.float32)
n = 0
for i in idx:
for stream in self.streams:
sh = self.stream[ stream ].shape
nd = len(sh) - 1
batch_index = [ n ] + [slice(0,None)] * nd
dataset_index = [slice(0,None)] * nd + [i]
batch[ stream ][batch_index] = self.stream[ stream ][ tuple(dataset_index) ]
if 'disp' not in stream:
if ColorSpace == 'YCBCR':
batch = lf_tools.convert2YCBCR(batch, stream, batch_index)
elif ColorSpace == 'LAB':
batch = lf_tools.convert2LAB(batch, stream, batch_index)
if ColorSpace == 'YCBCR':
batch = lf_tools.augment_data_YCBCR(batch, batch_index)
if ColorSpace == 'RGB':
batch = lf_tools.augment_data(batch, batch_index)
batch = lf_tools.create_scales(batch,batch_index)
n = n + 1
#code.interact( local = locals() )
batch[ 'epoch' ] = subset[ 'epoch' ]
batch[ 'minibatch_index' ] = subset[ 'minibatch_index' ]
return batch