Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR for Version Up in upstream repos #4

Merged
merged 9 commits into from
Apr 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ PyTorch implementation of convolutional networks-based text-to-speech synthesis

Audio samples are available at https://r9y9.github.io/deepvoice3_pytorch/.

## Online TTS demo

A notebook supposed to be executed on https://colab.research.google.com is available:

- [DeepVoice3: Multi-speaker text-to-speech demo](https://colab.research.google.com/github/r9y9/Colaboratory/blob/master/DeepVoice3_multi_speaker_TTS_en_demo.ipynb)

## Highlights

- Convolutional sequence-to-sequence model with attention for text-to-speech synthesis
Expand Down Expand Up @@ -53,9 +59,7 @@ See "Synthesize from a checkpoint" section in the README for how to generate spe

- Python 3
- CUDA >= 8.0
- PyTorch >= v0.3
- TensorFlow >= v1.3
- [tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch) (master)
- [nnmnkwii](https://github.com/r9y9/nnmnkwii) >= v0.0.11
- [MeCab](http://taku910.github.io/mecab/) (Japanese only)

Expand Down
2 changes: 2 additions & 0 deletions deepvoice3_pytorch/frontend/text/numbers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-

import inflect
import re

Expand Down
2 changes: 1 addition & 1 deletion hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Convenient model builder
# [deepvoice3, deepvoice3_multispeaker, nyanko]
# Definitions can be found at deepvoice3_pytorch/builder.py
# deepvoice3: DeepVoice3 https://arxiv.org/abs/1710.07654
# deepvoice3: DeepVoice3 https://arxiv.org/abs/1710.07654
# deepvoice3_multispeaker: Multi-speaker version of DeepVoice3
# nyanko: https://arxiv.org/abs/1710.08969
builder="deepvoice3",
Expand Down
4 changes: 2 additions & 2 deletions presets/deepvoice3_vctk.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "deepvoice3",
"frontend": "en",
"replace_pronunciation_prob": 0.5,
"builder": "deepvoice3",
"builder": "deepvoice3_multispeaker",
"n_speakers": 108,
"speaker_embed_dim": 16,
"num_mels": 80,
Expand Down Expand Up @@ -62,4 +62,4 @@
"window_ahead": 3,
"window_backward": 1,
"power": 1.4
}
}
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess
from os.path import exists

version = '0.0.4'
version = '0.0.5'

# Adapted from https://github.com/pytorch/pytorch
cwd = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -79,6 +79,7 @@ def create_readme_rst():
install_requires=[
"numpy",
"scipy",
"torch >= 0.3.0",
"unidecode",
"inflect",
"librosa",
Expand Down
15 changes: 12 additions & 3 deletions synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ def tts(model, text, p=0, speaker_id=None, fast=False):
return waveform, alignment, spectrogram, mel


def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint


if __name__ == "__main__":
args = docopt(__doc__)
print("Command line args:\n", args)
Expand Down Expand Up @@ -113,13 +122,13 @@ def tts(model, text, p=0, speaker_id=None, fast=False):

# Load checkpoints separately
if checkpoint_postnet_path is not None and checkpoint_seq2seq_path is not None:
checkpoint = torch.load(checkpoint_seq2seq_path)
checkpoint = _load(checkpoint_seq2seq_path)
model.seq2seq.load_state_dict(checkpoint["state_dict"])
checkpoint = torch.load(checkpoint_postnet_path)
checkpoint = _load(checkpoint_postnet_path)
model.postnet.load_state_dict(checkpoint["state_dict"])
checkpoint_name = splitext(basename(checkpoint_seq2seq_path))[0]
else:
checkpoint = torch.load(checkpoint_path)
checkpoint = _load(checkpoint_path)
model.load_state_dict(checkpoint["state_dict"])
checkpoint_name = splitext(basename(checkpoint_path))[0]

Expand Down
38 changes: 31 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,12 +815,21 @@ def build_model():
return model


def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint


def load_checkpoint(path, model, optimizer, reset_optimizer):
global global_step
global global_epoch

print("Load checkpoint from: {}".format(path))
checkpoint = torch.load(path)
checkpoint = _load(path)
model.load_state_dict(checkpoint["state_dict"])
if not reset_optimizer:
optimizer_state = checkpoint["optimizer"]
Expand All @@ -834,19 +843,33 @@ def load_checkpoint(path, model, optimizer, reset_optimizer):


def _load_embedding(path, model):
state = torch.load(path)["state_dict"]
state = _load(path)["state_dict"]
key = "seq2seq.encoder.embed_tokens.weight"
model.seq2seq.encoder.embed_tokens.weight.data = state[key]


# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3


def restore_parts(path, model):
print("Restore part of the model from: {}".format(path))
state = torch.load(path)["state_dict"]
state = _load(path)["state_dict"]
model_dict = model.state_dict()
valid_state_dict = {k: v for k, v in state.items() if k in model_dict}
model_dict.update(valid_state_dict)
model.load_state_dict(model_dict)

try:
model_dict.update(valid_state_dict)
model.load_state_dict(model_dict)
except RuntimeError as e:
# there should be invalid size of weight(s), so load them per parameter
print(str(e))
model_dict = model.state_dict()
for k, v in valid_state_dict.items():
model_dict[k] = v
try:
model.load_state_dict(model_dict)
except RuntimeError as e:
print(str(e))
warn("{}: may contain invalid size of weight. skipping...".format(k))


if __name__ == "__main__":
Expand Down Expand Up @@ -951,7 +974,8 @@ def restore_parts(path, model):
# Setup summary writer for tensorboard
if log_event_path is None:
if platform.system() == "Windows":
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_").replace(":","_")
log_event_path = "log/run-test" + \
str(datetime.now()).replace(" ", "_").replace(":", "_")
else:
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_")
print("Los event path: {}".format(log_event_path))
Expand Down