diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index b7e07589c5..6a1528565b 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -151,18 +151,7 @@ def generate( # noqa: PLR0911 # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs # inputs_tensor has to be defined @@ -174,6 +163,9 @@ def generate( # noqa: PLR0911 ) batch_size = inputs_tensor.shape[0] + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states @@ -182,7 +174,7 @@ def generate( # noqa: PLR0911 accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config.pad_token_id, diff --git a/pyproject.toml b/pyproject.toml index ff2ff32dd4..36172fa366 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ dependencies = [ "gruut[de,es,fr]==2.2.3", # Tortoise "einops>=0.6.0", - "transformers>=4.33.0,<4.41.0", + "transformers>=4.41.1", # Bark "encodec>=0.1.1", # XTTS