Skip to content

Commit

Permalink
Use weight normalization disabled model for evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
r9y9 committed May 4, 2018
1 parent 77b4642 commit ec4cf2a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def prepare_spec_image(spectrogram):
return np.uint8(cm.magma(spectrogram.T) * 255)


def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker):
def eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker):
# harded coded
texts = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
Expand All @@ -395,14 +395,18 @@ def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker):
eval_output_dir = join(checkpoint_dir, "eval")
os.makedirs(eval_output_dir, exist_ok=True)

# Prepare model for evaluation
model_eval = build_model().to(device)
model_eval.load_state_dict(model.state_dict())

# hard coded
speaker_ids = [0, 1, 10] if ismultispeaker else [None]
for speaker_id in speaker_ids:
speaker_str = "multispeaker{}".format(speaker_id) if speaker_id is not None else "single"

for idx, text in enumerate(texts):
signal, alignment, _, mel = synthesis.tts(
model, text, p=0, speaker_id=speaker_id, fast=False)
model_eval, text, p=0, speaker_id=speaker_id, fast=True)
signal /= np.max(np.abs(signal))

# Alignment
Expand Down Expand Up @@ -713,7 +717,7 @@ def train(device, model, data_loader, optimizer, writer,
train_seq2seq, train_postnet)

if global_step > 0 and global_step % hparams.eval_interval == 0:
eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker)
eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker)

# Update
loss.backward()
Expand Down

0 comments on commit ec4cf2a

Please sign in to comment.