diff --git a/python-api-examples/offline-tts.py b/python-api-examples/offline-tts.py index 246eedf44..265b09af2 100755 --- a/python-api-examples/offline-tts.py +++ b/python-api-examples/offline-tts.py @@ -161,6 +161,9 @@ def main(): ), rule_fsts=args.tts_rule_fsts, ) + if not tts_config.validate(): + raise ValueError("Please check your config") + tts = sherpa_onnx.OfflineTts(tts_config) start = time.time() diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index fae92dc60..39a2146cc 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -43,6 +43,21 @@ } \ } while (0) +#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + dst = default_value; \ + } else { \ + dst = atoi(value.get()); \ + if (dst < 0) { \ + SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ + exit(-1); \ + } \ + } \ + } while (0) + // read a vector of integers #define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ do { \ @@ -112,4 +127,20 @@ } \ } while (0) +#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ + default_value) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + dst = default_value; \ + } else { \ + dst = value.get(); \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ + exit(-1); \ + } \ + } \ + } while (0) + #endif // SHERPA_ONNX_CSRC_MACROS_H_ diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index 2ace0f57d..20d9c37ac 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -147,10 +147,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { private: void InitLexicon() { - if (model_->IsPiper() && model_->Language() == "English" && - !config_.model.vits.data_dir.empty()) { - lexicon_ = - std::make_unique(config_.model.vits.data_dir); + if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) { + lexicon_ = std::make_unique( + config_.model.vits.tokens, config_.model.vits.data_dir); } else { lexicon_ = std::make_unique( config_.model.vits.lexicon, config_.model.vits.tokens, diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc index feb0b1e36..b9fce0f6b 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -15,7 +15,7 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); po->Register("vits-data-dir", &data_dir, "Path to the directory containing dict for espeak-ng. If it is " - "given, --vits-lexicon and --vits-tokens are ignored."); + "given, --vits-lexicon is ignored."); po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models"); po->Register("vits-noise-scale-w", &noise_scale_w, "noise_scale_w for VITS models"); @@ -34,6 +34,16 @@ bool OfflineTtsVitsModelConfig::Validate() const { return false; } + if (tokens.empty()) { + SHERPA_ONNX_LOGE("Please provide --vits-tokens"); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str()); + return false; + } + if (data_dir.empty()) { if (lexicon.empty()) { SHERPA_ONNX_LOGE("Please provide --vits-lexicon"); @@ -45,15 +55,6 @@ bool OfflineTtsVitsModelConfig::Validate() const { return false; } - if (tokens.empty()) { - SHERPA_ONNX_LOGE("Please provide --vits-tokens"); - return false; - } - - if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str()); - return false; - } } else { if (!FileExists(data_dir + "/phontab")) { SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test", diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.h b/sherpa-onnx/csrc/offline-tts-vits-model-config.h index 99ee86b06..cde8b3920 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.h @@ -16,7 +16,7 @@ struct OfflineTtsVitsModelConfig { std::string lexicon; std::string tokens; - // If data_dir is given, lexicon and tokens are ignored + // If data_dir is given, lexicon is ignored // data_dir is for piper-phonemize, which uses espeak-ng std::string data_dir; diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index 2e8cfe766..31e3a7c31 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -75,15 +75,12 @@ class OfflineTtsVitsModel::Impl { Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); - SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(add_blank_, "add_blank", 0); SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); - SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(punctuations_, "punctuation", + ""); SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); - // SHERPA_ONNX_READ_META_DATA_STR(voice_, "voice"); - if (language_ == "English") { - // FIXME(fangjun): Read voice from the metadata - voice_ = "en-us"; - } + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(voice_, "voice", ""); std::string comment; SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); diff --git a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc index a0335c18a..91d3eee80 100644 --- a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc +++ b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc @@ -4,8 +4,11 @@ #include "sherpa-onnx/csrc/piper-phonemize-lexicon.h" +#include +#include #include #include // NOLINT +#include #include "espeak-ng/speak_lib.h" #include "phoneme_ids.hpp" @@ -14,6 +17,80 @@ namespace sherpa_onnx { +static std::unordered_map ReadTokens(std::istream &is) { + std::wstring_convert, char32_t> conv; + std::unordered_map token2id; + + std::string line; + + std::string sym; + std::u32string s; + int32_t id; + while (std::getline(is, line)) { + std::istringstream iss(line); + iss >> sym; + if (iss.eof()) { + id = atoi(sym.c_str()); + sym = " "; + } else { + iss >> id; + } + + // eat the trailing \r\n on windows + iss >> std::ws; + if (!iss.eof()) { + SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str()); + exit(-1); + } + + s = conv.from_bytes(sym); + if (s.size() != 1) { + SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d", + line.c_str(), static_cast(s.size())); + exit(-1); + } + char32_t c = s[0]; + + if (token2id.count(c)) { + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", + sym.c_str(), line.c_str(), token2id.at(c)); + exit(-1); + } + + token2id.insert({c, id}); + } + + return token2id; +} + +// see the function "phonemes_to_ids" from +// https://github.com/rhasspy/piper/blob/master/notebooks/piper_inference_(ONNX).ipynb +static std::vector PhonemesToIds( + const std::unordered_map &token2id, + const std::vector &phonemes) { + // see + // https://github.com/rhasspy/piper-phonemize/blob/master/src/phoneme_ids.hpp#L17 + int32_t pad = token2id.at(U'_'); + int32_t bos = token2id.at(U'^'); + int32_t eos = token2id.at(U'$'); + + std::vector ans; + ans.reserve(phonemes.size()); + + ans.push_back(bos); + for (auto p : phonemes) { + if (token2id.count(p)) { + ans.push_back(token2id.at(p)); + ans.push_back(pad); + } else { + SHERPA_ONNX_LOGE("Skip unkown phonemes. Unicode codepoint: \\U+%04x.", p); + } + } + ans.push_back(eos); + + return ans; +} + void InitEspeak(const std::string &data_dir) { static std::once_flag init_flag; std::call_once(init_flag, [data_dir]() { @@ -29,8 +106,14 @@ void InitEspeak(const std::string &data_dir) { }); } -PiperPhonemizeLexicon::PiperPhonemizeLexicon(const std::string &data_dir) +PiperPhonemizeLexicon::PiperPhonemizeLexicon(const std::string &tokens, + const std::string &data_dir) : data_dir_(data_dir) { + { + std::ifstream is(tokens); + token2id_ = ReadTokens(is); + } + InitEspeak(data_dir_); } @@ -45,15 +128,11 @@ std::vector PiperPhonemizeLexicon::ConvertTextToTokenIds( std::vector> phonemes; piper::phonemize_eSpeak(text, config, phonemes); - std::vector phoneme_ids; - std::map missing_phonemes; - std::vector ans; - piper::PhonemeIdConfig id_config; + + std::vector phoneme_ids; for (const auto &p : phonemes) { - phoneme_ids.clear(); - missing_phonemes.clear(); - phonemes_to_ids(p, id_config, phoneme_ids, missing_phonemes); + phoneme_ids = PhonemesToIds(token2id_, p); ans.insert(ans.end(), phoneme_ids.begin(), phoneme_ids.end()); } diff --git a/sherpa-onnx/csrc/piper-phonemize-lexicon.h b/sherpa-onnx/csrc/piper-phonemize-lexicon.h index 5f29addf3..627fc9397 100644 --- a/sherpa-onnx/csrc/piper-phonemize-lexicon.h +++ b/sherpa-onnx/csrc/piper-phonemize-lexicon.h @@ -5,20 +5,24 @@ #ifndef SHERPA_ONNX_CSRC_PIPER_PHONEMIZE_LEXICON_H_ #define SHERPA_ONNX_CSRC_PIPER_PHONEMIZE_LEXICON_H_ +#include + #include "sherpa-onnx/csrc/lexicon.h" namespace sherpa_onnx { class PiperPhonemizeLexicon : public Lexicon { public: - explicit PiperPhonemizeLexicon(const std::string &data_dir); + explicit PiperPhonemizeLexicon(const std::string &tokens, + const std::string &data_dir); std::vector ConvertTextToTokenIds( const std::string &text, const std::string &voice = "") const override; private: - std::string voice_; std::string data_dir_; + // map unicode codepoint to an integer ID + std::unordered_map token2id_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc index d488a957c..6e016715d 100644 --- a/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc @@ -28,7 +28,8 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) { .def_readwrite("noise_scale", &PyClass::noise_scale) .def_readwrite("noise_scale_w", &PyClass::noise_scale_w) .def_readwrite("length_scale", &PyClass::length_scale) - .def("__str__", &PyClass::ToString); + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-tts.cc b/sherpa-onnx/python/csrc/offline-tts.cc index 538ceceed..39669a0e4 100644 --- a/sherpa-onnx/python/csrc/offline-tts.cc +++ b/sherpa-onnx/python/csrc/offline-tts.cc @@ -34,6 +34,7 @@ static void PybindOfflineTtsConfig(py::module *m) { py::arg("model"), py::arg("rule_fsts") = "") .def_readwrite("model", &PyClass::model) .def_readwrite("rule_fsts", &PyClass::rule_fsts) + .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); }