Skip to content

Commit

Permalink
Merge pull request #51 from idiap/update-trainer
Browse files Browse the repository at this point in the history
Update to coqui-tts-trainer 0.1.4
  • Loading branch information
eginhard authored Jul 2, 2024
2 parents ff2cd5c + 8cab2e3 commit c1a929b
Show file tree
Hide file tree
Showing 33 changed files with 63 additions and 166 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ jobs:
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS
run: |
python3 -m uv pip install --system "coqui-tts[dev,server,languages] @ ."
python3 setup.py egg_info
resolution=highest
if [ "${{ matrix.python-version }}" == "3.9" ]; then
resolution=lowest-direct
fi
python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ."
- name: Unit tests
run: make ${{ matrix.subset }}
- name: Upload coverage data
Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/compute_attention_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from trainer.io import load_checkpoint

from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.io import load_checkpoint

if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
Expand Down
2 changes: 1 addition & 1 deletion TTS/encoder/models/base_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torchaudio
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec

from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
5 changes: 2 additions & 3 deletions TTS/encoder/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.generic_utils import get_experiment_folder_path
from trainer.generic_utils import get_experiment_folder_path, get_git_branch
from trainer.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger

from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_git_branch


@dataclass
Expand All @@ -30,7 +29,7 @@ def process_args(args, config=None):
args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
c (Coqpit): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
Expand Down
3 changes: 2 additions & 1 deletion TTS/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def load_checkpoint(
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
"""
...
3 changes: 2 additions & 1 deletion TTS/tts/configs/bark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from dataclasses import dataclass, field
from typing import Dict

from trainer.io import get_user_data_dir

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.layers.bark.model import GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPTConfig
from TTS.tts.models.bark import BarkAudioConfig
from TTS.utils.generic_utils import get_user_data_dir


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec

from TTS.utils.io import load_fsspec
from TTS.vocoder.models.hifigan_generator import get_padding

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torchaudio
from coqpit import Coqpit
from torch.utils.data import DataLoader
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler

Expand All @@ -18,7 +19,6 @@
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/align_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec

from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder
Expand All @@ -15,7 +16,6 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec


@dataclass
Expand Down
5 changes: 3 additions & 2 deletions TTS/tts/models/base_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec

from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
Expand All @@ -15,7 +16,6 @@
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,7 +103,8 @@ def load_checkpoint(
config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file.
eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/delightful_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler

Expand All @@ -32,7 +33,6 @@
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
from TTS.utils.audio.processor import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/forward_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from trainer.io import load_fsspec

from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
Expand All @@ -17,7 +18,6 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from trainer.io import load_fsspec

from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder
Expand All @@ -17,7 +18,6 @@
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/neuralhmm_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger

from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
Expand All @@ -18,7 +19,6 @@
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from coqpit import Coqpit
from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger

from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
Expand All @@ -19,7 +20,6 @@
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler

Expand All @@ -34,7 +35,6 @@
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
from trainer.io import load_fsspec

from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec

logger = logging.getLogger(__name__)

Expand Down
38 changes: 0 additions & 38 deletions TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,13 @@
import datetime
import importlib
import logging
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import Dict, Optional

logger = logging.getLogger(__name__)


# TODO: This method is duplicated in Trainer but out of date there
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
except (FileNotFoundError, StopIteration) as e:
current = "unknown"
return current


def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
Expand Down Expand Up @@ -67,28 +51,6 @@ def get_import_path(obj: object) -> str:
return ".".join([type(obj).__module__, type(obj).__name__])


def get_user_data_dir(appname):
TTS_HOME = os.environ.get("TTS_HOME")
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
if TTS_HOME is not None:
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
elif XDG_DATA_HOME is not None:
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel

key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)


def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
Expand Down
70 changes: 0 additions & 70 deletions TTS/utils/io.py

This file was deleted.

Loading

0 comments on commit c1a929b

Please sign in to comment.