diff --git a/ChatTTS/core.py b/ChatTTS/core.py index e72ef0c7a..f6ce83f09 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -203,6 +203,8 @@ class InferCodeParams(RefineTextParams): repetition_penalty: float = 1.05 max_new_token: int = 2048 stream_batch: int = 24 + stream_speed: int = 12000 + pass_first_n_batches: int = 2 def infer( self, @@ -374,7 +376,9 @@ def _infer( yield text return - length = np.zeros(len(text), dtype=np.uint16) + if stream: + length = np.zeros(len(text), dtype=np.uint32) + pass_batch_count = 0 for result in self._infer_code( text, stream, @@ -383,43 +387,67 @@ def _infer( params_infer_code, ): wavs = self._decode_to_wavs( - result, - length, + result.hiddens if use_decoder else result.ids, use_decoder, ) result.destroy() + if stream: + pass_batch_count += 1 + if pass_batch_count <= params_infer_code.pass_first_n_batches: + continue + a = length + b = a + params_infer_code.stream_speed + new_wavs = np.zeros((wavs.shape[0], params_infer_code.stream_speed)) + for i in range(wavs.shape[0]): + if b[i] > len(wavs[i]): + b[i] = len(wavs[i]) + new_wavs[i, :b[i]-a[i]] = wavs[i, a[i]:b[i]] + length = b + yield new_wavs + else: + yield wavs + if stream: + for i in range(wavs.shape[0]): + a = length[i] + b = len(wavs[i]) + wavs[i, :b-a] = wavs[i, a:] + wavs[i, b-a:] = 0 yield wavs @torch.inference_mode() def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray: if "mps" in str(self.device): - return self.vocos.decode(spec.cpu()).squeeze_(0).cpu().numpy() + return self.vocos.decode(spec.cpu()).cpu().numpy() else: - return self.vocos.decode(spec).squeeze_(0).cpu().numpy() + return self.vocos.decode(spec).cpu().numpy() @torch.inference_mode() def _decode_to_wavs( self, - result: GPT.GenerationOutputs, - start_seeks: np.ndarray, + result_list: List[torch.Tensor], use_decoder: bool, ): - x = result.hiddens if use_decoder else result.ids - wavs: List[Optional[np.ndarray]] = [] - for i, chunk_data in enumerate(x): - start_seek: int = start_seeks[i] - length = len(chunk_data) - if length <= start_seek: - wavs.append(None) - continue - start_seeks[i] = length - chunk_data = chunk_data[start_seek:].to(self.device) - decoder = self.decoder if use_decoder else self.dvae - mel_spec = decoder(chunk_data.unsqueeze_(0).permute(0, 2, 1)) - del chunk_data - wavs.append(self._vocos_decode(mel_spec)) - del_all(mel_spec) - del_all(x) + decoder = self.decoder if use_decoder else self.dvae + max_x_len = -1 + if len(result_list) == 0: + return np.array([], dtype=np.float32) + for result in result_list: + if result.size(0) > max_x_len: + max_x_len = result.size(0) + batch_result = torch.zeros( + (len(result_list), result_list[0].size(1), max_x_len), + dtype=result_list[0].dtype, + device=result_list[0].device, + ) + for i in range(len(result_list)): + src = result_list[i] + batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0)) + del src + del_all(result_list) + mel_specs = decoder(batch_result) + del batch_result + wavs = self._vocos_decode(mel_specs) + del mel_specs return wavs @staticmethod diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index 5fa430152..b32964353 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -13,7 +13,6 @@ def __init__( tokenizer: BertTokenizerFast = torch.load( tokenizer_path, map_location=device, mmap=True ) - tokenizer.padding_side = "left" self._tokenizer = tokenizer self.len = len(tokenizer) diff --git a/examples/cmd/run.py b/examples/cmd/run.py index 7acb2796a..77af250a9 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -9,6 +9,8 @@ import argparse from typing import Optional, List +import numpy as np + import ChatTTS from tools.audio import wav_arr_to_mp3_view @@ -25,7 +27,7 @@ def save_mp3_file(wav, index): logger.info(f"Audio saved to {mp3_filename}") -def main(texts: List[str], spk: Optional[str] = None): +def main(texts: List[str], spk: Optional[str] = None, stream=False): logger.info("Text input: %s", str(texts)) chat = ChatTTS.Chat(get_logger("ChatTTS")) @@ -43,15 +45,25 @@ def main(texts: List[str], spk: Optional[str] = None): logger.info("Start inference.") wavs = chat.infer( - texts, + texts, stream, params_infer_code=ChatTTS.Chat.InferCodeParams( spk_emb=spk, ), ) logger.info("Inference completed.") # Save each generated wav file to a local file + if stream: + wavs_list = [] for index, wav in enumerate(wavs): - save_mp3_file(wav, index) + if stream: + for i, w in enumerate(wav): + save_mp3_file(w, (i+1)*1000+index) + wavs_list.append(wav) + else: + save_mp3_file(wav, index) + if stream: + for index, wav in enumerate(np.concatenate(wavs_list, axis=1)): + save_mp3_file(wav, index) logger.info("Audio generation successful.") @@ -59,7 +71,7 @@ def main(texts: List[str], spk: Optional[str] = None): logger.info("Starting ChatTTS commandline demo...") parser = argparse.ArgumentParser( description="ChatTTS Command", - usage='[--spk xxx] "Your text 1." " Your text 2."', + usage='[--spk xxx] [--stream] "Your text 1." " Your text 2."', ) parser.add_argument( "--spk", @@ -68,8 +80,13 @@ def main(texts: List[str], spk: Optional[str] = None): default=None, ) parser.add_argument( - "texts", help="Original text", default="YOUR TEXT HERE", nargs="*" + "--stream", + help="Use stream mode", + action='store_true', + ) + parser.add_argument( + "texts", help="Original text", default=["YOUR TEXT HERE"], nargs=argparse.REMAINDER, ) args = parser.parse_args() - main(args.texts, args.spk) + main(args.texts, args.spk, args.stream) logger.info("ChatTTS process finished.")