-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathdecoder-batch.cpp
94 lines (71 loc) · 3.77 KB
/
decoder-batch.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// decoder-cuda.cpp - GPU Decoder Implementation
// local includes
#include "decoder.hpp"
#include "model.hpp"
#include "config.hpp"
#include "types.hpp"
namespace kaldiserve {
#if HAVE_CUDA == 1
BatchDecoder::BatchDecoder(ChainModel *const model) : model_(model) {
if (model_->wb_info != nullptr) options.enable_word_level = true;
if (model_->rnnlm_info != nullptr) options.enable_rnnlm = true;
// kaldi::CuDevice::RegisterDeviceOptions(&po); // only need if using fp16 (can't access device_options_ directly)
// kaldi::g_allocator_options // only need if need to customize cuda memory usage
batched_decoder_config_.cuda_online_pipeline_opts.use_gpu_feature_extraction = false;
batched_decoder_config_.cuda_online_pipeline_opts.determinize_lattice = false;
// decoder options
batched_decoder_config_.cuda_online_pipeline_opts.decoder_opts.default_beam = model_->model_spec.beam;
batched_decoder_config_.cuda_online_pipeline_opts.decoder_opts.lattice_beam = model_->model_spec.lattice_beam;
batched_decoder_config_.cuda_online_pipeline_opts.decoder_opts.max_active = model_->model_spec.max_active;
// feature pipeline options
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.feature_type = "mfcc";
std::string model_dir = model_->model_spec.path;
std::string conf_dir = join_path(model_dir, "conf");
std::string mfcc_conf_filepath = join_path(conf_dir, "mfcc.conf");
std::string ivector_conf_filepath = join_path(conf_dir, "ivector_extractor.conf");
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.mfcc_config = mfcc_conf_filepath;
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.ivector_extraction_config = ivector_conf_filepath;
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.silence_weighting_config.silence_weight = model_->model_spec.silence_weight;
// compute options
batched_decoder_config_.cuda_online_pipeline_opts.compute_opts.acoustic_scale = model_->model_spec.acoustic_scale;
batched_decoder_config_.cuda_online_pipeline_opts.compute_opts.frame_subsampling_factor = model_->model_spec.frame_subsampling_factor;
cuda_pipeline_ = NULL;
}
BatchDecoder::~BatchDecoder() {
free_decoder();
}
void BatchDecoder::start_decoding() {
kaldi::g_cuda_allocator.SetOptions(kaldi::g_allocator_options);
kaldi::CuDevice::Instantiate().SelectGpuId("yes");
kaldi::CuDevice::Instantiate().AllowMultithreading();
cuda_pipeline_ = new kaldi::cuda_decoder::BatchedThreadedNnet3CudaPipeline2(
batched_decoder_config_, *model_->decode_fst, model_->am_nnet, model_->trans_model);
}
void BatchDecoder::free_decoder() {
num_tasks_submitted_ = 0;
if (cuda_pipeline_) {
delete cuda_pipeline_;
cuda_pipeline_ = NULL;
}
}
void BatchDecoder::decode_with_callback(std::istream &wav_stream,
const int &n_best,
const bool &word_level,
const std::string &key,
std::function<void(const utterance_results_t &results)> &user_callback) {
auto wave_data = std::shared_ptr<kaldi::WaveData>(new kaldi::WaveData());
wave_data->Read(wav_stream);
cuda_pipeline_->DecodeWithCallback(wave_data, [&n_best, &word_level, &user_callback, this](kaldi::CompactLattice &clat) {
utterance_results_t results;
find_alternatives(clat, n_best, results, word_level, this->model_, this->options);
user_callback(results);
});
num_tasks_submitted_++;
}
void BatchDecoder::wait_for_tasks() {
std::cout << "#tasks submitted: " << num_tasks_submitted_ << std::endl;
cuda_pipeline_->WaitForAllTasks();
cudaDeviceSynchronize();
}
#endif
} // namespace kaldiserve