Skip to content

Commit

Permalink
Fix kokoro batch issue (#128)
Browse files Browse the repository at this point in the history
* Fix kokoro batch issue

* code

* fix batch size

---------

Co-authored-by: Freddy Boulton <[email protected]>
  • Loading branch information
freddyaboulton and Freddy Boulton authored Mar 6, 2025
1 parent 6517a93 commit df0706e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
13 changes: 13 additions & 0 deletions backend/fastrtc/text_to_speech/test_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from fastrtc.text_to_speech.tts import get_tts_model


def test_tts_long_prompt():
model = get_tts_model()
prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable"

for i, chunk in enumerate(model.stream_tts_sync(prompt)):
print(f"Chunk {i}: {chunk[1].shape}")


if __name__ == "__main__":
test_tts_long_prompt()
45 changes: 45 additions & 0 deletions backend/fastrtc/text_to_speech/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,48 @@ def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel:
return m


class KokoroFixedBatchSize:
# Source: https://github.com/thewh1teagle/kokoro-onnx/issues/115#issuecomment-2676625392
def _split_phonemes(self, phonemes: str) -> list[str]:
MAX_PHONEME_LENGTH = 510
max_length = MAX_PHONEME_LENGTH - 1
batched_phonemes = []
while len(phonemes) > max_length:
# Find best split point within limit
split_idx = max_length

# Try to find the last period before max_length
period_idx = phonemes.rfind(".", 0, max_length)
if period_idx != -1:
split_idx = period_idx + 1 # Include period

else:
# Try other punctuation
match = re.search(
r"[!?;,]", phonemes[:max_length][::-1]
) # Search backwards
if match:
split_idx = max_length - match.start()

else:
# Try last space
space_idx = phonemes.rfind(" ", 0, max_length)
if space_idx != -1:
split_idx = space_idx

# If no good split point is found, force split at max_length
chunk = phonemes[:split_idx].strip()
batched_phonemes.append(chunk)

# Move to the next part
phonemes = phonemes[split_idx:].strip()

# Add remaining phonemes
if phonemes:
batched_phonemes.append(phonemes)
return batched_phonemes


class KokoroTTSModel(TTSModel):
def __init__(self):
from kokoro_onnx import Kokoro
Expand All @@ -48,6 +90,8 @@ def __init__(self):
voices_path=hf_hub_download("fastrtc/kokoro-onnx", "voices-v1.0.bin"),
)

self.model._split_phonemes = KokoroFixedBatchSize()._split_phonemes

def tts(
self, text: str, options: KokoroTTSOptions | None = None
) -> tuple[int, NDArray[np.float32]]:
Expand All @@ -74,6 +118,7 @@ async def stream_tts(
):
if s_idx != 0 and chunk_idx == 0:
yield chunk[1], np.zeros(chunk[1] // 7, dtype=np.float32)
chunk_idx += 1
yield chunk[1], chunk[0]

def stream_tts_sync(
Expand Down

0 comments on commit df0706e

Please sign in to comment.