From 873b9f0ec4d9dea59d6334e46b9bc54c35659fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A7=E8=92=9F=E8=92=BB?= <6648049+tooyoungtoosimp@users.noreply.github.com> Date: Mon, 13 Feb 2023 16:46:21 +0800 Subject: [PATCH] fix: tensorboard --- configs/config.json | 2 +- train_v2.py | 38 ++++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/configs/config.json b/configs/config.json index 7ec9159..36613c3 100644 --- a/configs/config.json +++ b/configs/config.json @@ -1,7 +1,7 @@ { "train": { "log_interval": 200, - "eval_interval": 1000, + "eval_interval": 200, "seed": 1234, "epochs": 10000, "learning_rate": 0.0001, diff --git a/train_v2.py b/train_v2.py index 665b3d4..196098e 100644 --- a/train_v2.py +++ b/train_v2.py @@ -170,25 +170,27 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade # logger.info([x.item() for x in losses] + [global_step, lr]) pbar_train_loader.set_postfix(loss=[round(y, 2) for y in [x.item() for x in losses]], lr=lr) - scalar_dict = { - "loss/g/total": loss_gen_all, - "loss/d/total": loss_disc_all, - "learning_rate": lr, - "grad_norm_d": grad_norm_d, - "grad_norm_g": grad_norm_g - } - scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}) - - scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) - scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) - scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), - "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), - "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), - } + if global_step % hps.train.log_interval == 0: + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g + } + scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}) + + scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), + } + utils.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict) + global_step += 1 - utils.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict) if rank == 0: # logger.info('====> Epoch: {}'.format(epoch))