This commit is contained in:
Jarod Mica 2024-12-23 04:57:56 -08:00
parent 894d724b36
commit 60ddc7a4a4
2 changed files with 21 additions and 21 deletions

View File

@ -1,7 +1,7 @@
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import utils, os import GPT_SoVITS.utils, os
hps = utils.get_hparams(stage=2) hps = GPT_SoVITS.utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
@ -67,7 +67,7 @@ def main():
def run(rank, n_gpus, hps): def run(rank, n_gpus, hps):
global global_step global global_step
if rank == 0: if rank == 0:
logger = utils.get_logger(hps.data.exp_dir) logger = GPT_SoVITS.utils.get_logger(hps.data.exp_dir)
logger.info(hps) logger.info(hps)
# utils.check_git_hash(hps.s2_ckpt_dir) # utils.check_git_hash(hps.s2_ckpt_dir)
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
@ -192,16 +192,16 @@ def run(rank, n_gpus, hps):
net_d = net_d.to(device) net_d = net_d.to(device)
try: # 如果能加载自动resume try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = GPT_SoVITS.utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"), GPT_SoVITS.utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
net_d, net_d,
optim_d, optim_d,
) # D多半加载没事 ) # D多半加载没事
if rank == 0: if rank == 0:
logger.info("loaded D") logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = GPT_SoVITS.utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"), GPT_SoVITS.utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
net_g, net_g,
optim_g, optim_g,
) )
@ -427,20 +427,20 @@ def train_and_evaluate(
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = { image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy( "slice/mel_org": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy() y_mel[0].data.cpu().numpy()
), ),
"slice/mel_gen": utils.plot_spectrogram_to_numpy( "slice/mel_gen": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy() y_hat_mel[0].data.cpu().numpy()
), ),
"all/mel": utils.plot_spectrogram_to_numpy( "all/mel": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy() mel[0].data.cpu().numpy()
), ),
"all/stats_ssl": utils.plot_spectrogram_to_numpy( "all/stats_ssl": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy() stats_ssl[0].data.cpu().numpy()
), ),
} }
utils.summarize( GPT_SoVITS.utils.summarize(
writer=writer, writer=writer,
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,
@ -449,7 +449,7 @@ def train_and_evaluate(
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0: if hps.train.if_save_latest == 0:
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_g, net_g,
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
@ -458,7 +458,7 @@ def train_and_evaluate(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step) "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
), ),
) )
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
@ -468,7 +468,7 @@ def train_and_evaluate(
), ),
) )
else: else:
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_g, net_g,
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
@ -477,7 +477,7 @@ def train_and_evaluate(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333) "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
), ),
) )
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
@ -565,7 +565,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
) )
image_dict.update( image_dict.update(
{ {
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy( f"gen/mel_{batch_idx}_{test}": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy() y_hat_mel[0].cpu().numpy()
) )
} }
@ -575,7 +575,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
) )
image_dict.update( image_dict.update(
{ {
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( f"gt/mel_{batch_idx}": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy() mel[0].cpu().numpy()
) )
} }
@ -587,7 +587,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
# f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :] # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
# }) # })
utils.summarize( GPT_SoVITS.utils.summarize(
writer=writer_eval, writer=writer_eval,
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,

View File

@ -33,7 +33,7 @@ def clean_text(text, language, version=None):
for special_s, special_l, target_symbol in special: for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l: if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol, version) return clean_special(text, language, special_s, target_symbol, version)
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]]) language_module = __import__("GPT_SoVITS.text."+language_module_map[language],fromlist=[language_module_map[language]])
if hasattr(language_module,"text_normalize"): if hasattr(language_module,"text_normalize"):
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
else: else:
@ -67,7 +67,7 @@ def clean_special(text, language, special_s, target_symbol, version=None):
特殊静音段sp符号处理 特殊静音段sp符号处理
""" """
text = text.replace(special_s, ",") text = text.replace(special_s, ",")
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]]) language_module = __import__("GPT_SoVITS.text."+language_module_map[language],fromlist=[language_module_map[language]])
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
new_ph = [] new_ph = []