Skip to content

Commit

Permalink
Merge pull request #5 from r9y9/master
Browse files Browse the repository at this point in the history
Reverse PR
  • Loading branch information
engiecat authored Apr 27, 2018
2 parents 32cab90 + df8695c commit 6d8973a
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 26 deletions.
9 changes: 8 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ before_install:
# Useful for debugging any issues with conda
- conda config --add channels pypi
- conda info -a
- deps='pip numpy scipy cython nose pytorch'
- deps='pip numpy scipy cython nose pytorch flake8'
- conda create -q -n test-environment "python=$TRAVIS_PYTHON_VERSION" $deps -c pytorch
- source activate test-environment

install:
- pip install -e ".[test]"
- python -c "import nltk; nltk.download('cmudict')"

before_script:
# stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

script:
- nosetests -v -w tests/ -a '!local_only'

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

PyTorch implementation of convolutional networks-based text-to-speech synthesis models:

1. [arXiv:1710.07654](https://arxiv.org/abs/1710.07654): Deep Voice 3: 2000-Speaker Neural Text-to-Speech.
1. [arXiv:1710.07654](https://arxiv.org/abs/1710.07654): Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning.
2. [arXiv:1710.08969](https://arxiv.org/abs/1710.08969): Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.

Audio samples are available at https://r9y9.github.io/deepvoice3_pytorch/.
Expand Down
2 changes: 1 addition & 1 deletion deepvoice3_pytorch/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def incremental_forward(self, input):
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = torch.autograd.Variable(self.input_buffer, volatile=True)
input = torch.Tensor(self.input_buffer)
if dilation > 1:
input = input[:, 0::dilation, :].contiguous()
output = F.linear(input.view(bsz, -1), weight, self.bias)
Expand Down
16 changes: 6 additions & 10 deletions deepvoice3_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,18 @@ def forward(self, x, w=1.0):

if isscaler or w.size(0) == 1:
weight = sinusoidal_encode(self.weight, w)
return self._backend.Embedding.apply(
x, weight,
padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse
)
return F.embedding(
x, weight, padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
else:
# TODO: cannot simply apply for batch
# better to implement efficient function
pe = []
for batch_idx, we in enumerate(w):
weight = sinusoidal_encode(self.weight, we)
pe.append(self._backend.Embedding.apply(
x[batch_idx], weight,
padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse
))
pe.append(F.embedding(
x[batch_idx], weight, padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse))
pe = torch.stack(pe)
return pe

Expand Down
2 changes: 1 addition & 1 deletion docs/content/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,5 @@ Your browser does not support the audio element.

## References

- [Wei Ping, Kainan Peng, Andrew Gibiansky, et al, "Deep Voice 3: 2000-Speaker Neural Text-to-Speech", arXiv:1710.07654, Oct. 2017.](https://arxiv.org/abs/1710.07654)
- [Wei Ping, Kainan Peng, Andrew Gibiansky, et al, "Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning", arXiv:1710.07654, Oct. 2017.](https://arxiv.org/abs/1710.07654)
- [Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara, "Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention". arXiv:1710.08969, Oct 2017.](https://arxiv.org/abs/1710.08969)
2 changes: 2 additions & 0 deletions nikl_m.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import re
from hparams import hparams

from hparams import hparams


def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
'''Preprocesses the LJ Speech dataset from a given input path into a given output directory.
Expand Down
2 changes: 2 additions & 0 deletions nikl_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import audio
import re

from hparams import hparams


def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
'''Preprocesses the LJ Speech dataset from a given input path into a given output directory.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_deepvoice3.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_incremental_forward():
max_input_len = np.max(input_lengths) + 10 # manuall padding
seqs = np.array([_pad(x, max_input_len) for x in seqs], dtype=np.int)
input_lengths = torch.LongTensor(input_lengths)
input_lengths = input_lengths.cuda() if use_cuda else input_lenghts
input_lengths = input_lengths.cuda() if use_cuda else input_lengths
else:
input_lengths = None

Expand Down
22 changes: 11 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,30 +728,30 @@ def train(model, data_loader, optimizer, writer,
# Update
loss.backward()
if clip_thresh > 0:
grad_norm = torch.nn.utils.clip_grad_norm(
grad_norm = torch.nn.utils.clip_grad_norm_(
model.get_trainable_parameters(), clip_thresh)
optimizer.step()

# Logs
writer.add_scalar("loss", float(loss.data[0]), global_step)
writer.add_scalar("loss", float(loss.item()), global_step)
if train_seq2seq:
writer.add_scalar("done_loss", float(done_loss.data[0]), global_step)
writer.add_scalar("mel loss", float(mel_loss.data[0]), global_step)
writer.add_scalar("mel_l1_loss", float(mel_l1_loss.data[0]), global_step)
writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.data[0]), global_step)
writer.add_scalar("done_loss", float(done_loss.item()), global_step)
writer.add_scalar("mel loss", float(mel_loss.item()), global_step)
writer.add_scalar("mel_l1_loss", float(mel_l1_loss.item()), global_step)
writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.item()), global_step)
if train_postnet:
writer.add_scalar("linear_loss", float(linear_loss.data[0]), global_step)
writer.add_scalar("linear_l1_loss", float(linear_l1_loss.data[0]), global_step)
writer.add_scalar("linear_loss", float(linear_loss.item()), global_step)
writer.add_scalar("linear_l1_loss", float(linear_l1_loss.item()), global_step)
writer.add_scalar("linear_binary_div_loss", float(
linear_binary_div.data[0]), global_step)
linear_binary_div.item()), global_step)
if train_seq2seq and hparams.use_guided_attention:
writer.add_scalar("attn_loss", float(attn_loss.data[0]), global_step)
writer.add_scalar("attn_loss", float(attn_loss.item()), global_step)
if clip_thresh > 0:
writer.add_scalar("gradient norm", grad_norm, global_step)
writer.add_scalar("learning rate", current_lr, global_step)

global_step += 1
running_loss += loss.data[0]
running_loss += loss.item()

averaged_loss = running_loss / (len(data_loader))
writer.add_scalar("loss (per epoch)", averaged_loss, global_epoch)
Expand Down

0 comments on commit 6d8973a

Please sign in to comment.