From ec4cf2ae3c2a3e9289b3958ff28d5412710f85ca Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Fri, 4 May 2018 14:31:32 +0900 Subject: [PATCH] Use weight normalization disabled model for evaluation fixes https://github.com/r9y9/deepvoice3_pytorch/issues/77#issuecomment-385574856 --- train.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 9550a666..6cc9e773 100644 --- a/train.py +++ b/train.py @@ -379,7 +379,7 @@ def prepare_spec_image(spectrogram): return np.uint8(cm.magma(spectrogram.T) * 255) -def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker): +def eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker): # harded coded texts = [ "Scientists at the CERN laboratory say they have discovered a new particle.", @@ -395,6 +395,10 @@ def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker): eval_output_dir = join(checkpoint_dir, "eval") os.makedirs(eval_output_dir, exist_ok=True) + # Prepare model for evaluation + model_eval = build_model().to(device) + model_eval.load_state_dict(model.state_dict()) + # hard coded speaker_ids = [0, 1, 10] if ismultispeaker else [None] for speaker_id in speaker_ids: @@ -402,7 +406,7 @@ def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker): for idx, text in enumerate(texts): signal, alignment, _, mel = synthesis.tts( - model, text, p=0, speaker_id=speaker_id, fast=False) + model_eval, text, p=0, speaker_id=speaker_id, fast=True) signal /= np.max(np.abs(signal)) # Alignment @@ -713,7 +717,7 @@ def train(device, model, data_loader, optimizer, writer, train_seq2seq, train_postnet) if global_step > 0 and global_step % hparams.eval_interval == 0: - eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker) + eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker) # Update loss.backward()