From dc815fdd2e241ab63323f0811087c67ea596210d Mon Sep 17 00:00:00 2001 From: ain-soph Date: Sun, 11 Aug 2024 11:36:13 -0400 Subject: [PATCH] add experimental argument --- ChatTTS/core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 79ef9bc8d..9f730bdf4 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -127,6 +127,7 @@ def load( coef: Optional[torch.Tensor] = None, use_flash_attn=False, use_vllm=False, + experimental: bool = False, ) -> bool: download_path = self.download_models(source, force_redownload, custom_path) if download_path is None: @@ -137,6 +138,7 @@ def load( coef=coef, use_flash_attn=use_flash_attn, use_vllm=use_vllm, + experimental=experimental, **{ k: os.path.join(download_path, v) for k, v in asdict(self.config.path).items() @@ -233,9 +235,10 @@ def _load( coef: Optional[str] = None, use_flash_attn=False, use_vllm=False, + experimental: bool = False, ): if device is None: - device = select_device() + device = select_device(experimental=experimental) self.logger.info("use device %s", str(device)) self.device = device self.device_gpt = device if "mps" not in str(device) else torch.device("cpu") @@ -287,7 +290,7 @@ def _load( logger=self.logger, ).eval() assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.from_pretrained(gpt_ckpt_path) + gpt.from_pretrained(gpt_ckpt_path, experimental=experimental) gpt.prepare(compile=compile and "cuda" in str(device)) self.gpt = gpt