Skip to content

Commit

Permalink
Add C++ and Python API for FireRedASR AED models (#1867)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Feb 16, 2025
1 parent 2337169 commit 316424b
Show file tree
Hide file tree
Showing 20 changed files with 1,019 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ lexicon.txt
us_gold.json
us_silver.json
kokoro-multi-lang-v1_0
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
75 changes: 75 additions & 0 deletions python-api-examples/offline-fire-red-asr-decode-files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3

"""
This file shows how to use a non-streaming FireRedAsr AED model from
https://github.com/FireRedTeam/FireRedASR
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
"""

from pathlib import Path

import sherpa_onnx
import soundfile as sf


def create_recognizer():
encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"
decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"
tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"
test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav"

if (
not Path(encoder).is_file()
or not Path(decoder).is_file()
or not Path(test_wav).is_file()
):
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
encoder=encoder,
decoder=decoder,
tokens=tokens,
debug=True,
),
test_wav,
)


def main():
recognizer, wave_filename = create_recognizer()

audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel

# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz

stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ set(sources
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-fire-red-asr-greedy-search-decoder.cc
offline-fire-red-asr-model-config.cc
offline-fire-red-asr-model.cc
offline-lm-config.cc
offline-lm.cc
offline-model-config.cc
Expand Down
39 changes: 39 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_

#include <cstdint>
#include <vector>

#include "onnxruntime_cxx_api.h" // NOLINT

namespace sherpa_onnx {

struct OfflineFireRedAsrDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
};

class OfflineFireRedAsrDecoder {
public:
virtual ~OfflineFireRedAsrDecoder() = default;

/** Run beam search given the output from the FireRedAsr encoder model.
*
* @param n_layer_cross_k A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param n_layer_cross_v A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
87 changes: 87 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
//
// Copyright (c) 2025 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"

#include <algorithm>
#include <tuple>
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

namespace sherpa_onnx {

// Note: this functions works only for batch size == 1 at present
std::vector<OfflineFireRedAsrDecoderResult>
OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
const auto &meta_data = model_->GetModelMetadata();

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

// For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std::array<int64_t, 2> token_shape = {1, 1};
int64_t token = meta_data.sos_id;

int32_t batch_size = 1;

Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token, 1, token_shape.data(), token_shape.size());

std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;

std::vector<OfflineFireRedAsrDecoderResult> ans(1);

auto self_kv_cache = model_->GetInitialSelfKVCache();

std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
decoder_out = {Ort::Value{nullptr},
std::move(self_kv_cache.first),
std::move(self_kv_cache.second),
std::move(cross_k),
std::move(cross_v),
std::move(offset)};

for (int32_t i = 0; i < meta_data.max_len; ++i) {
decoder_out = model_->ForwardDecoder(View(&tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
std::move(std::get<3>(decoder_out)),
std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out)));

const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();

auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
int32_t vocab_size = logits_shape[2];

int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
if (max_token_id == meta_data.eos_id) {
break;
}

ans[0].tokens.push_back(max_token_id);

token = max_token_id;

// increment offset
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) += 1;
}

return ans;
}

} // namespace sherpa_onnx
29 changes: 29 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_

#include <vector>

#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"

namespace sherpa_onnx {

class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder {
public:
explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model)
: model_(model) {}

std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value cross_k, Ort::Value cross_v) override;

private:
OfflineFireRedAsrModel *model_; // not owned
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
56 changes: 56 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) {
po->Register("fire-red-asr-encoder", &encoder,
"Path to onnx encoder of FireRedAsr");

po->Register("fire-red-asr-decoder", &decoder,
"Path to onnx decoder of FireRedAsr");
}

bool OfflineFireRedAsrModelConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder");
return false;
}

if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist",
encoder.c_str());
return false;
}

if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder");
return false;
}

if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist",
decoder.c_str());
return false;
}

return true;
}

std::string OfflineFireRedAsrModelConfig::ToString() const {
std::ostringstream os;

os << "OfflineFireRedAsrModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";

return os.str();
}

} // namespace sherpa_onnx
31 changes: 31 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

// see https://github.com/FireRedTeam/FireRedASR
struct OfflineFireRedAsrModelConfig {
std::string encoder;
std::string decoder;

OfflineFireRedAsrModelConfig() = default;
OfflineFireRedAsrModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
28 changes: 28 additions & 0 deletions sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_

#include <string>
#include <unordered_map>
#include <vector>

namespace sherpa_onnx {

struct OfflineFireRedAsrModelMetaData {
int32_t sos_id;
int32_t eos_id;
int32_t max_len;

int32_t num_decoder_layers;
int32_t num_head;
int32_t head_dim;

std::vector<float> mean;
std::vector<float> inv_stddev;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
Loading

0 comments on commit 316424b

Please sign in to comment.