diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 5aff4ae..9bada9d 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -7,8 +7,7 @@ 全部按日文识别 ''' import logging -import traceback - +import traceback,torchaudio,warnings logging.getLogger("markdown_it").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("httpcore").setLevel(logging.ERROR) @@ -17,6 +16,8 @@ logging.getLogger("asyncio").setLevel(logging.ERROR) logging.getLogger("charset_normalizer").setLevel(logging.ERROR) logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) logging.getLogger("multipart.multipart").setLevel(logging.ERROR) +warnings.simplefilter(action='ignore', category=FutureWarning) + import LangSegment, os, re, sys, json import pdb import torch @@ -25,20 +26,17 @@ try: import gradio.analytics as analytics analytics.version_check = lambda:None except:... +version=model_version=os.environ.get("version","v2") +pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "runtime/GPT_SoVITS/s2Gv3.pth"] +pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "runtime/GPT_SoVITS/s1v3.ckpt"] -version=os.environ.get("version","v2") -pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"] -pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"] _ =[[],[]] -for i in range(2): - if os.path.exists(pretrained_gpt_name[i]): - _[0].append(pretrained_gpt_name[i]) - if os.path.exists(pretrained_sovits_name[i]): - _[-1].append(pretrained_sovits_name[i]) +for i in range(3): + if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i]) + if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i]) pretrained_gpt_name,pretrained_sovits_name = _ - - + if os.path.exists(f"./weight.json"): pass @@ -83,7 +81,7 @@ from feature_extractor import cnhubert cnhubert.cnhubert_base_path = cnhubert_base_path -from module.models import SynthesizerTrn +from GPT_SoVITS.module.models import SynthesizerTrn,SynthesizerTrnV3 from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text @@ -184,9 +182,17 @@ if is_half == True: else: ssl_model = ssl_model.to(device) +resample_transform_dict={} +def resample(audio_tensor, sr0): + global resample_transform_dict + if sr0 not in resample_transform_dict: + resample_transform_dict[sr0] = torchaudio.transforms.Resample( + sr0, 24000 + ).to(device) + return resample_transform_dict[sr0](audio_tensor) def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): - global vq_model, hps, version, dict_language + global vq_model, hps, version, model_version, dict_language dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] hps = DictToAttrRecursive(hps) @@ -196,21 +202,41 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): else: hps.model.version = "v2" version = hps.model.version + if os.path.getsize(sovits_path)>700*1024*1024: + model_version="v3" + else: + model_version=version + ''' + v1:about 82942KB + half thr:82978KB + v2:about 83014KB + v3:about 750MB + ''' # print("sovits版本:",hps.model.version) - vq_model = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model - ) + if model_version!="v3": + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model + ) + else: + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model + ) if ("pretrained" not in sovits_path): - del vq_model.enc_q + try: + del vq_model.enc_q + except:pass if is_half == True: vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() - print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) + print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False)) dict_language = dict_language_v1 if version =='v1' else dict_language_v2 with open("./weight.json")as f: data=f.read() @@ -228,13 +254,17 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): else: text_update = {'__type__':'update', 'value':''} text_language_update = {'__type__':'update', 'value':i18n("中文")} - return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update - + if model_version=="v3": + visible_sample_steps=True + visible_inp_refs=False + else: + visible_sample_steps=False + visible_inp_refs=True + return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs} change_sovits_weights(sovits_path) - def change_gpt_weights(gpt_path): global hz, max_sec, t2s_model, config hz = 50 @@ -247,8 +277,8 @@ def change_gpt_weights(gpt_path): t2s_model = t2s_model.half() t2s_model = t2s_model.to(device) t2s_model.eval() - total = sum([param.nelement() for param in t2s_model.parameters()]) - print("Number of parameter: %.2fM" % (total / 1e6)) + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # print("Number of parameter: %.2fM" % (total / 1e6)) with open("./weight.json")as f: data=f.read() data=json.loads(data) @@ -257,6 +287,25 @@ def change_gpt_weights(gpt_path): change_gpt_weights(gpt_path) +os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" +import torch,soundfile +now_dir = os.getcwd() +import soundfile + +def init_bigvgan(): + global model + from BigVGAN import bigvgan + model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + model.remove_weight_norm() + model = model.eval() + if is_half == True: + model = model.half().to(device) + else: + model = model.to(device) + +if model_version!="v3":model=None +else:init_bigvgan() def get_spepc(hps, filename): @@ -376,6 +425,30 @@ def get_phones_and_bert(text,language,version,final=False): return phones,bert.to(dtype),norm_text +from module.mel_processing import spectrogram_torch,spec_to_mel_torch +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + spec=spectrogram_torch(y,n_fft,sampling_rate,hop_size,win_size,center) + mel=spec_to_mel_torch(spec,n_fft,num_mels,sampling_rate,fmin,fmax) + return mel +mel_fn_args = { + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False +} + +spec_min = -12 +spec_max = 2 +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min +mel_fn=lambda x: mel_spectrogram(x, **mel_fn_args) + def merge_short_text_in_array(texts, threshold): if (len(texts)) < 2: @@ -397,8 +470,7 @@ def merge_short_text_in_array(texts, threshold): ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature # cache_tokens={}#暂未实现清理机制 cache= {} -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free - =False,speed=1,if_freeze=False,inp_refs=None): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=8): global cache if ref_wav_path:pass else:gr.Warning(i18n('请上传参考音频')) @@ -468,6 +540,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, texts = process_text(texts) texts = merge_short_text_in_array(texts, 5) audio_opt = [] + ###s2v3暂不支持ref_free if not ref_free: phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version) @@ -509,18 +582,60 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) cache[i_text]=pred_semantic t3 = ttime() - refers=[] - if(inp_refs): - for path in inp_refs: - try: - refer = get_spepc(hps, path.name).to(dtype).to(device) - refers.append(refer) - except: - traceback.print_exc() - if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] - audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0]) - max_audio=np.abs(audio).max()#简单防止16bit爆音 - if max_audio>1:audio/=max_audio + ###v3不存在以下逻辑和inp_refs + if model_version!="v3": + refers=[] + if(inp_refs): + for path in inp_refs: + try: + refer = get_spepc(hps, path.name).to(dtype).to(device) + refers.append(refer) + except: + traceback.print_exc() + if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0]) + else: + refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)#######这里要重采样切到32k,因为src是24k的,没有单独的32k的src,所以不能改成2个路径 + phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0) + fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio=ref_audio.to(device).float() + if (ref_audio.shape[0] == 2): + ref_audio = ref_audio.mean(0).unsqueeze(0) + if sr!=24000: + ref_audio=resample(ref_audio,sr) + mel2 = mel_fn(ref_audio.to(dtype)) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if (T_min > 468): + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge) + cfm_resss = [] + idx = 0 + while (1): + fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len] + if (fea_todo_chunk.shape[-1] == 0): break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) + cfm_res = cfm_res[:, :, mel2.shape[2]:] + mel2 = cfm_res[:, :, -468:] + fea_ref = fea_todo_chunk[:, :, -468:] + cfm_resss.append(cfm_res) + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res) + if model==None:init_bigvgan() + with torch.inference_mode(): + wav_gen = model(cmf_res) + audio=wav_gen[0][0].cpu().detach().numpy() + max_audio=np.abs(audio).max()#简单防止16bit爆音 + if max_audio>1:audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() @@ -529,9 +644,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])) ) - yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype( - np.int16 - ) + sr=hps.data.sampling_rate if model_version!="v3"else 24000 + yield sr, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) def split(todo_text): @@ -655,8 +769,8 @@ def change_choices(): return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} -SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"] -GPT_weight_root=["GPT_weights_v2","GPT_weights"] +SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"] +GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"] for path in SoVITS_weight_root+GPT_weight_root: os.makedirs(path,exist_ok=True) @@ -708,7 +822,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: prompt_language = gr.Dropdown( label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文"), ) - inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple") + inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple")if model_version!="v3"else gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple",visible=False) + sample_steps = gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=32,choices=[4,8,16,32],visible=True)if model_version=="v3"else gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=8,choices=[4,8,16,32],visible=False) gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"),'h3')) with gr.Row(): with gr.Column(scale=13): @@ -740,10 +855,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: inference_button.click( get_tts_wav, - [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs], + [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs,sample_steps], [output], ) - SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language]) + SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs]) GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], []) # gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")) diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 898ca54..4b510d9 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -118,6 +118,7 @@ def main(args): ) logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) os.environ["MASTER_ADDR"]="localhost" + os.environ["USE_LIBUV"] = "0" trainer: Trainer = Trainer( max_epochs=config["train"]["epochs"], accelerator="gpu" if torch.cuda.is_available() else "cpu", diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index f5de615..5be43c9 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -75,7 +75,7 @@ def run(rank, n_gpus, hps): dist.init_process_group( backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://", + init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank, ) @@ -193,7 +193,7 @@ def run(rank, n_gpus, hps): try: # 如果能加载自动resume _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"), + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"), net_d, optim_d, ) # D多半加载没事 @@ -201,7 +201,7 @@ def run(rank, n_gpus, hps): logger.info("loaded D") # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"), + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"), net_g, optim_g, ) @@ -455,7 +455,7 @@ def train_and_evaluate( hps.train.learning_rate, epoch, os.path.join( - "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step) + "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step) ), ) utils.save_checkpoint( @@ -464,7 +464,7 @@ def train_and_evaluate( hps.train.learning_rate, epoch, os.path.join( - "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step) + "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step) ), ) else: @@ -474,7 +474,7 @@ def train_and_evaluate( hps.train.learning_rate, epoch, os.path.join( - "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333) + "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333) ), ) utils.save_checkpoint( @@ -483,7 +483,7 @@ def train_and_evaluate( hps.train.learning_rate, epoch, os.path.join( - "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333) + "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333) ), ) if rank == 0 and hps.train.if_save_every_weights == True: diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py new file mode 100644 index 0000000..d2e72c8 --- /dev/null +++ b/GPT_SoVITS/s2_train_v3.py @@ -0,0 +1,413 @@ +import warnings +warnings.filterwarnings("ignore") +import utils, os +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.multiprocessing as mp +import torch.distributed as dist, traceback +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import autocast, GradScaler +from tqdm import tqdm +import logging, traceback + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from random import randint +from module import commons + +from module.data_utils import ( + TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader, + TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate, + DistributedBucketSampler, +) +from module.models import ( + SynthesizerTrnV3 as SynthesizerTrn, + MultiPeriodDiscriminator, +) +from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss +from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from process_ckpt import savee + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data) ######## + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + # 1100, + # 1200, + # 1300, + # 1400, + # 1500, + # 1600, + # 1700, + # 1800, + # 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + # if rank == 0: + # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) + # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, + # batch_size=1, pin_memory=True, + # drop_last=False, collate_fn=collate_fn) + + net_g = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + + # net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device) + # for name, param in net_g.named_parameters(): + # if not param.requires_grad: + # print(name, "not requires_grad") + + optim_g = torch.optim.AdamW( + filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致 + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + # optim_d = torch.optim.AdamW( + # net_d.parameters(), + # hps.train.learning_rate, + # betas=hps.train.betas, + # eps=hps.train.eps, + # ) + if torch.cuda.is_available(): + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) + # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + # net_d = net_d.to(device) + + try: # 如果能加载自动resume + # _, _, _, epoch_str = utils.load_checkpoint( + # utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"), + # net_d, + # optim_d, + # ) # D多半加载没事 + # if rank == 0: + # logger.info("loaded D") + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"), + net_g, + optim_g, + ) + global_step = (epoch_str - 1) * len(train_loader) + # epoch_str = 1 + # global_step = 0 + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + net_g.module.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ) if torch.cuda.is_available() else net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ) + ) ##测试不加载优化器 + # if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D): + # if rank == 0: + # logger.info("loaded pretrained %s" % hps.train.pretrained_s2D) + # print( + # net_d.module.load_state_dict( + # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"] + # ) if torch.cuda.is_available() else net_d.load_state_dict( + # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"] + # ) + # ) + + # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=hps.train.lr_decay, last_epoch=-1 + ) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + # optim_d, gamma=hps.train.lr_decay, last_epoch=-1 + # ) + for _ in range(epoch_str): + scheduler_g.step() + # scheduler_d.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + net_d=optim_d=scheduler_d=None + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + # scheduler_d.step() + + +def train_and_evaluate( + rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers +): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + # net_d.train() + # for batch_idx, ( + # ssl, + # ssl_lengths, + # spec, + # spec_lengths, + # y, + # y_lengths, + # text, + # text_lengths, + # ) in enumerate(tqdm(train_loader)): + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in tqdm(enumerate(train_loader)): + if torch.cuda.is_available(): + spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( + rank, non_blocking=True + ) + mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda( + rank, non_blocking=True + ) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda( + rank, non_blocking=True + ) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + mel, mel_lengths = mel.to(device), mel_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.to(device), text_lengths.to(device) + + with autocast(enabled=hps.train.fp16_run): + cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths) + loss_gen_all=cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]['lr'] + # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + losses = [cfm_loss] + logger.info('Train Epoch: {} [{:.0f}%]'.format( + epoch, + 100. * batch_idx / len(train_loader))) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + # image_dict = { + # "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + # "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + # "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), + # "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()), + # } + utils.summarize( + writer=writer, + global_step=global_step, + # images=image_dict, + scalars=scalar_dict) + + # if global_step % hps.train.eval_interval == 0: + # # evaluate(hps, net_g, eval_loader, writer_eval) + # utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step)),scaler) + # # utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step)),scaler) + # # keep_ckpts = getattr(hps.train, 'keep_ckpts', 3) + # # if keep_ckpts > 0: + # # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) + + + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step) + ), + ) + # utils.save_checkpoint( + # net_d, + # optim_d, + # hps.train.learning_rate, + # epoch, + # os.path.join( + # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step) + # ), + # ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333) + ), + ) + # utils.save_checkpoint( + # net_d, + # optim_d, + # hps.train.learning_rate, + # epoch, + # os.path.join( + # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333) + # ), + # ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + ckpt, + hps.name + "_e%s_s%s" % (epoch, global_step), + epoch, + global_step, + hps, + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +if __name__ == "__main__": + main()