Skip to content

Commit

Permalink
Fix nemo streaming transducer greedy search (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored May 30, 2024
1 parent 3f472a9 commit 082f230
Show file tree
Hide file tree
Showing 18 changed files with 320 additions and 290 deletions.
39 changes: 39 additions & 0 deletions .github/scripts/test-online-transducer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,45 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run NeMo transducer (English)"
log "------------------------------------------------------------"
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
curl -SL -O $repo_url
tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
repo=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms

log "Start testing ${repo_url}"

waves=(
$repo/test_wavs/0.wav
$repo/test_wavs/1.wav
$repo/test_wavs/8k.wav
)

for wave in ${waves[@]}; do
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--num-threads=2 \
$wave
done

time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo

log "------------------------------------------------------------"
log "Run LSTM transducer (English)"
log "------------------------------------------------------------"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/aarch64-linux-gnu-shared.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p aarch64
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/aarch64-linux-gnu-static.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p aarch64
cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/android.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
cp -v ../sherpa-onnx-*-android.tar.bz2 ./
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/arm-linux-gnueabihf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p arm32
cp -v ../sherpa-onnx-*.tar.bz2 ./arm32
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/build-xcframework.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
cp -v ../sherpa-onnx-*.tar.bz2 ./
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/riscv64-linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p riscv64
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p win64
cp -v ../sherpa-onnx-*.tar.bz2 ./win64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
cd huggingface
git lfs pull
mkdir -p win32
cp -v ../sherpa-onnx-*.tar.bz2 ./win32
Expand Down
20 changes: 10 additions & 10 deletions sherpa-onnx/csrc/online-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@ namespace sherpa_onnx {

std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
const OnlineRecognizerConfig &config) {

if (!config.model_config.transducer.encoder.empty()) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

auto sess = std::make_unique<Ort::Session>(
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
} else {
SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
}
}
Expand All @@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

auto sess = std::make_unique<Ort::Session>(
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
} else {
Expand Down
11 changes: 4 additions & 7 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,15 @@

namespace sherpa_onnx {

static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start) {
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms, int32_t subsampling_factor,
int32_t segment, int32_t frames_since_start) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());

for (auto i : src.tokens) {
if (i == -1) continue;
auto sym = sym_table[i];

r.text.append(sym);
Expand Down
94 changes: 37 additions & 57 deletions sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_

#include <algorithm>
#include <fstream>
#include <ios>
#include <memory>
Expand All @@ -32,23 +33,20 @@
namespace sherpa_onnx {

// defined in ./online-recognizer-transducer-impl.h
// static may or may not be here? TODDOs
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start);
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms, int32_t subsampling_factor,
int32_t segment, int32_t frames_since_start);

class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public:
public:
explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(config.model_config.tokens),
endpoint_(config_.endpoint_config),
model_(std::make_unique<OnlineTransducerNeMoModel>(
config.model_config)) {
model_(
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
Expand All @@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
config.decoding_method.c_str());
exit(-1);
}

Expand All @@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {

std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
InitOnlineStream(stream.get());
return stream;
}
Expand All @@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
}

OnlineRecognizerResult GetResult(OnlineStream *s) const override {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);

// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 8;
return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
int32_t subsampling_factor = model_->SubsamplingFactor();
return Convert(s->GetResult(), symbol_table_, frame_shift_ms,
subsampling_factor, s->GetCurrentSegment(),
s->GetNumFramesSinceStart());
}

bool IsEndpoint(OnlineStream *s) const override {
Expand All @@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;

// subsampling factor is 8
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
int32_t trailing_silence_frames =
s->GetResult().num_trailing_blanks * model_->SubsamplingFactor();

return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
Expand All @@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// segment is incremented only when the last
// result is not empty
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
if (!r.tokens.empty()) {
s->GetCurrentSegment() += 1;
}
}

// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
s->SetResult({});

s->SetStates(model_->GetEncoderInitStates());

auto r = decoder_->GetEmptyResult();

s->SetResult(r);
s->GetResult().decoder_out = std::move(decoder_out);
s->SetNeMoDecoderStates(model_->GetDecoderInitStates());

// Note: We only update counters. The underlying audio samples
// are not discarded.
Expand All @@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {

int32_t feature_dim = ss[0]->FeatureDim();

std::vector<OnlineTransducerDecoderResult> result(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> encoder_states(n);

for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
Expand All @@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);

result[i] = std::move(ss[i]->GetResult());
encoder_states[i] = std::move(ss[i]->GetStates());

}

auto memory_info =
Expand All @@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
features_vec.size(), x_shape.data(),
x_shape.size());

// Batch size is 1
auto states = std::move(encoder_states[0]);
int32_t num_states = states.size(); // num_states = 3
auto states = model_->StackStates(std::move(encoder_states));
int32_t num_states = states.size(); // num_states = 3
auto t = model_->RunEncoder(std::move(x), std::move(states));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] next states

std::vector<Ort::Value> out_states;
out_states.reserve(num_states);

for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(t[k]));
}

auto unstacked_states = model_->UnStackStates(std::move(out_states));
for (int32_t i = 0; i != n; ++i) {
ss[i]->SetStates(std::move(unstacked_states[i]));
}

Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);

// defined in online-transducer-greedy-search-nemo-decoder.h
// get intial states of decoder.
std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();

// Subsequent decoder states (for each chunks) are updated inside the Decode method.
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
decoder_states = decoder_->Decode(std::move(encoder_out),
std::move(decoder_states),
&result, ss, n);

ss[0]->SetResult(result[0]);

ss[0]->SetStates(std::move(out_states));

decoder_->Decode(std::move(encoder_out), ss, n);
}

void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
// set encoder states
stream->SetStates(model_->GetEncoderInitStates());

stream->SetResult(r);
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
// set decoder states
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates());
}

private:
Expand Down Expand Up @@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}

}

private:
Expand All @@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineTransducerNeMoModel> model_;
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
Endpoint endpoint_;

};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}

void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
void OnlineStream::SetNeMoDecoderStates(
std::vector<Ort::Value> decoder_states) {
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
}

Expand Down
Loading

0 comments on commit 082f230

Please sign in to comment.