Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google's Tacotron2 published #90

Closed
zuoxiang95 opened this issue Dec 20, 2017 · 29 comments
Closed

Google's Tacotron2 published #90

zuoxiang95 opened this issue Dec 20, 2017 · 29 comments

Comments

@zuoxiang95
Copy link

@keithito A few days ago, Google has published Tacotron2. Do you have a plan to achieve it?

@keithito
Copy link
Owner

keithito commented Dec 21, 2017

Yeah, I heard their audio. It sounds pretty amazing!

I'm hoping to find some time to try to implement it next week. In some ways it's a simpler architecture, for example they got rid of the CBHGs. And they use WaveNet as a vocoder, but it's trained independently after the main model is trained, so hopefully it won't be too hard to get something working...

@zuoxiang95
Copy link
Author

I am so glad to hear that you will implement it. And i'm willing to help you finish it together : )

@saxenauts
Copy link

I've been trying to reproduce the Deep Voice WaveNet with Tacotron. I would also love to volunteer. I'm currently working on parallel wavenet.

@candlewill
Copy link
Contributor

It would be very helpful.

@MXGray
Copy link

MXGray commented Dec 31, 2017

Wishing you all a prosperous 2018! ;)
Looking forward to help out with this in any way that I can. :)

@mrsylerpowers
Copy link

This could be a problem. The old wavenet vocorder takes to way to long but Google released a new wavenet. It uses probably density distillation along with other things. This wavenet vocorder is faster that realtime (1000 times faster than the original). It is outlined here and I have yet to see anyone implement it .https://deepmind.com/blog/high-fidelity-speech-synthesis-wavenet/

@EzequielAdrianM
Copy link

@keithito We would LOVE to see some Tacotron2 implementation here if you have a little time for it.
As always thank you for making a better Open Source world :)

@keithito
Copy link
Owner

keithito commented Jan 25, 2018

Hi, I just pushed a branch https://github.com/keithito/tacotron/tree/tacotron2-work-in-progress

It's kind of a hybrid of Tacotron 1 and 2. The encoder, decoder, and attention mechanism are from the Tacotron 2 paper. However, it's still using Griffin-Lim to go from spectrogram to waveform. I was hoping to be able to use an existing open source WaveNet implementation, but I couldn't find one that runs in real-time. If anyone knows of one, please let me know.

The location-sensitive attention seems like it leads to cleaner alignments:
step-143000-align

It's only been training for 143k steps, but I think it's already sounding pretty good.
eval-143000.zip

We'll see how it goes...

@luopengxo1
Copy link

https://github.com/r9y9/wavenet_vocoder, Looks like it works

@MXGray
Copy link

MXGray commented Feb 6, 2018

@keithito I hope you don't mind sharing your CKPT files for this (you're probably up to more than 143K GS by now)? I used your hybrid Tacotron 1 / Tacotron 2 to train from scratch using Nancy dataset, though even if I'm now at 335K GS, it isn't producing intelligible audio (unlike yours at 143K GS). Would be good to start training with Nancy dataset from your checkpoint. :)

@rafaelvalle
Copy link

@keithito after how many iterations did the attention start to show proper alignment?
Can you share some attention plots given iteration such that we can see their evolution during th training process?

@keithito
Copy link
Owner

keithito commented Feb 7, 2018

@rafaelvalle it was around 15k steps. Attention plots are attached: attention.zip

@keithito
Copy link
Owner

keithito commented Feb 7, 2018

@MXGray: here's a checkpoint after 800k steps:
http://data.keithito.com/data/speech/model.ckpt-800000.zip

@MXGray
Copy link

MXGray commented Feb 7, 2018

@keithito Thanks a lot! I'll share the trained Nancy model under hybrid Tacotron 1 & 2 once I get good results.

@MXGray
Copy link

MXGray commented Feb 7, 2018

@keithito
Using your hybrid Tacotron 1 / 2 implementation, I'm getting the error below with your 800K model; and
Also, when I use your hybrid Tacotron 1 / 2 implementation, I get the same errors when I use my old Nancy checkpoint (trained using your original implementation and not your hybrid one) ...
Here's the error:

  1. I get this when I enter "python3 train.py --restore_step=(directorypath)/model.ckpt-800000" in the terminal:
    NotFoundError (see above for traceback): Unsuccessful TensorSliceReader constructor: Failed to find any matching files for model.ckpt--800000

  2. And, I get this when I enter "python3 train.py --restore_step=800000" in the terminal:
    NotFoundError: NewRandomAccessFile failed to Create/Open: model.ckpt-800000.data-00000-of-00001 : The system cannot find the file specified.

I double checked directory and file paths. Everything's correct.
Also, those terminal commands are the same ones I use when training models using your original implementation ...

Was your 800K model trained using your hybrid Tacotron 1 / 2? Or, was it trained using your original implementation?

Please advise. Thanks!

@rafaelvalle
Copy link

@keithito can you please share the loss curve of your Tacotron2 model?

@keithito
Copy link
Owner

Hi @MXGray - the model is trained with the hybrid Tacotron 1/2, the same code that's checked into the tacotron2-work-in-progress branch. Not sure what's going wrong, but the naming is a little bit funny because the directory and checkpoint name are the same. I just verified that the following works:

curl http://data.keithito.com/data/speech/model.ckpt-800000.zip > /tmp/model.ckpt-800000.zip
unzip -d /tmp /tmp/model.ckpt-800000.zip
python3 eval.py --checkpoint /tmp/model.ckpt-800000/model.ckpt-800000

@keithito
Copy link
Owner

@rafaelvalle: Here you go. The smoothing is 0.98.
loss

all

@rafaelvalle
Copy link

rafaelvalle commented Feb 11, 2018

@keithito thanks a lot! Did the loss explode at any time?
Also, did you try running it with the learning rate scheme described in the paper, i.e. 0.001 for 50k iters than decay?

@keithito
Copy link
Owner

No loss explosion. I did not try the learning rate scheme in the paper -- I suspect it would not make a huge difference from starting the decay at step 0.

@tuong-olli
Copy link

tuong-olli commented Mar 30, 2018

Output of tacotron 1:
https://drive.google.com/file/d/1NfT_oEXNfPQJ-3pHCpMEEsdgDiw2wXOu/view?usp=sharing
Attention alignments of tacotron 1:
step-113000-align

Output of tacotron 2:
https://drive.google.com/file/d/1kX3oWl-aCVugrsZDrmaJNRe8-3oiGFp3/view?usp=sharing
Attention alignments of tacotron 2:
step-314000-align

Outputs of tacotron 1 are clearer than tacotron2's outputs, but tacotron1's outputs are easily interrupted.

@twidddj
Copy link

twidddj commented Apr 3, 2018

@keithito Hi, Your works is very helpful to me. I really appreciate your works.
I wonder have you tried to train your model use the "reduction factor" = 1 followed the paper? I think It's a key factor to connect Tacotron and Wavenet.

@keithito
Copy link
Owner

keithito commented Apr 3, 2018

No, I used the default (5). I'm not sure if it will make that much of a difference when integrating wavenet -- you're still generating the same number of mel spectrogram frames regardless of the reduction factor.

@twidddj
Copy link

twidddj commented Apr 3, 2018

It is required for generating ground truth-aligned prediction which frame exactly aligns with target waveform. They made the aligned prediction in teacher-forcing mode of new Tacotron and use them to train their wavenet.

If we use reduction factor > 1, the prediction would seem like this. It would be a bad news to wavenet performance.
teacher_forced_mel_prediction

This is the true mel.
true_mel

@keithito
Copy link
Owner

keithito commented Apr 3, 2018

I see. Guess I'll have to try with r=1 :-)

@rishikksh20
Copy link

@keithito does tacotron 2 implements in the given. branch

https://github.com/keithito/tacotron/tree/tacotron2-work-in-progress
is completed. Can I integrate it with r9y9's wavelet_vocoder using Mel-spectrogram ?

@keithito
Copy link
Owner

Just to follow up, there are some great Tacotron 2 implementations here:
https://github.com/Rayhane-mamah/Tacotron-2
https://github.com/NVIDIA/tacotron2

If you're interested in Tacotron 2, please use one of them.

@vinnitu
Copy link

vinnitu commented May 23, 2019

@MXGray: here's a checkpoint after 800k steps:
http://data.keithito.com/data/speech/model.ckpt-800000.zip

How to run your model?

Key
model/inference/decoder/output_projection_wrapper/multi_rnn_cell/cell_0/output_projection_wrapper/bias not found in checkpoint

pipenv run python demo_server.py --checkpoint ~/tacotron/model.ckpt-800000/model.ckpt-800000

Hyperparameters:
  adam_beta1: 0.9
  adam_beta2: 0.999
  attention_depth: 256
  batch_size: 8
  cleaners: transliteration_cleaners
  decay_learning_rate: True
  decoder_depth: 256
  embed_depth: 256
  encoder_depth: 256
  frame_length_ms: 50
  frame_shift_ms: 12.5
  griffin_lim_iters: 60
  initial_learning_rate: 0.002
  max_iters: 1000
  min_level_db: -100
  num_freq: 1025
  num_mels: 80
  outputs_per_step: 5
  postnet_depth: 256
  power: 1.5
  preemphasis: 0.97
  prenet_depths: [256, 128]
  ref_level_db: 20
  sample_rate: 20000
  use_cmudict: False
Constructing model: tacotron
Initialized Tacotron model. Dimensions: 
  embedding:               256
  prenet out:              128
  encoder out:             256
  attention out:           256
  concat attn & out:       512
  decoder cell out:        256
  decoder out (5 frames):  400
  decoder out (1 frame):   80
  postnet out:             256
  linear out:              1025
Loading checkpoint: /home/ai/tacotron/model.ckpt-800000/model.ckpt-800000
Traceback (most recent call last):
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1334, in _do_call
    return fn(*args)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1319, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1407, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Key model/inference/decoder/output_projection_wrapper/multi_rnn_cell/cell_0/output_projection_wrapper/bias not found in checkpoint
	 [[{{node save/RestoreV2}}]]
	 [[{{node save/RestoreV2}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1276, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1152, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1328, in _do_run
    run_metadata)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1348, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key model/inference/decoder/output_projection_wrapper/multi_rnn_cell/cell_0/output_projection_wrapper/bias not found in checkpoint
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]

Caused by op 'save/RestoreV2', defined at:
  File "demo_server.py", line 91, in <module>
    synthesizer.load(args.checkpoint)
  File "/home/ai/projects/tacotron/synthesizer.py", line 27, in load
    saver = tf.train.Saver()
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 832, in __init__
    self.build()
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 844, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 881, in _build
    build_save=build_save, build_restore=build_restore)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 513, in _build_internal
    restore_sequentially, reshape)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 332, in _AddRestoreOps
    restore_sequentially)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 580, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1572, in restore_v2
    name=name)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
    op_def=op_def)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1801, in __init__
    self._traceback = tf_stack.extract_stack()

NotFoundError (see above for traceback): Key model/inference/decoder/output_projection_wrapper/multi_rnn_cell/cell_0/output_projection_wrapper/bias not found in checkpoint
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1286, in restore
    names_to_keys = object_graph_key_mapping(save_path)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1591, in object_graph_key_mapping
    checkpointable.OBJECT_GRAPH_PROTO_KEY)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 370, in get_tensor
    status)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.NotFoundError: Key _CHECKPOINTABLE_OBJECT_GRAPH not found in checkpoint

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "demo_server.py", line 91, in <module>
    synthesizer.load(args.checkpoint)
  File "/home/ai/projects/tacotron/synthesizer.py", line 28, in load
    saver.restore(self.session, checkpoint_path)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1292, in restore
    err, "a Variable name or other graph key that is missing")
tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key model/inference/decoder/output_projection_wrapper/multi_rnn_cell/cell_0/output_projection_wrapper/bias not found in checkpoint
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]

Caused by op 'save/RestoreV2', defined at:
  File "demo_server.py", line 91, in <module>
    synthesizer.load(args.checkpoint)
  File "/home/ai/projects/tacotron/synthesizer.py", line 27, in load
    saver = tf.train.Saver()
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 832, in __init__
    self.build()
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 844, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 881, in _build
    build_save=build_save, build_restore=build_restore)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 513, in _build_internal
    restore_sequentially, reshape)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 332, in _AddRestoreOps
    restore_sequentially)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 580, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1572, in restore_v2
    name=name)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
    op_def=op_def)
  File "/home/ai/projects/tacotron/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1801, in __init__
    self._traceback = tf_stack.extract_stack()

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key model/inference/decoder/output_projection_wrapper/multi_rnn_cell/cell_0/output_projection_wrapper/bias not found in checkpoint
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]
	 [[node save/RestoreV2 (defined at /home/ai/projects/tacotron/synthesizer.py:27) ]]

@rabbia970
Copy link

Hey I am getting some errors in implementation of eval.py or demo_server.py.
Can you please help me.

whenever i execute eval.py. encountered following errors.

Traceback (most recent call last):
File "eval.py", line 54, in
main()
File "eval.py", line 49, in main
hparams.parse(100)
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\contrib\training\python\training\hparam.py", line 523, in parse
values_map = parse_values(values, type_map)
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\contrib\training\python\training\hparam.py", line 246, in parse_values
while pos < len(values):
TypeError: object of type 'int' has no len()

Same type of error occurs when I run through other method like below.

Hyperparameters:
GL_on_GPU: True
NN_init: True
NN_scaler: 0.3
allow_clipping_in_normalization: True
attention_dim: 128
attention_filters: 32
attention_kernel: (31,)
attention_win_size: 7
batch_norm_position: after
cbhg_conv_channels: 128
cbhg_highway_units: 128
cbhg_highwaynet_layers: 4
cbhg_kernels: 8
cbhg_pool_size: 2
cbhg_projection: 256
cbhg_projection_kernel_size: 3
cbhg_rnn_units: 128
cdf_loss: False
cin_channels: 80
cleaners: english_cleaners
clip_for_wavenet: True
clip_mels_length: True
clip_outputs: True
cross_entropy_pos_weight: 1
cumulative_weights: True
decoder_layers: 2
decoder_lstm_units: 1024
embedding_dim: 512
enc_conv_channels: 512
enc_conv_kernel_size: (5,)
enc_conv_num_layers: 3
encoder_lstm_units: 256
fmax: 7600
fmin: 55
frame_shift_ms: None
freq_axis_kernel_size: 3
gate_channels: 256
gin_channels: 16
griffin_lim_iters: 60
hop_size: 200
input_type: raw
kernel_size: 3
layers: 20
leaky_alpha: 0.4
legacy: True
log_scale_min: -32.23619130191664
log_scale_min_gauss: -16.11809565095832
lower_bound_decay: 0.1
magnitude_power: 2.0
mask_decoder: False
mask_encoder: True
max_abs_value: 4.0
max_iters: 100
max_mel_frames: 900
max_time_sec: None
max_time_steps: 11000
min_level_db: -100
n_fft: 1024
n_speakers: 6
normalize_for_wavenet: True
num_freq: 513
num_mels: 80
out_channels: 30
outputs_per_step: 1
postnet_channels: 512
postnet_kernel_size: (5,)
postnet_num_layers: 5
power: 1.5
predict_linear: True
preemphasis: 0.97
preemphasize: True
prenet_layers: [256, 256]
quantize_channels: 65536
ref_level_db: 20
rescale: True
rescaling_max: 0.999
residual_channels: 128
residual_legacy: True
sample_rate: 16000
signal_normalization: True
silence_threshold: 2
skip_out_channels: 128
smoothing: False
speakers: ['awb', 'bdl', 'clb', 'ksp', 'rms', 'slt']
speakers_path: None
split_on_cpu: True
stacks: 2
stop_at_any: True
symmetric_mels: True
synthesis_constraint: False
synthesis_constraint_type: window
tacotron_adam_beta1: 0.9
tacotron_adam_beta2: 0.999
tacotron_adam_epsilon: 1e-06
tacotron_batch_size: 32
tacotron_clip_gradients: True
tacotron_data_random_state: 1234
tacotron_decay_learning_rate: True
tacotron_decay_rate: 0.5
tacotron_decay_steps: 18000
tacotron_dropout_rate: 0.5
tacotron_final_learning_rate: 0.0001
tacotron_fine_tuning: False
tacotron_initial_learning_rate: 0.001
tacotron_natural_eval: False
tacotron_num_gpus: 1
tacotron_random_seed: 5339
tacotron_reg_weight: 1e-06
tacotron_scale_regularization: False
tacotron_start_decay: 40000
tacotron_swap_with_cpu: False
tacotron_synthesis_batch_size: 1
tacotron_teacher_forcing_decay_alpha: None
tacotron_teacher_forcing_decay_steps: 40000
tacotron_teacher_forcing_final_ratio: 0.0
tacotron_teacher_forcing_init_ratio: 1.0
tacotron_teacher_forcing_mode: constant
tacotron_teacher_forcing_ratio: 1.0
tacotron_teacher_forcing_start_decay: 10000
tacotron_test_batches: None
tacotron_test_size: 0.05
tacotron_zoneout_rate: 0.1
train_with_GTA: True
trim_fft_size: 2048
trim_hop_size: 512
trim_silence: True
trim_top_db: 40
upsample_activation: Relu
upsample_scales: [10, 20]
upsample_type: SubPixel
use_bias: True
use_lws: False
use_speaker_embedding: True
wavenet_adam_beta1: 0.9
wavenet_adam_beta2: 0.999
wavenet_adam_epsilon: 1e-06
wavenet_batch_size: 8
wavenet_clip_gradients: True
wavenet_data_random_state: 1234
wavenet_debug_mels: ['training_data/mels/mel-LJ001-0008.npy']
wavenet_debug_wavs: ['training_data/audio/audio-LJ001-0008.npy']
wavenet_decay_rate: 0.5
wavenet_decay_steps: 200000
wavenet_dropout: 0.05
wavenet_ema_decay: 0.9999
wavenet_gradient_max_norm: 100.0
wavenet_gradient_max_value: 5.0
wavenet_init_scale: 1.0
wavenet_learning_rate: 0.001
wavenet_lr_schedule: exponential
wavenet_natural_eval: False
wavenet_num_gpus: 1
wavenet_pad_sides: 1
wavenet_random_seed: 5339
wavenet_swap_with_cpu: False
wavenet_synth_debug: False
wavenet_synthesis_batch_size: 20
wavenet_test_batches: 1
wavenet_test_size: None
wavenet_warmup: 4000.0
wavenet_weight_normalization: False
win_size: 800
Constructing model: Tacotron
Traceback (most recent call last):
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\execute.py", line 141, in make_shape
shape = tensor_shape.as_shape(v)
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\tensor_shape.py", line 849, in as_shape
return TensorShape(shape)
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\tensor_shape.py", line 455, in init
self._dims = [as_dimension(d) for d in dims_iter]
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\tensor_shape.py", line 455, in
self._dims = [as_dimension(d) for d in dims_iter]
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\tensor_shape.py", line 397, in as_dimension
return Dimension(value)
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\tensor_shape.py", line 32, in init
self._value = int(value)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'HParams'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "demo_server.py", line 92, in
synthesizer.load(args.checkpoint,hparams.parse(args.hparams))
File "C:\Users\Dell\Desktop\Tacotron-2-master\tacotron\synthesizer.py", line 21, in load
targets = tf.placeholder(tf.float32, (None, None, hparams), name='mel_targets')
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1680, in placeholder
return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 4101, in _placeholder
shape = _execute.make_shape(shape, "shape")
File "C:\Users\Dell\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\execute.py", line 143, in make_shape
raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests