Skip to content

Commit

Permalink
Fix for #37, #50 and #53 (Windows specific issues) (#54)
Browse files Browse the repository at this point in the history
* Fixed typeerror (torch.index_select received an invalid combination of arguments)

  File "synthesis.py", line 137, in <module>
    model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True)
  File "synthesis.py", line 66, in tts
    sequence, text_positions=text_positions, speaker_ids=speaker_ids)
  File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "H:\Tensorflow_Study\git\deepvoice3_pytorch\deepvoice3_pytorch\__init__.py", line 79, in forward
    text_positions, frame_positions, input_lengths)
  File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "H:\Tensorflow_Study\git\deepvoice3_pytorch\deepvoice3_pytorch\__init__.py", line 116, in forward
    text_sequences, lengths=input_lengths, speaker_embed=speaker_embed)
  File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "H:\Tensorflow_Study\git\deepvoice3_pytorch\deepvoice3_pytorch\deepvoice3.py", line 75, in forward
    x = self.embed_tokens(text_sequences) <- change this to long!
  File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\sparse.py", line 103, in forward
    self.scale_grad_by_freq, self.sparse
  File "H:\envs\pytorch\lib\site-packages\torch\nn\_functions\thnn\sparse.py", line 59, in forward
    output = torch.index_select(weight, 0, indices.view(-1))
TypeError: torch.index_select received an invalid combination of arguments - got (�[32;1mtorch.cuda.FloatTensor�[0m, �[32;1mint�[0m, �[31;1mtorch.cuda.IntTensor�[0m), but expected (torch.cuda.FloatTensor source, int dim, torch.cuda.LongTensor index)

changed text_sequence to long, as required by torch.index_select.

* Fixed Nonetype error in collect_features

* requirements.txt fix

* Memory Leakage bugfix + hparams change

* Pre-PR modifications

* Pre-PR modifications 2

* Pre-PR modifications 3

* Post-PR modification

* remove requirements.txt

* num_workers to 1 in train.py
  • Loading branch information
engiecat authored and r9y9 committed Mar 10, 2018
1 parent 52d1026 commit 9bc4943
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deepvoice3_pytorch/deepvoice3.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, text_sequences, text_positions=None, lengths=None,
assert self.n_speakers == 1 or speaker_embed is not None

# embed text_sequences
x = self.embed_tokens(text_sequences)
x = self.embed_tokens(text_sequences.long())
x = F.dropout(x, p=self.dropout, training=self.training)

# expand speaker embedding for all time steps
Expand Down
12 changes: 9 additions & 3 deletions hparams.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorflow as tf

# NOTE: If you want full control for model architecture. please take a look
# at the code and change whatever you want. Some hyper parameters are hardcoded.

Expand Down Expand Up @@ -28,7 +27,7 @@
builder="deepvoice3",

# Must be configured depends on the dataset and model you use
n_speakers=1,
n_speakers=1,
speaker_embed_dim=16,

# Audio:
Expand Down Expand Up @@ -81,7 +80,7 @@

# Data loader
pin_memory=True,
num_workers=2,
num_workers=2, # Set it to 1 when in Windows (MemoryError, THAllocator.c 0x5)

# Loss
masked_loss_weight=0.5, # (1-w)*loss + w * masked_loss
Expand Down Expand Up @@ -121,9 +120,16 @@
# 0 tends to prevent word repretetion, but sometime causes skip words
window_backward=1,
power=1.4, # Power to raise magnitudes to prior to phase retrieval

# GC:
# Forced garbage collection probability
# Use only when MemoryError continues in Windows (Disabled by default)
#gc_probability = 0.001,
)




def hparams_debug_string():
values = hparams.values()
hp = [' %s: %s' % (name, values[name]) for name in sorted(values)]
Expand Down
24 changes: 23 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""
from docopt import docopt

import sys
import sys, gc, platform
from os.path import dirname, join
from tqdm import tqdm, trange
from datetime import datetime
Expand Down Expand Up @@ -130,7 +130,18 @@ def collect_features(self, *args):
text, speaker_id = args
else:
text = args[0]
global _frontend
if _frontend is None:
_frontend = getattr(frontend, hparams.frontend)
seq = _frontend.text_to_sequence(text, p=hparams.replace_pronunciation_prob)

if platform.system() == "Windows":
if hasattr(hparams, 'gc_probability'):
_frontend = None # memory leaking prevention in Windows
if np.random.rand() < hparams.gc_probability:
gc.collect() # garbage collection enforced
print("GC done")

if self.multi_speaker:
return np.asarray(seq, dtype=np.int32), int(speaker_id)
else:
Expand Down Expand Up @@ -712,6 +723,7 @@ def train(model, data_loader, optimizer, writer,
if global_step > 0 and global_step % hparams.eval_interval == 0:
eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker)


# Update
loss.backward()
if clip_thresh > 0:
Expand Down Expand Up @@ -876,6 +888,16 @@ def restore_parts(path, model):
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args["--hparams"])

# Preventing Windows specific error such as MemoryError
# Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch
if platform.system() == "Windows":
print("Windows Detected - num_workers set to 1")
hparams.set_hparam('num_workers',1)

# Now, print the finalized hparams.
print(hparams_debug_string())

assert hparams.name == "deepvoice3"
print(hparams_debug_string())

Expand Down

0 comments on commit 9bc4943

Please sign in to comment.