From a16360af859732eedcb6c0faaa1a57081c33c9be Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 13 Nov 2023 13:00:08 +0100 Subject: [PATCH 1/3] Implement chunking gpt_cond --- TTS/tts/configs/xtts_config.py | 10 +++- TTS/tts/models/xtts.py | 101 +++++++++++++++++++++++---------- 2 files changed, 78 insertions(+), 33 deletions(-) diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py index 2d3edaf43a..e8ab07da70 100644 --- a/TTS/tts/configs/xtts_config.py +++ b/TTS/tts/configs/xtts_config.py @@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig): Defaults to `16`. gpt_cond_len (int): - Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`. + Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`. + + gpt_cond_chunk_len (int): + Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the + latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`. max_ref_len (int): Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`. @@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig): num_gpt_outputs: int = 1 # cloning - gpt_cond_len: int = 3 + gpt_cond_len: int = 12 + gpt_cond_chunk_len: int = 4 max_ref_len: int = 10 sound_norm_refs: bool = False diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index f41bcfb944..0f79ad6912 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -255,39 +255,57 @@ def device(self): return next(self.parameters()).device @torch.inference_mode() - def get_gpt_cond_latents(self, audio, sr, length: int = 3): + def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6): """Compute the conditioning latents for the GPT model from the given audio. Args: audio (tensor): audio tensor. sr (int): Sample rate of the audio. - length (int): Length of the audio in seconds. Defaults to 3. + length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30. + chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio + is being used without chunking. It must be < `length`. Defaults to 6. """ if sr != 22050: audio = torchaudio.functional.resample(audio, sr, 22050) - audio = audio[:, : 22050 * length] + if length > 0: + audio = audio[:, : 22050 * length] if self.args.gpt_use_perceiver_resampler: - n_fft = 2048 - hop_length = 256 - win_length = 1024 + style_embs = [] + for i in range(0, audio.shape[1], 22050 * chunk_length): + audio_chunk = audio[:, i : i + 22050 * chunk_length] + mel_chunk = wav_to_mel_cloning( + audio_chunk, + mel_norms=self.mel_stats.cpu(), + n_fft=2048, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None) + style_embs.append(style_emb) + + # mean style embedding + cond_latent = torch.stack(style_embs).mean(dim=0) else: - n_fft = 4096 - hop_length = 1024 - win_length = 4096 - mel = wav_to_mel_cloning( - audio, - mel_norms=self.mel_stats.cpu(), - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - power=2, - normalized=False, - sample_rate=22050, - f_min=0, - f_max=8000, - n_mels=80, - ) - cond_latent = self.gpt.get_style_emb(mel.to(self.device)) + mel = wav_to_mel_cloning( + audio, + mel_norms=self.mel_stats.cpu(), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + cond_latent = self.gpt.get_style_emb(mel.to(self.device)) return cond_latent.transpose(1, 2) @torch.inference_mode() @@ -323,12 +341,24 @@ def get_speaker_embedding(self, audio, sr): def get_conditioning_latents( self, audio_path, + max_ref_length=30, gpt_cond_len=6, - max_ref_length=10, + gpt_cond_chunk_len=6, librosa_trim_db=None, sound_norm_refs=False, - load_sr=24000, + load_sr=22050, ): + """Get the conditioning latents for the GPT model from the given audio. + + Args: + audio_path (str or List[str]): Path to reference audio file(s). + max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30. + gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6. + gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6. + librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None. + sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False. + load_sr (int, optional): Sample rate to load the audio. Defaults to 24000. + """ # deal with multiples references if not isinstance(audio_path, list): audio_paths = [audio_path] @@ -349,14 +379,17 @@ def get_conditioning_latents( if librosa_trim_db is not None: audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + # compute latents for the decoder speaker_embedding = self.get_speaker_embedding(audio, load_sr) speaker_embeddings.append(speaker_embedding) audios.append(audio) - # use a merge of all references for gpt cond latents + # merge all the audios and compute the latents for the gpt full_audio = torch.cat(audios, dim=-1) - gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T] + gpt_cond_latents = self.get_gpt_cond_latents( + full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len + ) # [1, 1024, T] if speaker_embeddings: speaker_embedding = torch.stack(speaker_embeddings) @@ -397,6 +430,7 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs "top_k": config.top_k, "top_p": config.top_p, "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, "max_ref_len": config.max_ref_len, "sound_norm_refs": config.sound_norm_refs, } @@ -417,7 +451,8 @@ def full_inference( top_p=0.85, do_sample=True, # Cloning - gpt_cond_len=6, + gpt_cond_len=30, + gpt_cond_chunk_len=6, max_ref_len=10, sound_norm_refs=False, **hf_generate_kwargs, @@ -448,7 +483,10 @@ def full_inference( (aka boring) outputs. Defaults to 0.8. gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used - else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. + else the first `gpt_cond_len` secs is used. Defaults to 30 seconds. + + gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds. hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation @@ -461,6 +499,7 @@ def full_inference( (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len, + gpt_cond_chunk_len=gpt_cond_chunk_len, max_ref_length=max_ref_len, sound_norm_refs=sound_norm_refs, ) @@ -566,7 +605,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): if overlap_len > len(wav_chunk): # wav_chunk is smaller than overlap_len, pass on last wav_gen if wav_gen_prev is not None: - wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):] + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :] else: # not expecting will hit here as problem happens on last chunk wav_chunk = wav_gen[-overlap_len:] @@ -576,7 +615,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) wav_chunk[:overlap_len] += crossfade_wav - + wav_overlap = wav_gen[-overlap_len:] wav_gen_prev = wav_gen return wav_chunk, wav_gen_prev, wav_overlap From b2682d39c5dd584bb30f6650e3a5b18e27cccf5b Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 13 Nov 2023 13:01:01 +0100 Subject: [PATCH 2/3] Make style --- TTS/cs_api.py | 9 +++++--- TTS/tts/layers/tortoise/dpm_solver.py | 23 +++++++++++++++----- TTS/tts/layers/xtts/tokenizer.py | 11 ++++++---- TTS/tts/layers/xtts/trainer/dataset.py | 1 + tests/xtts_tests/test_xtts_gpt_train.py | 4 +++- tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 4 +++- 6 files changed, 37 insertions(+), 15 deletions(-) diff --git a/TTS/cs_api.py b/TTS/cs_api.py index c45f9d08d5..9dc6c30dd4 100644 --- a/TTS/cs_api.py +++ b/TTS/cs_api.py @@ -82,7 +82,6 @@ class CS_API: }, } - SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"] def __init__(self, api_token=None, model="XTTS"): @@ -308,7 +307,11 @@ def tts_to_file( print(api.list_speakers_as_tts_models()) ts = time.time() - wav, sr = api.tts("It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name) + wav, sr = api.tts( + "It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name + ) print(f" [i] XTTS took {time.time() - ts:.2f}s") - filepath = api.tts_to_file(text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav") + filepath = api.tts_to_file( + text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav" + ) diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py index 2166eebb3c..c70888df42 100644 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -562,15 +562,21 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3,] * ( + orders = [ + 3, + ] * ( K - 2 ) + [2, 1] elif steps % 3 == 1: - orders = [3,] * ( + orders = [ + 3, + ] * ( K - 1 ) + [1] else: - orders = [3,] * ( + orders = [ + 3, + ] * ( K - 1 ) + [2] elif order == 2: @@ -581,7 +587,9 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type ] * K else: K = steps // 2 + 1 - orders = [2,] * ( + orders = [ + 2, + ] * ( K - 1 ) + [1] elif order == 1: @@ -1440,7 +1448,10 @@ def sample( model_prev_list[-1] = self.model_fn(x, t) elif method in ["singlestep", "singlestep_fixed"]: if method == "singlestep": - (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver( + ( + timesteps_outer, + orders, + ) = self.get_orders_and_timesteps_for_singlestep_solver( steps=steps, order=order, skip_type=skip_type, @@ -1548,4 +1559,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file + return v[(...,) + (None,) * (dims - 1)] diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index edb0904277..211d0a93d9 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,6 +1,7 @@ import json import os import re +from functools import cached_property import pypinyin import torch @@ -8,7 +9,6 @@ from hangul_romanize.rule import academic from num2words import num2words from tokenizers import Tokenizer -from functools import cached_property from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words @@ -560,19 +560,22 @@ def __init__(self, vocab_file=None): @cached_property def katsu(self): import cutlet + return cutlet.Cutlet() - + def check_input_length(self, txt, lang): limit = self.char_limits.get(lang, 250) if len(txt) > limit: - print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.") + print( + f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio." + ) def preprocess_text(self, txt, lang): if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}: txt = multilingual_cleaners(txt, lang) if lang in {"zh", "zh-cn"}: txt = chinese_transliterate(txt) - elif lang == "ja": + elif lang == "ja": txt = japanese_cleaners(txt, self.katsu) elif lang == "ko": txt = korean_cleaners(txt) diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 8cb90ad0f8..2f958cb5a5 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F import torch.utils.data + from TTS.tts.models.xtts import load_audio torch.set_num_threads(1) diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 12c547d684..b8b9a4e388 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -60,7 +60,9 @@ # Training sentences generations -SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences +SPEAKER_REFERENCE = [ + "tests/data/ljspeech/wavs/LJ001-0002.wav" +] # speaker reference to be used in training test sentences LANGUAGE = config_dataset.language diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index b19b7210d8..6663433c12 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -58,7 +58,9 @@ # Training sentences generations -SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences +SPEAKER_REFERENCE = [ + "tests/data/ljspeech/wavs/LJ001-0002.wav" +] # speaker reference to be used in training test sentences LANGUAGE = config_dataset.language From 92fa988aecc2937ac11927e7f0758bc94ee79ded Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 13 Nov 2023 13:44:06 +0100 Subject: [PATCH 3/3] Fixup --- TTS/tts/models/xtts.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 0f79ad6912..b277c3ac72 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -369,11 +369,8 @@ def get_conditioning_latents( audios = [] speaker_embedding = None for file_path in audio_paths: - # load the audio in 24khz to avoid issued with multiple sr references audio = load_audio(file_path, load_sr) audio = audio[:, : load_sr * max_ref_length].to(self.device) - if audio.shape[0] > 1: - audio = audio.mean(0, keepdim=True) if sound_norm_refs: audio = (audio / torch.abs(audio).max()) * 0.75 if librosa_trim_db is not None: