import argparse import os import torch import numpy as np import torch.nn as nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from utils.model import get_model, get_vocoder, get_param_num from utils.tools import get_configs_of, to_device, log, synth_one_sample from model import SpecDiffGANLoss from dataset import Dataset from evaluate import evaluate device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args, configs): print("Prepare training ...") preprocess_config, model_config, train_config = configs # Get dataset dataset = Dataset( "train.txt", args, preprocess_config, model_config, train_config, sort=True, drop_last=True ) batch_size = train_config["optimizer"]["batch_size"] group_size = 4 # Set this larger than 1 to enable sorting in Dataset assert batch_size * group_size < len(dataset) loader = DataLoader( dataset, batch_size=batch_size * group_size, shuffle=True, collate_fn=dataset.collate_fn, ) # Prepare model model, diff_d, spec_d, optG_fs2, optG, optD, sdlG, sdlD, epoch = get_model(args, configs, device, train=True) num_params_G = get_param_num(model) num_params_D_Diff = get_param_num(diff_d) num_params_D_Spec = get_param_num(spec_d) Loss = SpecDiffGANLoss(args, preprocess_config, model_config, train_config).to(device) print("Number of SpecDiff-GAN Parameters :", num_params_G) print(" DiffusionDiscriminator Parameters :", num_params_D_Diff) print(" SpectrogramDiscriminator Parameters:", num_params_D_Spec) print(" All Parameters :", num_params_G + num_params_D_Diff + num_params_D_Spec) # Load vocoder vocoder = get_vocoder(model_config, device) # Init logger for p in train_config["path"].values(): os.makedirs(p, exist_ok=True) train_log_path = os.path.join(train_config["path"]["log_path"], "train") val_log_path = os.path.join(train_config["path"]["log_path"], "val") os.makedirs(train_log_path, exist_ok=True) os.makedirs(val_log_path, exist_ok=True) train_logger = SummaryWriter(train_log_path) val_logger = SummaryWriter(val_log_path) # Training step = args.restore_step + 1 grad_acc_step = train_config["optimizer"]["grad_acc_step"] grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] total_step = train_config["step"]["total_step"] log_step = train_config["step"]["log_step"] save_step = train_config["step"]["save_step"] synth_step = train_config["step"]["synth_step"] val_step = train_config["step"]["val_step"] def model_update(model, step, loss, optimizer): # Backward loss = (loss / grad_acc_step).backward() if step % grad_acc_step == 0: # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) # Update weights optimizer.step() optimizer.zero_grad() outer_bar = tqdm(total=total_step, desc="Training", position=0) outer_bar.n = args.restore_step outer_bar.update() while True: inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) for batchs in loader: for batch in batchs: batch = to_device(batch, device) ####################### # Train Discriminator # ####################### # Forward output, *_ = model(*(batch[2:])) xs, spk_emb, t, mel_masks = *(output[1:4]), output[9] x_ts, x_t_prevs, x_t_prev_preds, spk_emb, t = \ [x.detach() if x is not None else x for x in (list(xs) + [spk_emb, t])] D_fake_diff = diff_d(x_ts, x_t_prev_preds, t) D_real_diff = diff_d(x_ts, x_t_prevs, t) D_fake_spec_feats, D_fake_spec = spec_d(output[0], spk_emb) D_real_spec_feats, D_real_spec = spec_d(batch[6], spk_emb) D_loss_real, D_loss_fake = Loss.d_loss_fn(D_real_diff[-1], D_real_spec, D_fake_diff[-1], D_fake_spec) D_loss = D_loss_real + D_loss_fake model_update(diff_d, step, D_loss, optD) ####################### # Train Generator # ####################### # Forward output, p_targets, coarse_mels = model(*(batch[2:])) # Update Batch batch[9] = p_targets (x_ts, x_t_prevs, x_t_prev_preds), spk_emb, t, mel_masks = *(output[1:4]), output[9] D_fake_diff = diff_d(x_ts, x_t_prev_preds, t) D_real_diff = diff_d(x_ts, x_t_prevs, t) D_fake_spec_feats, D_fake_spec = spec_d(output[0], spk_emb) D_real_spec_feats, D_real_spec = spec_d(batch[6], spk_emb) adv_loss = Loss.g_loss_fn(D_fake_diff[-1], D_fake_spec) ( fm_loss, recon_loss, mel_loss, pitch_loss, energy_loss, duration_loss, ) = Loss( model, batch, output, coarse_mels, (D_real_diff, D_real_spec_feats, D_fake_diff, D_fake_spec_feats), ) G_loss = adv_loss + recon_loss + fm_loss model_update(model, step, G_loss, optG) losses = [D_loss + G_loss, D_loss, G_loss, recon_loss, fm_loss, adv_loss, mel_loss, pitch_loss, energy_loss, duration_loss] losses_msg = [D_loss + G_loss, D_loss, adv_loss, mel_loss, pitch_loss, energy_loss, duration_loss] if step % log_step == 0: losses_msg = [sum(l.values()).item() if isinstance(l, dict) else l.item() for l in losses_msg] message1 = "Step {}/{}, ".format(step, total_step) message2 = "Total Loss: {:.4f}, D_loss: {:.4f}, adv_loss: {:.4f}, mel_loss: {:.4f}, pitch_loss: {:.4f}, energy_loss: {:.4f}, duration_loss: {:.4f}".format( *losses_msg ) with open(os.path.join(train_log_path, "log.txt"), "a") as f: f.write(message1 + message2 + "\n") outer_bar.write(message1 + message2) log(train_logger, step, losses=losses, lr=sdlG.get_last_lr()[-1]) if step % synth_step == 0: figs, wav_reconstruction, wav_prediction, tag = synth_one_sample( args, batch, output, coarse_mels, vocoder, model_config, preprocess_config, model.diffusion, ) log( train_logger, step, figs=figs, tag="Training", ) sampling_rate = preprocess_config["preprocessing"]["audio"][ "sampling_rate" ] log( train_logger, step, audio=wav_reconstruction, sampling_rate=sampling_rate, tag="Training/reconstructed", ) log( train_logger, step, audio=wav_prediction, sampling_rate=sampling_rate, tag="Training/synthesized", ) if step % val_step == 0: model.eval() message = evaluate(args, model, diff_d, spec_d, step, configs, val_logger, vocoder, losses) with open(os.path.join(val_log_path, "log.txt"), "a") as f: f.write(message + "\n") outer_bar.write(message) model.train() if step % save_step == 0: torch.save( { "epoch": epoch, "G": model.state_dict(), "D_Diff": diff_d.state_dict(), "D_Spec": spec_d.state_dict(), "optG_fs2": optG_fs2._optimizer.state_dict(), "optG": optG.state_dict(), "optD": optD.state_dict(), "sdlG": sdlG.state_dict(), "sdlD": sdlD.state_dict(), }, os.path.join( train_config["path"]["ckpt_path"], "{}.pth.tar".format(step), ), ) if step >= total_step: quit() step += 1 outer_bar.update(1) inner_bar.update(1) epoch += 1 sdlG.step() sdlD.step() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--restore_step", type=int, default=0) parser.add_argument("--path_tag", type=str, default="") parser.add_argument( "--dataset", type=str, required=True, help="name of dataset", ) args = parser.parse_args() # Read Config preprocess_config, model_config, train_config = get_configs_of(args.dataset) configs = (preprocess_config, model_config, train_config) path_tag = "_{}".format(args.path_tag) if args.path_tag != "" else args.path_tag train_config["path"]["ckpt_path"] = train_config["path"]["ckpt_path"] + "_{}".format(path_tag) if path_tag != '' else '' train_config["path"]["log_path"] = train_config["path"]["log_path"] + "_{}".format(path_tag)if path_tag != '' else '' train_config["path"]["result_path"] = train_config["path"]["result_path"] + "_{}".format(path_tag)if path_tag != '' else '' if preprocess_config["preprocessing"]["pitch"]["pitch_type"] == "cwt": from utils.pitch_tools import get_lf0_cwt preprocess_config["preprocessing"]["pitch"]["cwt_scales"] = get_lf0_cwt(np.ones(10))[1] # Log Configuration print("\n==================================== Training Configuration ====================================") if model_config["multi_speaker"]: print(" ---> Type of Speaker Embedder:", preprocess_config["preprocessing"]["speaker_embedder"]) print(" ---> Total Batch Size:", int(train_config["optimizer"]["batch_size"])) print(" ---> Use Pitch Embed:", model_config["variance_embedding"]["use_pitch_embed"]) print(" ---> Use Energy Embed:", model_config["variance_embedding"]["use_energy_embed"]) print(" ---> Path of ckpt:", train_config["path"]["ckpt_path"]) print(" ---> Path of log:", train_config["path"]["log_path"]) print(" ---> Path of result:", train_config["path"]["result_path"]) print("================================================================================================") main(args, configs)