-
Notifications
You must be signed in to change notification settings - Fork 561
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add C++ and Python API for FireRedASR AED models (#1867)
- Loading branch information
1 parent
2337169
commit 316424b
Showing
20 changed files
with
1,019 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#!/usr/bin/env python3 | ||
|
||
""" | ||
This file shows how to use a non-streaming FireRedAsr AED model from | ||
https://github.com/FireRedTeam/FireRedASR | ||
to decode files. | ||
Please download model files from | ||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
For instance, | ||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 | ||
tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 | ||
rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 | ||
""" | ||
|
||
from pathlib import Path | ||
|
||
import sherpa_onnx | ||
import soundfile as sf | ||
|
||
|
||
def create_recognizer(): | ||
encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx" | ||
decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx" | ||
tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt" | ||
test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav" | ||
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav" | ||
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav" | ||
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav" | ||
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav" | ||
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav" | ||
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav" | ||
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav" | ||
|
||
if ( | ||
not Path(encoder).is_file() | ||
or not Path(decoder).is_file() | ||
or not Path(test_wav).is_file() | ||
): | ||
raise ValueError( | ||
"""Please download model files from | ||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
""" | ||
) | ||
return ( | ||
sherpa_onnx.OfflineRecognizer.from_fire_red_asr( | ||
encoder=encoder, | ||
decoder=decoder, | ||
tokens=tokens, | ||
debug=True, | ||
), | ||
test_wav, | ||
) | ||
|
||
|
||
def main(): | ||
recognizer, wave_filename = create_recognizer() | ||
|
||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
audio = audio[:, 0] # only use the first channel | ||
|
||
# audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
# sample_rate does not need to be 16000 Hz | ||
|
||
stream = recognizer.create_stream() | ||
stream.accept_waveform(sample_rate, audio) | ||
recognizer.decode_stream(stream) | ||
print(wave_filename) | ||
print(stream.result) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ | ||
|
||
#include <cstdint> | ||
#include <vector> | ||
|
||
#include "onnxruntime_cxx_api.h" // NOLINT | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct OfflineFireRedAsrDecoderResult { | ||
/// The decoded token IDs | ||
std::vector<int32_t> tokens; | ||
}; | ||
|
||
class OfflineFireRedAsrDecoder { | ||
public: | ||
virtual ~OfflineFireRedAsrDecoder() = default; | ||
|
||
/** Run beam search given the output from the FireRedAsr encoder model. | ||
* | ||
* @param n_layer_cross_k A 4-D tensor of shape | ||
* (num_decoder_layers, N, T, d_model). | ||
* @param n_layer_cross_v A 4-D tensor of shape | ||
* (num_decoder_layers, N, T, d_model). | ||
* | ||
* @return Return a vector of size `N` containing the decoded results. | ||
*/ | ||
virtual std::vector<OfflineFireRedAsrDecoderResult> Decode( | ||
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ |
87 changes: 87 additions & 0 deletions
87
sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h" | ||
|
||
#include <algorithm> | ||
#include <tuple> | ||
#include <utility> | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "sherpa-onnx/csrc/onnx-utils.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
// Note: this functions works only for batch size == 1 at present | ||
std::vector<OfflineFireRedAsrDecoderResult> | ||
OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
Ort::Value cross_v) { | ||
const auto &meta_data = model_->GetModelMetadata(); | ||
|
||
auto memory_info = | ||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
|
||
// For multilingual models, initial_tokens contains [sot, language, task] | ||
// - language is English by default | ||
// - task is transcribe by default | ||
// | ||
// For non-multilingual models, initial_tokens contains [sot] | ||
std::array<int64_t, 2> token_shape = {1, 1}; | ||
int64_t token = meta_data.sos_id; | ||
|
||
int32_t batch_size = 1; | ||
|
||
Ort::Value tokens = Ort::Value::CreateTensor( | ||
memory_info, &token, 1, token_shape.data(), token_shape.size()); | ||
|
||
std::array<int64_t, 1> offset_shape{1}; | ||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>( | ||
model_->Allocator(), offset_shape.data(), offset_shape.size()); | ||
*(offset.GetTensorMutableData<int64_t>()) = 0; | ||
|
||
std::vector<OfflineFireRedAsrDecoderResult> ans(1); | ||
|
||
auto self_kv_cache = model_->GetInitialSelfKVCache(); | ||
|
||
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value, | ||
Ort::Value> | ||
decoder_out = {Ort::Value{nullptr}, | ||
std::move(self_kv_cache.first), | ||
std::move(self_kv_cache.second), | ||
std::move(cross_k), | ||
std::move(cross_v), | ||
std::move(offset)}; | ||
|
||
for (int32_t i = 0; i < meta_data.max_len; ++i) { | ||
decoder_out = model_->ForwardDecoder(View(&tokens), | ||
std::move(std::get<1>(decoder_out)), | ||
std::move(std::get<2>(decoder_out)), | ||
std::move(std::get<3>(decoder_out)), | ||
std::move(std::get<4>(decoder_out)), | ||
std::move(std::get<5>(decoder_out))); | ||
|
||
const auto &logits = std::get<0>(decoder_out); | ||
const float *p_logits = logits.GetTensorData<float>(); | ||
|
||
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); | ||
int32_t vocab_size = logits_shape[2]; | ||
|
||
int32_t max_token_id = static_cast<int32_t>(std::distance( | ||
p_logits, std::max_element(p_logits, p_logits + vocab_size))); | ||
if (max_token_id == meta_data.eos_id) { | ||
break; | ||
} | ||
|
||
ans[0].tokens.push_back(max_token_id); | ||
|
||
token = max_token_id; | ||
|
||
// increment offset | ||
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) += 1; | ||
} | ||
|
||
return ans; | ||
} | ||
|
||
} // namespace sherpa_onnx |
29 changes: 29 additions & 0 deletions
29
sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ | ||
|
||
#include <vector> | ||
|
||
#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h" | ||
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder { | ||
public: | ||
explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model) | ||
: model_(model) {} | ||
|
||
std::vector<OfflineFireRedAsrDecoderResult> Decode( | ||
Ort::Value cross_k, Ort::Value cross_v) override; | ||
|
||
private: | ||
OfflineFireRedAsrModel *model_; // not owned | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h" | ||
|
||
#include "sherpa-onnx/csrc/file-utils.h" | ||
#include "sherpa-onnx/csrc/macros.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) { | ||
po->Register("fire-red-asr-encoder", &encoder, | ||
"Path to onnx encoder of FireRedAsr"); | ||
|
||
po->Register("fire-red-asr-decoder", &decoder, | ||
"Path to onnx decoder of FireRedAsr"); | ||
} | ||
|
||
bool OfflineFireRedAsrModelConfig::Validate() const { | ||
if (encoder.empty()) { | ||
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder"); | ||
return false; | ||
} | ||
|
||
if (!FileExists(encoder)) { | ||
SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist", | ||
encoder.c_str()); | ||
return false; | ||
} | ||
|
||
if (decoder.empty()) { | ||
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder"); | ||
return false; | ||
} | ||
|
||
if (!FileExists(decoder)) { | ||
SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist", | ||
decoder.c_str()); | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
std::string OfflineFireRedAsrModelConfig::ToString() const { | ||
std::ostringstream os; | ||
|
||
os << "OfflineFireRedAsrModelConfig("; | ||
os << "encoder=\"" << encoder << "\", "; | ||
os << "decoder=\"" << decoder << "\")"; | ||
|
||
return os.str(); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ | ||
|
||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/parse-options.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
// see https://github.com/FireRedTeam/FireRedASR | ||
struct OfflineFireRedAsrModelConfig { | ||
std::string encoder; | ||
std::string decoder; | ||
|
||
OfflineFireRedAsrModelConfig() = default; | ||
OfflineFireRedAsrModelConfig(const std::string &encoder, | ||
const std::string &decoder) | ||
: encoder(encoder), decoder(decoder) {} | ||
|
||
void Register(ParseOptions *po); | ||
bool Validate() const; | ||
|
||
std::string ToString() const; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct OfflineFireRedAsrModelMetaData { | ||
int32_t sos_id; | ||
int32_t eos_id; | ||
int32_t max_len; | ||
|
||
int32_t num_decoder_layers; | ||
int32_t num_head; | ||
int32_t head_dim; | ||
|
||
std::vector<float> mean; | ||
std::vector<float> inv_stddev; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ |
Oops, something went wrong.