fix: tensorboard

32k
大蒟蒻 3 years ago
parent 86f7be1310
commit 873b9f0ec4

@ -1,7 +1,7 @@
{
"train": {
"log_interval": 200,
"eval_interval": 1000,
"eval_interval": 200,
"seed": 1234,
"epochs": 10000,
"learning_rate": 0.0001,

@ -170,25 +170,27 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
# logger.info([x.item() for x in losses] + [global_step, lr])
pbar_train_loader.set_postfix(loss=[round(y, 2) for y in [x.item() for x in losses]], lr=lr)
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g
}
scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl})
scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
}
if global_step % hps.train.log_interval == 0:
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g
}
scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl})
scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
}
utils.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict)
global_step += 1
utils.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict)
if rank == 0:
# logger.info('====> Epoch: {}'.format(epoch))

Loading…
Cancel
Save