-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathextract_weights.py
82 lines (67 loc) · 3.48 KB
/
extract_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import tensorflow as tf
import numpy as np
import os
PATH_TO_CKPT = 'tensorflow_checkpoint'
MODEL_VERSION = 'model_epoch_0047_step_20591'
PATH_TO_MODEL = os.path.join(PATH_TO_CKPT, MODEL_VERSION)
PATH_TO_WEIGHTS = 'numpy_weights'
PATH_TO_CONV1 = os.path.join(PATH_TO_WEIGHTS, 'conv1.weights.npz')
PATH_TO_CONV1_BIAS = os.path.join(PATH_TO_WEIGHTS, 'conv1.bias.npz')
PATH_TO_PRIMARY_CAPS = os.path.join(PATH_TO_WEIGHTS, 'primary_caps.weights.npz')
PATH_TO_PRIMARY_CAPS_BIAS = os.path.join(PATH_TO_WEIGHTS, 'primary_caps.bias.npz')
PATH_TO_DIGIT_CAPS = os.path.join(PATH_TO_WEIGHTS, 'digit_caps.weights.npz')
PATH_TO_FULLY_CONNECTED1 = os.path.join(PATH_TO_WEIGHTS, 'fully_connected1.weights.npz')
PATH_TO_FULLY_CONNECTED2 = os.path.join(PATH_TO_WEIGHTS, 'fully_connected2.weights.npz')
PATH_TO_FULLY_CONNECTED3 = os.path.join(PATH_TO_WEIGHTS, 'fully_connected3.weights.npz')
PATH_TO_FULLY_CONNECTED1_BIAS = os.path.join(PATH_TO_WEIGHTS, 'fully_connected1.bias.npz')
PATH_TO_FULLY_CONNECTED2_BIAS = os.path.join(PATH_TO_WEIGHTS, 'fully_connected2.bias.npz')
PATH_TO_FULLY_CONNECTED3_BIAS = os.path.join(PATH_TO_WEIGHTS, 'fully_connected3.bias.npz')
print_tensors_in_checkpoint_file(file_name=PATH_TO_MODEL, tensor_name='', all_tensors=False)
sess = tf.Session()
new_saver = tf.train.import_meta_graph(PATH_TO_MODEL + '.meta')
new_saver.restore(sess, tf.train.latest_checkpoint(PATH_TO_CKPT))
# Conv1_layer/Conv/weights (DT_FLOAT) [9,9,1,256]
weights = sess.run('Conv1_layer/Conv/weights:0')
with open(PATH_TO_CONV1, 'wb') as outfile:
np.save(outfile, weights)
# Conv1_layer/Conv/biases (DT_FLOAT) [256]
bias = sess.run('Conv1_layer/Conv/biases:0')
with open(PATH_TO_CONV1_BIAS, 'wb') as outfile:
np.save(outfile, bias)
# PrimaryCaps_layer/Conv/weights (DT_FLOAT) [9,9,256,256]
weights = sess.run('PrimaryCaps_layer/Conv/weights:0')
with open(PATH_TO_PRIMARY_CAPS, 'wb') as outfile:
np.save(outfile, weights)
# PrimaryCaps_layer/Conv/biases (DT_FLOAT) [256]
bias = sess.run('PrimaryCaps_layer/Conv/biases:0')
with open(PATH_TO_PRIMARY_CAPS_BIAS, 'wb') as outfile:
np.save(outfile, bias)
# DigitCaps_layer/routing/Weight (DT_FLOAT) [1,1152,10,8,16]
weights = sess.run('DigitCaps_layer/routing/Weight:0')
with open(PATH_TO_DIGIT_CAPS, 'wb') as outfile:
np.save(outfile, weights)
# Decoder/fully_connected/weights (DT_FLOAT) [16,512]
weights = sess.run('Decoder/fully_connected/weights:0')
with open(PATH_TO_FULLY_CONNECTED1, 'wb') as outfile:
np.save(outfile, weights)
# Decoder/fully_connected_1/weights (DT_FLOAT) [512,1024]
weights = sess.run('Decoder/fully_connected_1/weights:0')
with open(PATH_TO_FULLY_CONNECTED2, 'wb') as outfile:
np.save(outfile, weights)
# Decoder/fully_connected_2/weights (DT_FLOAT) [1024,784]
weights = sess.run('Decoder/fully_connected_2/weights:0')
with open(PATH_TO_FULLY_CONNECTED3, 'wb') as outfile:
np.save(outfile, weights)
# Decoder/fully_connected/biases (DT_FLOAT) [512]
bias = sess.run('Decoder/fully_connected/biases:0')
with open(PATH_TO_FULLY_CONNECTED1_BIAS, 'wb') as outfile:
np.save(outfile, bias)
# Decoder/fully_connected_1/biases (DT_FLOAT) [1024]
bias = sess.run('Decoder/fully_connected_1/biases:0')
with open(PATH_TO_FULLY_CONNECTED2_BIAS, 'wb') as outfile:
np.save(outfile, bias)
# Decoder/fully_connected_2/biases (DT_FLOAT) [784]
bias = sess.run('Decoder/fully_connected_2/biases:0')
with open(PATH_TO_FULLY_CONNECTED3_BIAS, 'wb') as outfile:
np.save(outfile, bias)