-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexperiment.py
80 lines (57 loc) · 2.63 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# coding: utf-8
# Use this file to try out the CNN Text Classifying Network on sentences outside the dataset.
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import manage_data
from text_network import TextNetwork
from tensorflow.contrib import learn
tf.flags.DEFINE_string("sent", "", "The sentence to evaluate (default: '')")
# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
tf.flags.DEFINE_boolean("eval_train", False, "Evaluate on all training data")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
# TODO: Refactor this to predict only a single case
if not FLAGS.sent:
test_sentence = "This is a test sentence!" # Our x data
else:
test_sentence = FLAGS.sent
x_raw = test_sentence
# Map data into vocabulary
vocab_path = os.path.join(FLAGS.checkpoint_dir, "..", "vocab")
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)
x_test = np.array(list(vocab_processor.transform(x_raw)))
print("\nEvaluating your sentence...\n")
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name("input_x").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batches = manage_data.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)
for x_test_batch in batches:
prediction = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
print("Prediction was: ", prediction[0])
def run_network_on_sentence( sent ):
raise NotImplementedError
def classification_report( sent ):
raise NotImplementedError