Skip to content

Commit

Permalink
correct model loading, customize model folder, and generation params
Browse files Browse the repository at this point in the history
  • Loading branch information
daswer123 committed Jan 2, 2024
1 parent 5c56c21 commit e1a458e
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 38 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ models
xtts_api_server/models
*.pyc
xtts_api_server/RealtimeTTS/engines/coqui_engine_old.py
xtts_models
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "xtts-api-server"
version = "0.7.6"
version = "0.8.0"
authors = [
{ name="daswer123", email="[email protected]" },
]
Expand Down
6 changes: 3 additions & 3 deletions xtts_api_server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
parser.add_argument("-hs", "--host", default="localhost", help="Host to bind")
parser.add_argument("-p", "--port", default=8020, type=int, help="Port to bind")
parser.add_argument("-d", "--device", default="cuda", type=str, help="Device that will be used, you can choose cpu or cuda")
parser.add_argument("-sf", "--speaker_folder", default="speakers/", type=str, help="The folder where you get the samples for tts")
parser.add_argument("-sf", "--speaker-folder", default="speakers/", type=str, help="The folder where you get the samples for tts")
parser.add_argument("-o", "--output", default="output/", type=str, help="Output folder")
parser.add_argument("-t", "--tunnel", default="", type=str, help="URL of tunnel used (e.g: ngrok, localtunnel)")
parser.add_argument("-mf", "--model-folder", default="xtts_models/", type=str, help="The place where models for XTTS will be stored, finetuned models should be stored in this folder.")
parser.add_argument("-ms", "--model-source", default="local", choices=["api","apiManual", "local"],
help="Define the model source: 'api' for latest version from repository, apiManual for 2.0.2 model and api inference or 'local' for using local inference and model v2.0.2.")
parser.add_argument("-v", "--version", default="v2.0.2", type=str, help="You can specify which version of xtts to use or specify your own model, just upload model folder in models folder ,This version will be used everywhere in local and apiManual.")
Expand All @@ -28,6 +29,7 @@
os.environ['DEVICE'] = args.device # Set environment variable for output folder.
os.environ['OUTPUT'] = args.output # Set environment variable for output folder.
os.environ['SPEAKER'] = args.speaker_folder # Set environment variable for speaker folder.
os.environ['MODEL'] = args.model_folder # Set environment variable for model folder.
os.environ['BASE_HOST'] = host_ip # Set environment variable for base host."
os.environ['BASE_PORT'] = str(args.port) # Set environment variable for base port."
os.environ['BASE_URL'] = "http://" + host_ip + ":" + str(args.port) # Set environment variable for base url."
Expand All @@ -41,8 +43,6 @@
os.environ["STREAM_MODE_IMPROVE"] = str(args.streaming_mode_improve).lower() # Enable improved Streaming mode
os.environ["STREAM_PLAY_SYNC"] = str(args.stream_play_sync).lower() # Enable Streaming mode



from xtts_api_server.server import app

uvicorn.run(app, host=host_ip, port=args.port)
5 changes: 3 additions & 2 deletions xtts_api_server/modeldownloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,15 @@ def check_stream2sentence_version():

def download_model(this_dir,model_version):
# Define paths
base_path = this_dir / 'models'
base_path = this_dir
model_path = base_path / f'{model_version}'

# Define files and their corresponding URLs
files_to_download = {
"config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/config.json",
"model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/{model_version}/model.pth?download=true",
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/vocab.json"
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/vocab.json",
"speakers_xtts.pth": "https://huggingface.co/coqui/XTTS-v2/resolve/main/speakers_xtts.pth?download=true"
}

# Check and create directories
Expand Down
53 changes: 44 additions & 9 deletions xtts_api_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
from argparse import ArgumentParser
from pathlib import Path

from xtts_api_server.tts_funcs import TTSWrapper,supported_languages
from xtts_api_server.tts_funcs import TTSWrapper,supported_languages,InvalidSettingsError
from xtts_api_server.RealtimeTTS import TextToAudioStream, CoquiEngine
from xtts_api_server.modeldownloader import check_stream2sentence_version,install_deepspeed_based_on_python_version

# Default Folders , you can change them via API
DEVICE = os.getenv('DEVICE',"cuda")
OUTPUT_FOLDER = os.getenv('OUTPUT', 'output')
SPEAKER_FOLDER = os.getenv('SPEAKER', 'speakers')
MODEL_FOLDER = os.getenv('MODEL', 'models')
BASE_HOST = os.getenv('BASE_URL', '127.0.0.1:8020')
BASE_URL = os.getenv('BASE_URL', '127.0.0.1:8020')
MODEL_SOURCE = os.getenv("MODEL_SOURCE", "local")
Expand All @@ -40,7 +41,7 @@

# Create an instance of the TTSWrapper class and server
app = FastAPI()
XTTS = TTSWrapper(OUTPUT_FOLDER,SPEAKER_FOLDER,LOWVRAM_MODE,MODEL_SOURCE,MODEL_VERSION,DEVICE,DEEPSPEED,USE_CACHE)
XTTS = TTSWrapper(OUTPUT_FOLDER,SPEAKER_FOLDER,MODEL_FOLDER,LOWVRAM_MODE,MODEL_SOURCE,MODEL_VERSION,DEVICE,DEEPSPEED,USE_CACHE)

# Check for old format model version
XTTS.model_version = XTTS.check_model_version_old_format(MODEL_VERSION)
Expand All @@ -63,12 +64,7 @@
if STREAM_MODE_IMPROVE:
logger.info("You launched an improved version of streaming, this version features an improved tokenizer and more context when processing sentences, which can be good for complex languages like Chinese")

this_dir = Path(__file__).parent.resolve()

if XTTS.isModelOfficial(MODEL_VERSION):
model_path = this_dir / "models"
else:
model_path = "models"
model_path = XTTS.model_folder

engine = CoquiEngine(specific_model=MODEL_VERSION,use_deepspeed=DEEPSPEED,local_models_path=str(model_path))
stream = TextToAudioStream(engine)
Expand Down Expand Up @@ -120,6 +116,18 @@ class OutputFolderRequest(BaseModel):
class SpeakerFolderRequest(BaseModel):
speaker_folder: str

class ModelNameRequest(BaseModel):
model_name: str

class TTSSettingsRequest(BaseModel):
temperature: float
speed: float
length_penalty: float
repetition_penalty: float
top_p: float
top_k: int
enable_text_splitting: bool

class SynthesisRequest(BaseModel):
text: str
speaker_wav: str
Expand Down Expand Up @@ -150,7 +158,16 @@ def get_languages():
def get_folders():
speaker_folder = XTTS.speaker_folder
output_folder = XTTS.output_folder
return {"speaker_folder": speaker_folder, "output_folder": output_folder}
model_folder = XTTS.model_folder
return {"speaker_folder": speaker_folder, "output_folder": output_folder,"model_folder":model_folder}

@app.get("/get_models_list")
def get_models_list():
return XTTS.get_models_list()

@app.get("/get_tts_settings")
def get_tts_settings():
return XTTS.tts_settings

@app.get("/sample/{file_name:path}")
def get_sample(file_name: str):
Expand Down Expand Up @@ -179,6 +196,24 @@ def set_speaker_folder(speaker_req: SpeakerFolderRequest):
logger.error(e)
raise HTTPException(status_code=400, detail=str(e))

@app.post("/switch_model")
def switch_model(modelReq: ModelNameRequest):
try:
XTTS.switch_model(modelReq.model_name)
return {"message": f"Model switched to {modelReq.model_name}"}
except InvalidSettingsError as e:
logger.error(e)
raise HTTPException(status_code=400, detail=str(e))

@app.post("/set_tts_settings")
def set_tts_settings_endpoint(tts_settings_req: TTSSettingsRequest):
try:
XTTS.set_tts_settings(**tts_settings_req.dict())
return {"message": "Settings successfully applied"}
except InvalidSettingsError as e:
logger.error(e)
raise HTTPException(status_code=400, detail=str(e))

@app.get('/tts_stream')
async def tts_stream(request: Request, text: str = Query(), speaker_wav: str = Query(), language: str = Query()):
# Validate local model source.
Expand Down
128 changes: 105 additions & 23 deletions xtts_api_server/tts_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
import wave
import numpy as np

# Class to check tts settings
class InvalidSettingsError(Exception):
pass

# List of supported language codes
supported_languages = {
"ar":"Arabic",
Expand All @@ -43,13 +47,23 @@
"hi":"Hindi"
}

default_tts_settings = {
"temperature" : 0.75,
"length_penalty" : 1.0,
"repetition_penalty": 5.0,
"top_k" : 50,
"top_p" : 0.85,
"speed" : 1,
"enable_text_splitting": True
}

official_model_list = ["v2.0.0","v2.0.1","v2.0.2","v2.0.3","main"]
official_model_list_v2 = ["2.0.0","2.0.1","2.0.2","2.0.3"]

reversed_supported_languages = {name: code for code, name in supported_languages.items()}

class TTSWrapper:
def __init__(self,output_folder = "./output", speaker_folder="./speakers",lowvram = False,model_source = "local",model_version = "2.0.2",device = "cuda",deepspeed = False,enable_cache_results = True):
def __init__(self,output_folder = "./output", speaker_folder="./speakers",model_folder="./xtts_folder",lowvram = False,model_source = "local",model_version = "2.0.2",device = "cuda",deepspeed = False,enable_cache_results = True):

self.cuda = device # If the user has chosen what to use, we rewrite the value to the value we want to use
self.device = 'cpu' if lowvram else (self.cuda if torch.cuda.is_available() else "cpu")
Expand All @@ -59,12 +73,13 @@ def __init__(self,output_folder = "./output", speaker_folder="./speakers",lowvra

self.model_source = model_source
self.model_version = model_version
self.tts_settings = default_tts_settings

self.deepspeed = deepspeed

self.speaker_folder = speaker_folder
self.output_folder = output_folder
self.custom_models_folder = "./models"
self.model_folder = model_folder

self.create_directories()
check_tts_version()
Expand All @@ -90,6 +105,14 @@ def check_model_version_old_format(self,model_version):
return "v"+model_version
return model_version

def get_models_list(self):
# Fetch all entries in the directory given by self.model_folder
entries = os.listdir(self.model_folder)

# Filter out and return only directories
return [name for name in entries if os.path.isdir(os.path.join(self.model_folder, name))]


def get_wav_header(self, channels:int=1, sample_rate:int=24000, width:int=2) -> bytes:
wav_buf = io.BytesIO()
with wave.open(wav_buf, "wb") as out:
Expand Down Expand Up @@ -147,12 +170,11 @@ def load_model(self,load=True):
self.model = TTS("tts_models/multilingual/multi-dataset/xtts_v2")

if self.model_source == "apiManual":
this_dir = Path(__file__).parent.resolve() / "models"
this_dir = Path(self.model_folder)

if self.isModelOfficial(self.model_version):
download_model(this_dir,self.model_version)
else:
this_dir = Path(self.custom_models_folder).resolve()


config_path = this_dir / f'{self.model_version}' / 'config.json'
checkpoint_dir = this_dir / f'{self.model_version}'

Expand All @@ -170,13 +192,11 @@ def load_model(self,load=True):
logger.info("Model successfully loaded ")

def load_local_model(self,load=True):
this_model_dir = Path(__file__).parent.resolve()
this_model_dir = Path(self.model_folder)

if self.isModelOfficial(self.model_version):
download_model(this_model_dir,self.model_version)
this_model_dir = this_model_dir / "models"
else:
this_model_dir = Path(self.custom_models_folder)
this_model_dir = this_model_dir

config = XttsConfig()
config_path = this_model_dir / f'{self.model_version}' / 'config.json'
Expand All @@ -188,6 +208,34 @@ def load_local_model(self,load=True):
self.model.load_checkpoint(config,use_deepspeed=self.deepspeed, checkpoint_dir=str(checkpoint_dir))
self.model.to(self.device)

def switch_model(self,model_name):

model_list = self.get_models_list()
# Check to see if the same name is selected
if(model_name == self.model_version):
raise InvalidSettingsError("The model with this name is already loaded in memory")
return

# Check if the model is in the list at all
if(model_name not in model_list):
raise InvalidSettingsError(f"A model with `{model_name}` name is not in the models folder, the current available models: {model_list}")
return

# Clear gpu cache from old model
self.model = ""
torch.cuda.empty_cache()
logger.info("Model successfully unloaded from memory")

# Start load model
logger.info(f"Start loading {model_name} model")
self.model_version = model_name
if self.model_source == "local":
self.load_local_model()
else:
self.load_model()

logger.info(f"Model successfully loaded")

# LOWVRAM FUNCS
def switch_model_device(self):
# We check for lowram and the existence of cuda
Expand Down Expand Up @@ -222,7 +270,7 @@ def create_latents_for_all(self):

# DIRICTORIES FUNCS
def create_directories(self):
directories = [self.output_folder, self.speaker_folder,self.custom_models_folder]
directories = [self.output_folder, self.speaker_folder,self.model_folder]

for sanctuary in directories:
# List of folders to be checked for existence
Expand All @@ -249,6 +297,50 @@ def set_out_folder(self, folder):
else:
raise ValueError("Provided path is not a valid directory")

def set_tts_settings(self, temperature, speed, length_penalty,
repetition_penalty, top_p, top_k, enable_text_splitting):
# Validate each parameter and raise an exception if any checks fail.

# Check temperature
if not (0.01 <= temperature <= 1):
raise InvalidSettingsError("Temperature must be between 0.01 and 1.")

# Check speed
if not (0.2 <= speed <= 2):
raise InvalidSettingsError("Speed must be between 0.2 and 2.")

# Check length_penalty (no explicit range specified)
if not isinstance(length_penalty, float):
raise InvalidSettingsError("Length penalty must be a floating point number.")

# Check repetition_penalty
if not (0.1 <= repetition_penalty <= 10.0):
raise InvalidSettingsError("Repetition penalty must be between 0.1 and 10.0.")

# Check top_p
if not (0.01 <= top_p <= 1):
raise InvalidSettingsError("Top_p must be between 0.01 and 1 and must be a float.")

# Check top_k
if not (1 <= top_k <= 100):
raise InvalidSettingsError("Top_k must be an integer between 1 and 100.")

# Check enable_text_splitting
if not isinstance(enable_text_splitting, bool):
raise InvalidSettingsError("Enable text splitting must be either True or False.")

# All validations passed - proceed to apply settings.
self.tts_settings = {
"temperature": temperature,
"speed": speed,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"top_p": top_p,
"top_k": top_k,
"enable_text_splitting": enable_text_splitting,
}
print("Successfully updated TTS settings.")

# GET FUNCS
def get_wav_files(self, directory):
""" Finds all the wav files in a directory. """
Expand Down Expand Up @@ -361,12 +453,7 @@ async def stream_generation(self,text,speaker_name,speaker_wav,language,output_f
language,
speaker_embedding=speaker_embedding,
gpt_cond_latent=gpt_cond_latent,
temperature=0.75,
length_penalty=1.0,
repetition_penalty=5.0,
top_k=50,
top_p=0.85,
enable_text_splitting=True,
**self.tts_settings, # Expands the object with the settings and applies them for generation
stream_chunk_size=100,
)

Expand Down Expand Up @@ -402,12 +489,7 @@ def local_generation(self,text,speaker_name,speaker_wav,language,output_file):
language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=0.75,
length_penalty=1.0,
repetition_penalty=5.0,
top_k=50,
top_p=0.85,
enable_text_splitting=True
**self.tts_settings, # Expands the object with the settings and applies them for generation
)

torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)
Expand Down

0 comments on commit e1a458e

Please sign in to comment.