-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtesting.py
127 lines (107 loc) · 4.32 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def train_val_changes(train_op_fn, n_steps=5, config=None):
"""
Determine which variables change, and which don't.
Args:
`train_op_fn`: function mapping no inputs to a train operation.
`n_steps`: number of times to run the resulting `train_op`.
Returns:
unchanged_names: list of names of tensors which did not change value.
changed_names: list of names of tensors which did change value.
If `unchanged_names` is not empty, there are likely unused variables which
could possibly be removed, or your gradients are not flowing correctly.
"""
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
train_op = train_op_fn()
trainable_vars = tf.trainable_variables()
with tf.Session(graph=graph, config=config) as sess:
sess.run(tf.global_variables_initializer())
vals = sess.run(trainable_vars)
for _ in range(n_steps):
sess.run(train_op)
new_vals = sess.run(trainable_vars)
names = [v.name for v in trainable_vars]
unchanged_names = []
changed_names = []
for name, val, new_val in zip(names, vals, new_vals):
if np.all(val == new_val):
unchanged_names.append(name)
else:
changed_names.append(name)
return unchanged_names, changed_names
def report_train_val_changes(train_op_fn, steps=5, config=None):
"""This wrapper around train_val_changes with printed output."""
unchanged_names, changed_names = train_val_changes(
train_op_fn, steps, config=config)
n_unchanged = len(unchanged_names)
n_changed = len(changed_names)
n_total = n_unchanged + n_changed
if len(unchanged_names) == 0:
print('All trainable variables changed :)')
else:
print('%d / %d training variables unchanged'
% (n_unchanged, n_total))
print('Changed vars:')
for name in changed_names:
print(name)
print('Unchanged vars:')
for name in unchanged_names:
print(name)
def do_update_ops_run(train_op_fn, config=None):
"""
Determine whether all update ops are running by default.
This is helpful to determine whether other automatically created update_ops
will automatically be run, e.g. moving average updates in batch
normalization.
Implemented by creating an update op, calling `train_op_fn`, then running
the resulting `train_op` and checking if the initially created update op
is run.
Args:
`train_op_fn`: function mapping no inputs to a train op
Returns:
None if no update ops created by `train_op_fn`
True if the initially created update_op is run
False if the initially created update_op is not run.
If the initially created update_op is not run (indicated by this function
returning `False`), consider wrapping your `train_op` with
`tf.control_dependencies`.
```
def new_train_op_fn():
train_op = old_train_op_fn()
ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
ops.append(train_op)
with tf.control_dependencies(ops):
fixed_train_op = tf.no_op()
return fixed_train_op
```
"""
graph = tf.Graph()
with graph.as_default():
step = tf.Variable(
initial_value=0, dtype=tf.int32, name='test_step', trainable=False)
update_step = tf.assign_add(step, 1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_step)
train_op = train_op_fn()
ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if ops[0] == update_step and len(ops) == 1:
return None
with tf.Session(graph=graph, config=config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(train_op)
s = sess.run(step)
return s == 1
def report_update_ops_run(train_op_fn, config=None):
"""This wrapper around do_update_ops_run with printing."""
s = do_update_ops_run(train_op_fn, config=config)
if s is None:
print('No UPDATE_OPS created by `train_op_fn`')
elif s:
print('UPDATE_OPS run successfully :)')
else:
print('UPDATE_OPS not automatically run :(')