Skip to content

Commit

Permalink
Post PR commit - Also fixed #5
Browse files Browse the repository at this point in the history
r9y9#53 (comment) issue solved in PyTorch 0.4
  • Loading branch information
engiecat committed May 5, 2018
1 parent 25e95d8 commit 92a94b8
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 26 deletions.
7 changes: 1 addition & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ generated
data
text
datasets
testout

# Created by https://www.gitignore.io

Expand Down Expand Up @@ -200,9 +201,3 @@ Temporary Items

# Linux trash folder which might appear on any partition or disk
.Trash-*
vctk_preprocess/WorkingHowToUseThis.txt
GoTBook1.01.txt
presets/deepvoice3_got.json
presets/deepvoice3_gotOnly.json
presets/deepvoice3_stest.json
presets/deepvoice3_test.json
2 changes: 1 addition & 1 deletion deepvoice3_pytorch/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def incremental_forward(self, input):
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = torch.Tensor(self.input_buffer)
input = self.input_buffer.clone()
if dilation > 1:
input = input[:, 0::dilation, :].contiguous()
output = F.linear(input.view(bsz, -1), weight, self.bias)
Expand Down
4 changes: 2 additions & 2 deletions gentle_web_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Created on Sat Apr 21 09:06:37 2018
Phoneme alignment and conversion in HTK-style label file using Web-served Gentle
This works on any type of english dataset.
This allows its usage on Windows (Via Docker) and external server.
Unlike prepare_htk_alignments_vctk.py, this is Python3 and Windows(with Docker) compatible.
Preliminary results show that gentle has better performance with noisy dataset
(e.g. movie extracted audioclips)
*This work was derived from vctk_preprocess/prepare_htk_alignments_vctk.py
Expand Down Expand Up @@ -109,7 +109,7 @@ def gentle_request(wav_path,txt_path, server_addr, port, debug=False):
server_addr = arguments['--server_addr']
port = int(arguments['--port'])
max_unalign = float(arguments['--max_unalign'])
if arguments['--nested-directories'] == None:
if arguments['--nested-directories'] is None:
wav_paths = sorted(glob(arguments['--wav_pattern']))
txt_paths = sorted(glob(arguments['--txt_pattern']))
else:
Expand Down
10 changes: 5 additions & 5 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@
# Forced garbage collection probability
# Use only when MemoryError continues in Windows (Disabled by default)
#gc_probability = 0.001,
# json_meta mode only
# 0: "use all",
# json_meta mode only
# 0: "use all",
# 1: "ignore only unmatched_alignment",
# 2: "fully ignore recognition",
ignore_recognition_level = 2,
min_text=20,
process_only_htk_aligned = False,
min_text=20, # when dealing with non-dedicated speech dataset(e.g. movie excerpts), setting min_text above 15 is desirable. Can be adjusted by dataset.
process_only_htk_aligned = False, # if true, data without phoneme alignment file(.lab) will be ignored
)


Expand Down
4 changes: 2 additions & 2 deletions json_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
# ignore_recognition_level check
for path in info.keys():
is_aligned[path] = True
if type(info[path]) == list:
if isinstance(info[path], list):
if hparams.ignore_recognition_level == 1 and len(info[path]) == 1 or \
hparams.ignore_recognition_level == 2:
# flag the path to be 'non-aligned' text
Expand All @@ -96,7 +96,7 @@ def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
# Reserve for future processing
queue_count = 0
for audio_path, text in info.items():
if type(text)==list:
if isinstance(text, list):
if hparams.ignore_recognition_level == 0:
text = text[-1]
else:
Expand Down
7 changes: 1 addition & 6 deletions nikl_m.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import audio
import re
from hparams import hparams

from hparams import hparams

Expand Down Expand Up @@ -73,11 +72,7 @@ def _process_utterance(out_dir, index, speaker_id, wav_path, text):
wav = wav / np.abs(wav).max() * hparams.rescaling_max

# Compute the linear-scale spectrogram from the wav:
try:
spectrogram = audio.spectrogram(wav).astype(np.float32)
except:
print(wav_path)
print(wav)
spectrogram = audio.spectrogram(wav).astype(np.float32)
n_frames = spectrogram.shape[1]

# Compute a mel-scale spectrogram from the wav:
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@ def create_readme_rst():
"torch >= 0.3.0",
"unidecode",
"inflect",
"librosa == 0.5.1",
"librosa",
"numba",
"lws <= 1.0",
"nltk",
"requests",
"requests",
"PyQt5",
],
extras_require={
"train": [
Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@

fs = hparams.sample_rate

# Prevent Issue #5
plt.switch_backend('Qt5Agg')

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -916,8 +919,7 @@ def restore_parts(path, model):
# Preventing Windows specific error such as MemoryError
# Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch
if platform.system() == "Windows":
print("Windows Detected - num_workers set to 1")
hparams.set_hparam('num_workers', 1)
print(" [!] Windows Detected - IF THAllocator.c 0x05 error occurs SET num_workers to 1")

assert hparams.name == "deepvoice3"
print(hparams_debug_string())
Expand Down

0 comments on commit 92a94b8

Please sign in to comment.