Skip to content

Commit

Permalink
feat(core): batch vocos decode
Browse files Browse the repository at this point in the history
feat(cmd): support stream infer
  • Loading branch information
fumiama committed Jul 9, 2024
1 parent 0f47a87 commit 6e18575
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 30 deletions.
74 changes: 51 additions & 23 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class InferCodeParams(RefineTextParams):
repetition_penalty: float = 1.05
max_new_token: int = 2048
stream_batch: int = 24
stream_speed: int = 12000
pass_first_n_batches: int = 2

def infer(
self,
Expand Down Expand Up @@ -374,7 +376,9 @@ def _infer(
yield text
return

length = np.zeros(len(text), dtype=np.uint16)
if stream:
length = np.zeros(len(text), dtype=np.uint32)
pass_batch_count = 0
for result in self._infer_code(
text,
stream,
Expand All @@ -383,43 +387,67 @@ def _infer(
params_infer_code,
):
wavs = self._decode_to_wavs(
result,
length,
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
new_wavs = np.zeros((wavs.shape[0], params_infer_code.stream_speed))
for i in range(wavs.shape[0]):
if b[i] > len(wavs[i]):
b[i] = len(wavs[i])
new_wavs[i, :b[i]-a[i]] = wavs[i, a[i]:b[i]]
length = b
yield new_wavs
else:
yield wavs
if stream:
for i in range(wavs.shape[0]):
a = length[i]
b = len(wavs[i])
wavs[i, :b-a] = wavs[i, a:]
wavs[i, b-a:] = 0
yield wavs

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device):
return self.vocos.decode(spec.cpu()).squeeze_(0).cpu().numpy()
return self.vocos.decode(spec.cpu()).cpu().numpy()
else:
return self.vocos.decode(spec).squeeze_(0).cpu().numpy()
return self.vocos.decode(spec).cpu().numpy()

@torch.inference_mode()
def _decode_to_wavs(
self,
result: GPT.GenerationOutputs,
start_seeks: np.ndarray,
result_list: List[torch.Tensor],
use_decoder: bool,
):
x = result.hiddens if use_decoder else result.ids
wavs: List[Optional[np.ndarray]] = []
for i, chunk_data in enumerate(x):
start_seek: int = start_seeks[i]
length = len(chunk_data)
if length <= start_seek:
wavs.append(None)
continue
start_seeks[i] = length
chunk_data = chunk_data[start_seek:].to(self.device)
decoder = self.decoder if use_decoder else self.dvae
mel_spec = decoder(chunk_data.unsqueeze_(0).permute(0, 2, 1))
del chunk_data
wavs.append(self._vocos_decode(mel_spec))
del_all(mel_spec)
del_all(x)
decoder = self.decoder if use_decoder else self.dvae
max_x_len = -1
if len(result_list) == 0:
return np.array([], dtype=np.float32)
for result in result_list:
if result.size(0) > max_x_len:
max_x_len = result.size(0)
batch_result = torch.zeros(
(len(result_list), result_list[0].size(1), max_x_len),
dtype=result_list[0].dtype,
device=result_list[0].device,
)
for i in range(len(result_list)):
src = result_list[i]
batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
del src
del_all(result_list)
mel_specs = decoder(batch_result)
del batch_result
wavs = self._vocos_decode(mel_specs)
del mel_specs
return wavs

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(
tokenizer: BertTokenizerFast = torch.load(
tokenizer_path, map_location=device, mmap=True
)
tokenizer.padding_side = "left"
self._tokenizer = tokenizer

self.len = len(tokenizer)
Expand Down
29 changes: 23 additions & 6 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import argparse
from typing import Optional, List

import numpy as np

import ChatTTS

from tools.audio import wav_arr_to_mp3_view
Expand All @@ -25,7 +27,7 @@ def save_mp3_file(wav, index):
logger.info(f"Audio saved to {mp3_filename}")


def main(texts: List[str], spk: Optional[str] = None):
def main(texts: List[str], spk: Optional[str] = None, stream=False):
logger.info("Text input: %s", str(texts))

chat = ChatTTS.Chat(get_logger("ChatTTS"))
Expand All @@ -43,23 +45,33 @@ def main(texts: List[str], spk: Optional[str] = None):

logger.info("Start inference.")
wavs = chat.infer(
texts,
texts, stream,
params_infer_code=ChatTTS.Chat.InferCodeParams(
spk_emb=spk,
),
)
logger.info("Inference completed.")
# Save each generated wav file to a local file
if stream:
wavs_list = []
for index, wav in enumerate(wavs):
save_mp3_file(wav, index)
if stream:
for i, w in enumerate(wav):
save_mp3_file(w, (i+1)*1000+index)
wavs_list.append(wav)
else:
save_mp3_file(wav, index)
if stream:
for index, wav in enumerate(np.concatenate(wavs_list, axis=1)):
save_mp3_file(wav, index)
logger.info("Audio generation successful.")


if __name__ == "__main__":
logger.info("Starting ChatTTS commandline demo...")
parser = argparse.ArgumentParser(
description="ChatTTS Command",
usage='[--spk xxx] "Your text 1." " Your text 2."',
usage='[--spk xxx] [--stream] "Your text 1." " Your text 2."',
)
parser.add_argument(
"--spk",
Expand All @@ -68,8 +80,13 @@ def main(texts: List[str], spk: Optional[str] = None):
default=None,
)
parser.add_argument(
"texts", help="Original text", default="YOUR TEXT HERE", nargs="*"
"--stream",
help="Use stream mode",
action='store_true',
)
parser.add_argument(
"texts", help="Original text", default=["YOUR TEXT HERE"], nargs=argparse.REMAINDER,
)
args = parser.parse_args()
main(args.texts, args.spk)
main(args.texts, args.spk, args.stream)
logger.info("ChatTTS process finished.")

0 comments on commit 6e18575

Please sign in to comment.