-
Notifications
You must be signed in to change notification settings - Fork 278
/
Copy pathevaluate.py
executable file
·192 lines (155 loc) · 5.87 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import progressbar
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from six.moves import zip
import tensorflow as tf
from .util.augmentations import NormalizeSampleRate
from .util.checkpoints import load_graph_for_evaluation
from .util.config import (
Config,
create_progressbar,
initialize_globals_from_cli,
log_error,
log_progress,
)
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
from .util.feeding import create_dataset
from .util.helpers import check_ctcdecoder_version
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts(
(value.indices, value.values, value.dense_shape), alphabet
)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.Decode(res) for res in results]
def evaluate(test_csvs, create_model):
if Config.scorer_path:
scorer = Scorer(
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
)
else:
scorer = None
test_sets = [
create_dataset(
[csv],
batch_size=Config.test_batch_size,
train_phase=False,
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
reverse=Config.reverse_test,
limit=Config.limit_test,
)
for csv in test_csvs
]
iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]),
)
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(
prefix="Test epoch | ",
widgets=["Steps: ", progressbar.Counter(), " | ", progressbar.Timer()],
).start()
log_progress("Test epoch...")
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
(
batch_wav_filenames,
batch_logits,
batch_loss,
batch_lengths,
batch_transcripts,
) = session.run(
[batch_wav_filename, transposed, loss, batch_x_len, batch_y]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
Config.beam_width,
num_processes=num_processes,
scorer=scorer,
cutoff_prob=Config.cutoff_prob,
cutoff_top_n=Config.cutoff_top_n,
)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(
sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)
)
wav_filenames.extend(
wav_filename.decode("UTF-8") for wav_filename in batch_wav_filenames
)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(
wav_filenames, ground_truths, predictions, losses, dataset
)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print("Testing model on {}".format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main():
initialize_globals_from_cli()
if not Config.test_files:
log_error(
"You need to specify what files to use for evaluation via "
"the --test_files flag."
)
sys.exit(1)
from .train import ( # pylint: disable=cyclic-import,import-outside-toplevel
create_model,
)
samples = evaluate(Config.test_files, create_model)
if Config.test_output_file:
save_samples_json(samples, Config.test_output_file)
if __name__ == "__main__":
main()