-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·154 lines (127 loc) · 4.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import tensorflow as tf
from network import Network
from tensorflow.examples.tutorials.mnist import input_data
tf.app.flags.DEFINE_float("learning_rate", 0.01, "Learning rate.")
tf.app.flags.DEFINE_integer("batch_size", 1,
"Batch size to use during training.")
tf.app.flags.DEFINE_integer("size", 500, "Size of each model layer.")
tf.app.flags.DEFINE_integer("n_layers", 10, "Number of layers in the model.")
tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 50,
"How many training steps to do per checkpoint.")
tf.app.flags.DEFINE_boolean("self_test", False,
"Run a self-test if this is set to True.")
tf.app.flags.DEFINE_boolean("train", False,
"Run a train if this is set to True.")
tf.app.flags.DEFINE_integer("n_epochs", 1,
"Number of training iterations.")
tf.app.flags.DEFINE_string("log_dir", "/tmp",
"Tensorboard log directory.")
tf.app.flags.DEFINE_string("data_dir", "/tmp",
"training data directory.")
FLAGS = tf.app.flags.FLAGS
def create_model(sess, n_input, n_output):
input = tf.placeholder("float", [FLAGS.batch_size, n_input])
output = tf.placeholder("float", [FLAGS.batch_size, n_output])
net = Network(FLAGS.n_layers, FLAGS.size, FLAGS.size, n_input,
n_output,
input, output)
# Define loss and optimizer
with tf.name_scope('Loss'):
# cost = tf.nn.l2_loss(tf.sub(tf.transpose(net.pred), output))
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
tf.transpose(net.pred), output))
with tf.name_scope('SGD'):
optimizer = tf.train.AdamOptimizer(
learning_rate=FLAGS.learning_rate). \
minimize(cost)
# Evaluate model
with tf.name_scope('Accuracy'):
correct_pred = tf.equal(tf.argmax(net.pred, 1),
tf.argmax(output, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Create a summary to monitor cost tensor
tf.scalar_summary("loss", cost)
# Create a summary to monitor accuracy tensor
tf.scalar_summary("accuracy", accuracy)
# Merge all summaries into a single op
merged_summary_op = tf.merge_all_summaries()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
print(
"Reading model parameters from %s" %
ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
sess.run(tf.initialize_all_variables())
return input, output, optimizer, cost, merged_summary_op, accuracy, \
saver, net
def train():
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
n_input = 28 * 28
n_output = 10
# Launch the graph
with tf.Session() as sess:
# sess.run(init)
input, output, optimizer, cost, merged_summary_op, accuracy, \
saver, net = \
create_model(sess, n_input, n_output)
# op to write logs to Tensorboard
summary_writer = tf.train.SummaryWriter(FLAGS.log_dir,
graph=tf.get_default_graph())
avg_cost = 0.
# Training cycle
for epoch in range(FLAGS.n_epochs):
avg_cost = 0.
total_batch = int(
mnist.train.num_examples / FLAGS.batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(
FLAGS.batch_size)
# Run optimization op (backprop), cost op (to
# get loss value)
# and summary nodes
_, c, summary, pred_check = sess.run(
[optimizer, cost, merged_summary_op,
net.pred],
feed_dict={
input: batch_xs,
output: batch_ys})
# Write logs at every iteration
summary_writer.add_summary(summary,
epoch * total_batch + i)
# Compute average loss
avg_cost += c / total_batch
# Display logs per epoch step
if (i + 1) % FLAGS.steps_per_checkpoint == 0:
print("Epoch:", '%04d' % (epoch + 1), \
"cost=", "{:.9f}".format(c), \
"iter= " + str(i))
# checkpoint_path = os.path.join(
# FLAGS.train_dir,
#
# "transliterate.ckpt")
# saver.save(sess, checkpoint_path,
# global_step=epoch*total_batch + i)
print("Optimization Finished!")
# Test model
# Calculate accuracy
print("Accuracy:", \
accuracy.eval({
input: mnist.test.images,
output: mnist.test.labels}))
print("Run the command line:\n" \
"--> tensorboard --logdir=/tmp/tensorflow_logs " \
"\nThen open http://0.0.0.0:6006/ into your web browser")
def self_test():
return
def main(_):
if FLAGS.self_test:
self_test()
elif FLAGS.train:
train()
if __name__ == '__main__':
tf.app.run()