Skip to content

Commit

Permalink
add feat: Add ChatTTS vLLM Wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ylzz1997 committed Jul 20, 2024
1 parent 51ec0c7 commit 8e6184e
Show file tree
Hide file tree
Showing 16 changed files with 5,208 additions and 124 deletions.
297 changes: 173 additions & 124 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from json import load
from pathlib import Path
import lzma

import pathlib
from ChatTTS.vllm_engine.post_model import Post_model
from safetensors.torch import save_file, safe_open
import numpy as np
import torch
from vocos import Vocos
from vocos.pretrained import instantiate_class
from huggingface_hub import snapshot_download
import pybase16384 as b14

from .config import Config
from ChatTTS.vllm_engine.llm import LLM
from ChatTTS.vllm_engine.sampling_params import SamplingParams
import yaml
from .model import DVAE, GPT, gen_logits, Tokenizer
from .utils import (
check_all_assets,
Expand Down Expand Up @@ -167,7 +170,7 @@ def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:

@torch.no_grad()
def _sample_random_speaker(self) -> torch.Tensor:
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
dim: int = self.hidden_size
spk = (
torch.randn(dim, device=self.std.device, dtype=self.std.dtype)
.mul_(self.std)
Expand Down Expand Up @@ -266,56 +269,64 @@ def _load(
if "mps" in str(device)
else device
)
.eval()
)
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")

dvae = (
DVAE(
decoder_config=asdict(self.config.dvae.decoder),
encoder_config=asdict(self.config.dvae.encoder),
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
self.num_vq = 4
if not os.path.exists("asset/vllm_model"):
gpt = GPT(
**cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True)
self.gpt.gpt.save_pretrained("asset/vllm_model/gpt")
self.post_model = Post_model(
cfg.gpt_config.hidden_size,
cfg.num_audio_tokens,
cfg.num_text_tokens,
device = device
).to(device).eval()

self.post_model.emb_code = self.gpt.emb_code
self.post_model.emb_text = self.gpt.emb_text
self.post_model.head_text = self.gpt.head_text
self.post_model.head_code = self.gpt.head_code
save_file(self.post_model.state_dict(), "asset/vllm_model/post_model.safetensors")

self.num_audio_tokens = cfg.num_audio_tokens
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(
spk_stat_path
), f"Missing spk_stat.pt: {spk_stat_path}"
spk_stat: torch.Tensor = torch.load(
spk_stat_path,
weights_only=True,
mmap=True,
map_location=device,
)
.to(device)
.eval()
)
coef = str(dvae)
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")

gpt = GPT(
gpt_config=asdict(self.config.gpt),
use_flash_attn=use_flash_attn,
device=device,
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(gpt_ckpt_path)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}"
spk_stat: torch.Tensor = torch.load(
spk_stat_path,
weights_only=True,
mmap=True,
map_location=device,
)
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.logger.log(logging.INFO, "gpt loaded.")

decoder = (
DVAE(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.logger.log(logging.INFO, "gpt loaded.")

self.hidden_size = cfg.gpt_config.hidden_size
self.gpt = LLM(
model="asset/vllm_model/gpt",
num_audio_tokens = cfg.num_audio_tokens,
num_text_tokens = cfg.num_text_tokens,
post_model_path="asset/vllm_model/post_model.safetensors",
)

if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(decoder)
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_state_dict(
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
)
.to(device)
.eval()
Expand All @@ -335,7 +346,7 @@ def _load(
self.coef = coef

return self.has_loaded()

def _infer(
self,
text,
Expand Down Expand Up @@ -451,6 +462,55 @@ def _decode_to_wavs(
del mel_specs
return wavs

@staticmethod
def _decode_spk_emb(spk_emb: str) -> np.ndarray:
return np.frombuffer(
lzma.decompress(
b14.decode_from_string(spk_emb),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
dtype=np.float16,
).copy()

@torch.no_grad()
def _apply_spk_emb(
self,
emb: torch.Tensor,
spk_emb: str,
input_ids: torch.Tensor,
):
n = (
F.normalize(
torch.from_numpy(
self._decode_spk_emb(spk_emb),
),
p=2.0,
dim=0,
eps=1e-12,
)
.to(self.gpt.device_gpt)
.unsqueeze_(0)
.expand(emb.size(0), -1)
.unsqueeze_(1)
.expand(emb.shape)
)
cond = (
input_ids.narrow(-1, 0, 1).eq(self.tokenizer.spk_emb_ids).expand(emb.shape)
)
torch.where(cond, n, emb, out=emb)
del cond, n
@dataclass(repr=False, eq=False)
class GenerationOutputs:
ids: List[torch.Tensor]
# attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
hiddens: List[torch.Tensor]

def destroy(self):
del_all(self.ids)
# del_all(self.attentions)
# del_all(self.hiddens)

@torch.no_grad()
def _infer_code(
self,
Expand All @@ -461,15 +521,15 @@ def _infer_code(
params: InferCodeParams,
):

gpt = self.gpt
gpt: LLM = self.gpt

if not isinstance(text, list):
text = [text]

assert len(text), "text should not be empty"

if not isinstance(params.temperature, list):
temperature = [params.temperature] * gpt.num_vq
temperature = [params.temperature] * self.num_vq
else:
temperature = params.temperature

Expand All @@ -494,54 +554,46 @@ def _infer_code(
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.gpt.num_vq,
prompt_str=params.spk_smp,
device=gpt.device_gpt,
text, self.num_vq, self.device
)

emb = gpt(input_ids, text_mask)

del text_mask

if params.spk_emb is not None:
self.tokenizer.apply_spk_emb(
emb, params.spk_emb, input_ids, self.gpt.device_gpt
)

num_code = int(gpt.emb_code[0].num_embeddings - 1)
start_idx = input_ids.shape[-2]

num_code = self.num_audio_tokens - 1

logits_warpers, logits_processors = gen_logits(
num_code=num_code,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)

result = gpt.generate(
emb,
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,

sample_params = SamplingParams(
temperature=temperature,
max_new_token=params.max_new_token,
max_tokens = 8192,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
logits_processors=logits_processors,
logits_processors=(logits_warpers, logits_processors),
eos_token = num_code,
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
context=self.context,
start_idx=start_idx
)

del emb, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return result
input_ids = [i.tolist() for i in input_ids]

result = gpt.generate(
None,
sample_params,
input_ids,
)

token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(i.outputs[0].hidden_states.to(torch.float32).to(self.device))
return [self.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states
),]

@torch.no_grad()
def _refine_text(
Expand All @@ -551,51 +603,48 @@ def _refine_text(
params: RefineTextParams,
):

gpt = self.gpt
gpt:LLM = self.gpt

if not isinstance(text, list):
text = [text]

text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.gpt.num_vq,
device=gpt.device_gpt,
text, self.num_vq, self.device
)

start_idx = input_ids.shape[-2]
# print(start_idx)
logits_warpers, logits_processors = gen_logits(
num_code=self.tokenizer.len,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)

emb = gpt(input_ids, text_mask)

del text_mask

result = next(
gpt.generate(
emb,
input_ids,
temperature=torch.tensor([params.temperature], device=device),
eos_token=self.tokenizer.eos_token,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
logits_processors=logits_processors,
infer_text=True,
stream=False,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
context=self.context,
)
sample_params = SamplingParams(
temperature=params.temperature,
max_new_token=params.max_new_token,
max_tokens = 8192,
min_new_token=params.min_new_token,
logits_processors=(logits_warpers, logits_processors),
eos_token = self.tokenizer.eos_token,
infer_text=True,
start_idx=start_idx
)
input_ids = [i.tolist() for i in input_ids]

result = gpt.generate(
None,
sample_params,
input_ids
)
token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(i.outputs[0].hidden_states)
return self.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states
)

del emb, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return result
Empty file added ChatTTS/vllm_engine/__init__.py
Empty file.
Loading

0 comments on commit 8e6184e

Please sign in to comment.