Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update XTTS cloning #3207

Merged
merged 3 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions TTS/cs_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ class CS_API:
},
}


SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"]

def __init__(self, api_token=None, model="XTTS"):
Expand Down Expand Up @@ -308,7 +307,11 @@ def tts_to_file(
print(api.list_speakers_as_tts_models())

ts = time.time()
wav, sr = api.tts("It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name)
wav, sr = api.tts(
"It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name
)
print(f" [i] XTTS took {time.time() - ts:.2f}s")

filepath = api.tts_to_file(text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav")
filepath = api.tts_to_file(
text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav"
)
10 changes: 8 additions & 2 deletions TTS/tts/configs/xtts_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig):
Defaults to `16`.

gpt_cond_len (int):
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.

gpt_cond_chunk_len (int):
Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.

max_ref_len (int):
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
Expand Down Expand Up @@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig):
num_gpt_outputs: int = 1

# cloning
gpt_cond_len: int = 3
gpt_cond_len: int = 12
gpt_cond_chunk_len: int = 4
max_ref_len: int = 10
sound_norm_refs: bool = False
23 changes: 17 additions & 6 deletions TTS/tts/layers/tortoise/dpm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,21 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3,] * (
orders = [
3,
] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [3,] * (
orders = [
3,
] * (
K - 1
) + [1]
else:
orders = [3,] * (
orders = [
3,
] * (
K - 1
) + [2]
elif order == 2:
Expand All @@ -581,7 +587,9 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type
] * K
else:
K = steps // 2 + 1
orders = [2,] * (
orders = [
2,
] * (
K - 1
) + [1]
elif order == 1:
Expand Down Expand Up @@ -1440,7 +1448,10 @@ def sample(
model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep":
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
(
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps,
order=order,
skip_type=skip_type,
Expand Down Expand Up @@ -1548,4 +1559,4 @@ def expand_dims(v, dims):
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
return v[(...,) + (None,) * (dims - 1)]
11 changes: 7 additions & 4 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
import os
import re
from functools import cached_property

import pypinyin
import torch
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
from num2words import num2words
from tokenizers import Tokenizer
from functools import cached_property

from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words

Expand Down Expand Up @@ -560,19 +560,22 @@ def __init__(self, vocab_file=None):
@cached_property
def katsu(self):
import cutlet

return cutlet.Cutlet()

def check_input_length(self, txt, lang):
limit = self.char_limits.get(lang, 250)
if len(txt) > limit:
print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.")
print(
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
)

def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}:
txt = chinese_transliterate(txt)
elif lang == "ja":
elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko":
txt = korean_cleaners(txt)
Expand Down
1 change: 1 addition & 0 deletions TTS/tts/layers/xtts/trainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F
import torch.utils.data

from TTS.tts.models.xtts import load_audio

torch.set_num_threads(1)
Expand Down
101 changes: 70 additions & 31 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,39 +255,57 @@ def device(self):
return next(self.parameters()).device

@torch.inference_mode()
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the default value here and in the config should be the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be better, users can try to call this function individually and then get very different results. I think the better is both be equal to avoid issues like it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config as in the code or the release model's config

"""Compute the conditioning latents for the GPT model from the given audio.

Args:
audio (tensor): audio tensor.
sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3.
length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
is being used without chunking. It must be < `length`. Defaults to 6.
"""
if sr != 22050:
audio = torchaudio.functional.resample(audio, sr, 22050)
audio = audio[:, : 22050 * length]
if length > 0:
audio = audio[:, : 22050 * length]
if self.args.gpt_use_perceiver_resampler:
n_fft = 2048
hop_length = 256
win_length = 1024
style_embs = []
for i in range(0, audio.shape[1], 22050 * chunk_length):
audio_chunk = audio[:, i : i + 22050 * chunk_length]
mel_chunk = wav_to_mel_cloning(
audio_chunk,
mel_norms=self.mel_stats.cpu(),
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
style_embs.append(style_emb)

# mean style embedding
cond_latent = torch.stack(style_embs).mean(dim=0)
else:
n_fft = 4096
hop_length = 1024
win_length = 4096
mel = wav_to_mel_cloning(
audio,
mel_norms=self.mel_stats.cpu(),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
mel = wav_to_mel_cloning(
audio,
mel_norms=self.mel_stats.cpu(),
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)

@torch.inference_mode()
Expand Down Expand Up @@ -323,12 +341,24 @@ def get_speaker_embedding(self, audio, sr):
def get_conditioning_latents(
self,
audio_path,
max_ref_length=30,
gpt_cond_len=6,
max_ref_length=10,
gpt_cond_chunk_len=6,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=24000,
load_sr=22050,
):
"""Get the conditioning latents for the GPT model from the given audio.

Args:
audio_path (str or List[str]): Path to reference audio file(s).
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
"""
# deal with multiples references
if not isinstance(audio_path, list):
audio_paths = [audio_path]
Expand All @@ -349,14 +379,17 @@ def get_conditioning_latents(
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

# compute latents for the decoder
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding)

audios.append(audio)

# use a merge of all references for gpt cond latents
# merge all the audios and compute the latents for the gpt
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]
gpt_cond_latents = self.get_gpt_cond_latents(
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
) # [1, 1024, T]

if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
Expand Down Expand Up @@ -397,6 +430,7 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs
"top_k": config.top_k,
"top_p": config.top_p,
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
Expand All @@ -417,7 +451,8 @@ def full_inference(
top_p=0.85,
do_sample=True,
# Cloning
gpt_cond_len=6,
gpt_cond_len=30,
gpt_cond_chunk_len=6,
max_ref_len=10,
sound_norm_refs=False,
**hf_generate_kwargs,
Expand Down Expand Up @@ -448,7 +483,10 @@ def full_inference(
(aka boring) outputs. Defaults to 0.8.

gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.

gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.

hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
Expand All @@ -461,6 +499,7 @@ def full_inference(
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path,
gpt_cond_len=gpt_cond_len,
gpt_cond_chunk_len=gpt_cond_chunk_len,
max_ref_length=max_ref_len,
sound_norm_refs=sound_norm_refs,
)
Expand Down Expand Up @@ -566,7 +605,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
if overlap_len > len(wav_chunk):
# wav_chunk is smaller than overlap_len, pass on last wav_gen
if wav_gen_prev is not None:
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):]
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
else:
# not expecting will hit here as problem happens on last chunk
wav_chunk = wav_gen[-overlap_len:]
Expand All @@ -576,7 +615,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
wav_chunk[:overlap_len] += crossfade_wav

wav_overlap = wav_gen[-overlap_len:]
wav_gen_prev = wav_gen
return wav_chunk, wav_gen_prev, wav_overlap
Expand Down
4 changes: 3 additions & 1 deletion tests/xtts_tests/test_xtts_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@


# Training sentences generations
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
SPEAKER_REFERENCE = [
"tests/data/ljspeech/wavs/LJ001-0002.wav"
] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language


Expand Down
4 changes: 3 additions & 1 deletion tests/xtts_tests/test_xtts_v2-0_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@


# Training sentences generations
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
SPEAKER_REFERENCE = [
"tests/data/ljspeech/wavs/LJ001-0002.wav"
] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language


Expand Down
Loading