-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloader.py
152 lines (132 loc) · 4.66 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import tensorflow as tf
import os
from .. import dark
import numpy as np
from os.path import basename
class loader(object):
"""
interface to work with both .weights and .ckpt files
in loading / recollecting / resolving mode
"""
VAR_LAYER = ['convolutional', 'connected', 'local',
'select', 'conv-select',
'extract', 'conv-extract']
def __init__(self, *args):
self.src_key = list()
self.vals = list()
self.load(*args)
def __call__(self, key):
for idx in range(len(key)):
val = self.find(key, idx)
if val is not None: return val
return None
def find(self, key, idx):
up_to = min(len(self.src_key), 4)
for i in range(up_to):
key_b = self.src_key[i]
if key_b[idx:] == key[idx:]:
return self.yields(i)
return None
def yields(self, idx):
del self.src_key[idx]
temp = self.vals[idx]
del self.vals[idx]
return temp
class weights_loader(loader):
"""one who understands .weights files"""
_W_ORDER = dict({ # order of param flattened into .weights file
'convolutional': [
'biases','gamma','moving_mean','moving_variance','kernel'
],
'connected': ['biases', 'weights'],
'local': ['biases', 'kernels']
})
def load(self, path, src_layers):
self.src_layers = src_layers
walker = weights_walker(path)
for i, layer in enumerate(src_layers):
if layer.type not in self.VAR_LAYER: continue
self.src_key.append([layer])
if walker.eof: new = None
else:
args = layer.signature
new = dark.darknet.create_darkop(*args)
self.vals.append(new)
if new is None: continue
order = self._W_ORDER[new.type]
for par in order:
if par not in new.wshape: continue
val = walker.walk(new.wsize[par])
new.w[par] = val
new.finalize(walker.transpose)
if walker.path is not None:
assert walker.offset == walker.size, \
'expect {} bytes, found {}'.format(
walker.offset, walker.size)
print('Successfully identified {} bytes'.format(
walker.offset))
class checkpoint_loader(loader):
"""
one who understands .ckpt files, very much
"""
def load(self, ckpt, ignore):
meta = ckpt + '.meta'
with tf.Graph().as_default() as graph:
with tf.Session().as_default() as sess:
saver = tf.train.import_meta_graph(meta)
saver.restore(sess, ckpt)
for var in tf.global_variables():
name = var.name.split(':')[0]
packet = [name, var.get_shape().as_list()]
self.src_key += [packet]
self.vals += [var.eval(sess)]
def create_loader(path, cfg = None):
if path is None:
load_type = weights_loader
elif '.weights' in path:
load_type = weights_loader
else:
load_type = checkpoint_loader
return load_type(path, cfg)
class weights_walker(object):
"""incremental reader of float32 binary files"""
def __init__(self, path):
self.eof = False # end of file
self.path = path # current pos
if path is None:
self.eof = True
return
else:
self.size = os.path.getsize(path)# save the path
major, minor, revision, seen = np.memmap(path,
shape = (), mode = 'r', offset = 0,
dtype = '({})i4,'.format(4))
self.transpose = major > 1000 or minor > 1000
self.offset = 20
def walk(self, size):
if self.eof: return None
end_point = self.offset + 4 * size
assert end_point <= self.size, \
'Over-read {}'.format(self.path)
float32_1D_array = np.memmap(
self.path, shape = (), mode = 'r',
offset = self.offset,
dtype='({})float32,'.format(size)
)
self.offset = end_point
if end_point == self.size:
self.eof = True
return float32_1D_array
def model_name(file_path):
file_name = basename(file_path)
ext = str()
if '.' in file_name: # exclude extension
file_name = file_name.split('.')
ext = file_name[-1]
file_name = '.'.join(file_name[:-1])
if ext == str() or ext == 'meta': # ckpt file
file_name = file_name.split('-')
num = int(file_name[-1])
return '-'.join(file_name[:-1])
if ext == 'weights':
return file_name