diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4917d32..7d79fbd 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -1,49 +1,55 @@ import os -gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt") -sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth") -cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base") -bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large") -infer_ttswebui=os.environ.get("infer_ttswebui",9872) -infer_ttswebui=int(infer_ttswebui) -if("_CUDA_VISIBLE_DEVICES"in os.environ): - os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"] -is_half=eval(os.environ.get("is_half","True")) + +gpt_path = os.environ.get( + "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" +) +sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth") +cnhubert_base_path = os.environ.get( + "cnhubert_base_path", "pretrained_models/chinese-hubert-base" +) +bert_path = os.environ.get( + "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large" +) +infer_ttswebui = os.environ.get("infer_ttswebui", 9872) +infer_ttswebui = int(infer_ttswebui) +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +is_half = eval(os.environ.get("is_half", "True")) import gradio as gr from transformers import AutoModelForMaskedLM, AutoTokenizer -import sys,torch,numpy as np -from pathlib import Path -import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile +import torch, numpy as np +import os, librosa, torch + # torch.backends.cuda.sdp_kernel("flash") # torch.backends.cuda.enable_flash_sdp(True) # torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0 # torch.backends.cuda.enable_math_sdp(True) -from random import shuffle -from AR.utils import get_newest_ckpt -from glob import glob -from tqdm import tqdm from feature_extractor import cnhubert -cnhubert.cnhubert_base_path=cnhubert_base_path -from io import BytesIO + +cnhubert.cnhubert_base_path = cnhubert_base_path from module.models import SynthesizerTrn from AR.models.t2s_lightning_module import Text2SemanticLightningModule -from AR.utils.io import load_yaml_config from text import cleaned_text_to_sequence -from text.cleaner import text_to_sequence, clean_text +from text.cleaner import clean_text from time import time as ttime from module.mel_processing import spectrogram_torch from my_utils import load_audio -device="cuda" +device = "cuda" tokenizer = AutoTokenizer.from_pretrained(bert_path) -bert_model=AutoModelForMaskedLM.from_pretrained(bert_path) -if(is_half==True):bert_model=bert_model.half().to(device) -else:bert_model=bert_model.to(device) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +if is_half == True: + bert_model = bert_model.half().to(device) +else: + bert_model = bert_model.to(device) + + # bert_model=bert_model.to(device) def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: - inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model + inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model res = bert_model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] assert len(word2ph) == len(text) @@ -55,9 +61,12 @@ def get_bert_feature(text, word2ph): # if(is_half==True):phone_level_feature=phone_level_feature.half() return phone_level_feature.T + n_semantic = 1024 -dict_s2=torch.load(sovits_path,map_location="cpu") -hps=dict_s2["config"] +dict_s2 = torch.load(sovits_path, map_location="cpu") +hps = dict_s2["config"] + + class DictToAttrRecursive: def __init__(self, input_dict): for key, value in input_dict.items(): @@ -67,206 +76,271 @@ class DictToAttrRecursive: else: setattr(self, key, value) + hps = DictToAttrRecursive(hps) -hps.model.semantic_frame_rate="25hz" -dict_s1=torch.load(gpt_path,map_location="cpu") -config=dict_s1["config"] -ssl_model=cnhubert.get_model() -if(is_half==True):ssl_model=ssl_model.half().to(device) -else:ssl_model=ssl_model.to(device) +hps.model.semantic_frame_rate = "25hz" +dict_s1 = torch.load(gpt_path, map_location="cpu") +config = dict_s1["config"] +ssl_model = cnhubert.get_model() +if is_half == True: + ssl_model = ssl_model.half().to(device) +else: + ssl_model = ssl_model.to(device) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, - **hps.model) -if(is_half==True):vq_model=vq_model.half().to(device) -else:vq_model=vq_model.to(device) + **hps.model +) +if is_half == True: + vq_model = vq_model.half().to(device) +else: + vq_model = vq_model.to(device) vq_model.eval() -print(vq_model.load_state_dict(dict_s2["weight"],strict=False)) +print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) hz = 50 -max_sec = config['data']['max_sec'] +max_sec = config["data"]["max_sec"] # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo -t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False) +t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False) t2s_model.load_state_dict(dict_s1["weight"]) -if(is_half==True):t2s_model=t2s_model.half() -t2s_model=t2s_model.to(device) +if is_half == True: + t2s_model = t2s_model.half() +t2s_model = t2s_model.to(device) t2s_model.eval() total = sum([param.nelement() for param in t2s_model.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6)) + + def get_spepc(hps, filename): - audio=load_audio(filename,int(hps.data.sampling_rate)) - audio=torch.FloatTensor(audio) + audio = load_audio(filename, int(hps.data.sampling_rate)) + audio = torch.FloatTensor(audio) audio_norm = audio audio_norm = audio_norm.unsqueeze(0) - spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False) + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) return spec -dict_language={ - "中文":"zh", - "英文":"en", - "日文":"ja" -} -def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language): + +dict_language = {"中文": "zh", "英文": "en", "日文": "ja"} + + +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() - prompt_text=prompt_text.strip("\n") - prompt_language,text=prompt_language,text.strip("\n") + prompt_text = prompt_text.strip("\n") + prompt_language, text = prompt_language, text.strip("\n") with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙 wav16k = torch.from_numpy(wav16k) - if(is_half==True):wav16k=wav16k.half().to(device) - else:wav16k=wav16k.to(device) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float() + if is_half == True: + wav16k = wav16k.half().to(device) + else: + wav16k = wav16k.to(device) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ + "last_hidden_state" + ].transpose( + 1, 2 + ) # .float() codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] t1 = ttime() - prompt_language=dict_language[prompt_language] - text_language=dict_language[text_language] + prompt_language = dict_language[prompt_language] + text_language = dict_language[text_language] phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) - phones1=cleaned_text_to_sequence(phones1) - texts=text.split("\n") + phones1 = cleaned_text_to_sequence(phones1) + texts = text.split("\n") audio_opt = [] - zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32) + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) for text in texts: phones2, word2ph2, norm_text2 = clean_text(text, text_language) phones2 = cleaned_text_to_sequence(phones2) - if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device) - else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device) - if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device) - else:bert2 = torch.zeros((1024, len(phones2))).to(bert1) + if prompt_language == "zh": + bert1 = get_bert_feature(norm_text1, word2ph1).to(device) + else: + bert1 = torch.zeros( + (1024, len(phones1)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + if text_language == "zh": + bert2 = get_bert_feature(norm_text2, word2ph2).to(device) + else: + bert2 = torch.zeros((1024, len(phones2))).to(bert1) bert = torch.cat([bert1, bert2], 1) - all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) prompt = prompt_semantic.unsqueeze(0).to(device) t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( - pred_semantic,idx = t2s_model.model.infer_panel( + pred_semantic, idx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, prompt, bert, # prompt_phone_len=ph_offset, - top_k=config['inference']['top_k'], - early_stop_num=hz * max_sec) + top_k=config["inference"]["top_k"], + early_stop_num=hz * max_sec, + ) t3 = ttime() # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 - refer = get_spepc(hps, ref_wav_path)#.to(device) - if(is_half==True):refer=refer.half().to(device) - else:refer=refer.to(device) + pred_semantic = pred_semantic[:, -idx:].unsqueeze( + 0 + ) # .unsqueeze(0)#mq要多unsqueeze一次 + refer = get_spepc(hps, ref_wav_path) # .to(device) + if is_half == True: + refer = refer.half().to(device) + else: + refer = refer.to(device) # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] - audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分 + audio = ( + vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer + ) + .detach() + .cpu() + .numpy()[0, 0] + ) ###试试重建不带上prompt部分 audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16) + yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype( + np.int16 + ) + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} # 不考虑省略号 -splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号 def split(todo_text): todo_text = todo_text.replace("……", "。").replace("——", ",") - if (todo_text[-1] not in splits): todo_text += "。" + if todo_text[-1] not in splits: + todo_text += "。" i_split_head = i_split_tail = 0 len_text = len(todo_text) todo_texts = [] - while (1): - if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 - if (todo_text[i_split_head] in splits): + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: i_split_head += 1 todo_texts.append(todo_text[i_split_tail:i_split_head]) i_split_tail = i_split_head else: i_split_head += 1 return todo_texts + + def cut1(inp): - inp=inp.strip("\n") - inps=split(inp) - split_idx=list(range(0,len(inps),5)) - split_idx[-1]=None - if(len(split_idx)>1): - opts=[] - for idx in range(len(split_idx)-1): - opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]])) + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 5)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]])) else: - opts=[inp] + opts = [inp] return "\n".join(opts) + def cut2(inp): - inp=inp.strip("\n") - inps=split(inp) - if(len(inps)<2):return [inp] - opts=[] - summ=0 - tmp_str="" + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return [inp] + opts = [] + summ = 0 + tmp_str = "" for i in range(len(inps)): - summ+=len(inps[i]) - tmp_str+=inps[i] - if(summ>50): - summ=0 + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 opts.append(tmp_str) - tmp_str="" - if(tmp_str!=""):opts.append(tmp_str) - if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起 - opts[-2]=opts[-2]+opts[-1] - opts=opts[:-1] + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] return "\n".join(opts) + def cut3(inp): - inp=inp.strip("\n") - return "\n".join(["%s。"%item for item in inp.strip("。").split("。")]) + inp = inp.strip("\n") + return "\n".join(["%s。" % item for item in inp.strip("。").split("。")]) + with gr.Blocks(title="GPT-SoVITS WebUI") as app: gr.Markdown( - value= - "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE." + value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE." ) # with gr.Tabs(): # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")): with gr.Group(): - gr.Markdown( - value= - "*请上传并填写参考信息" - ) + gr.Markdown(value="*请上传并填写参考信息") with gr.Row(): inp_ref = gr.Audio(label="请上传参考音频", type="filepath") - prompt_text= gr.Textbox(label="参考音频的文本",value="") - prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文") - gr.Markdown( - value= - "*请填写需要合成的目标文本" - ) + prompt_text = gr.Textbox(label="参考音频的文本", value="") + prompt_language = gr.Dropdown( + label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文" + ) + gr.Markdown(value="*请填写需要合成的目标文本") with gr.Row(): - text=gr.Textbox(label="需要合成的文本",value="") - text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文") - inference_button=gr.Button("合成语音", variant="primary") + text = gr.Textbox(label="需要合成的文本", value="") + text_language = gr.Dropdown( + label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文" + ) + inference_button = gr.Button("合成语音", variant="primary") output = gr.Audio(label="输出的语音") - inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output]) - - gr.Markdown( - value= - "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。" + inference_button.click( + get_tts_wav, + [inp_ref, prompt_text, prompt_language, text, text_language], + [output], ) + + gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。") with gr.Row(): - text_inp=gr.Textbox(label="需要合成的切分前文本",value="") + text_inp = gr.Textbox(label="需要合成的切分前文本", value="") button1 = gr.Button("凑五句一切", variant="primary") button2 = gr.Button("凑50字一切", variant="primary") button3 = gr.Button("按中文句号。切", variant="primary") text_opt = gr.Textbox(label="切分后文本", value="") - button1.click(cut1,[text_inp],[text_opt]) - button2.click(cut2,[text_inp],[text_opt]) - button3.click(cut3,[text_inp],[text_opt]) - gr.Markdown( - value= - "后续将支持混合语种编码文本输入。" - ) + button1.click(cut1, [text_inp], [text_opt]) + button2.click(cut2, [text_inp], [text_opt]) + button3.click(cut3, [text_inp], [text_opt]) + gr.Markdown(value="后续将支持混合语种编码文本输入。") app.queue(concurrency_count=511, max_size=1022).launch( server_name="0.0.0.0", inbrowser=True, server_port=infer_ttswebui, quiet=True, -) \ No newline at end of file +) diff --git a/GPT_SoVITS/process_ckpt.py b/GPT_SoVITS/process_ckpt.py index 170dbb3..7483337 100644 --- a/GPT_SoVITS/process_ckpt.py +++ b/GPT_SoVITS/process_ckpt.py @@ -1,11 +1,12 @@ -import os -import sys import traceback from collections import OrderedDict import torch from tools.i18n.i18n import I18nAuto + i18n = I18nAuto() + + def savee(ckpt, name, epoch, steps, hps): try: opt = OrderedDict() @@ -15,8 +16,8 @@ def savee(ckpt, name, epoch, steps, hps): continue opt["weight"][key] = ckpt[key].half() opt["config"] = hps - opt["info"] = "%sepoch_%siteration" % (epoch,steps) - torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir,name)) + opt["info"] = "%sepoch_%siteration" % (epoch, steps) + torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) return "Success." except: return traceback.format_exc() diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 37166cb..4a77006 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -2,56 +2,84 @@ import os import pdb -if("_CUDA_VISIBLE_DEVICES"in os.environ): - os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"] +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] import argparse import logging from pathlib import Path -import torch,platform +import torch, platform from pytorch_lightning import seed_everything from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger +from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger from pytorch_lightning.strategies import DDPStrategy from AR.data.data_module import Text2SemanticDataModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.utils.io import load_yaml_config -logging.getLogger('numba').setLevel(logging.WARNING) -logging.getLogger('matplotlib').setLevel(logging.WARNING) -torch.set_float32_matmul_precision('high') + +logging.getLogger("numba").setLevel(logging.WARNING) +logging.getLogger("matplotlib").setLevel(logging.WARNING) +torch.set_float32_matmul_precision("high") from AR.utils import get_newest_ckpt from collections import OrderedDict + + class my_model_ckpt(ModelCheckpoint): - def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs): + def __init__( + self, + config, + if_save_latest, + if_save_every_weights, + half_weights_save_dir, + exp_name, + **kwargs + ): super().__init__(**kwargs) - self.if_save_latest=if_save_latest - self.if_save_every_weights=if_save_every_weights - self.half_weights_save_dir=half_weights_save_dir - self.exp_name=exp_name - self.config=config + self.if_save_latest = if_save_latest + self.if_save_every_weights = if_save_every_weights + self.half_weights_save_dir = half_weights_save_dir + self.exp_name = exp_name + self.config = config def on_train_epoch_end(self, trainer, pl_module): - if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): + if not self._should_skip_saving_checkpoint( + trainer + ) and self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) - if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: - if(self.if_save_latest==True):####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt - to_clean=list(os.listdir(self.dirpath)) + if ( + self._every_n_epochs >= 1 + and (trainer.current_epoch + 1) % self._every_n_epochs == 0 + ): + if ( + self.if_save_latest == True + ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt + to_clean = list(os.listdir(self.dirpath)) self._save_topk_checkpoint(trainer, monitor_candidates) - if (self.if_save_latest == True): + if self.if_save_latest == True: for name in to_clean: try: - os.remove("%s/%s"%(self.dirpath,name)) - except:pass - if(self.if_save_every_weights==True): - to_save_od=OrderedDict() - to_save_od["weight"]=OrderedDict() - dictt=trainer.strategy._lightning_module.state_dict() - for key in dictt:to_save_od["weight"][key]=dictt[key].half() - to_save_od["config"]=self.config - to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1) - torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1)) + os.remove("%s/%s" % (self.dirpath, name)) + except: + pass + if self.if_save_every_weights == True: + to_save_od = OrderedDict() + to_save_od["weight"] = OrderedDict() + dictt = trainer.strategy._lightning_module.state_dict() + for key in dictt: + to_save_od["weight"][key] = dictt[key].half() + to_save_od["config"] = self.config + to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) + torch.save( + to_save_od, + "%s/%s-e%s.ckpt" + % ( + self.half_weights_save_dir, + self.exp_name, + trainer.current_epoch + 1, + ), + ) self._save_last_checkpoint(trainer, monitor_candidates) @@ -61,41 +89,45 @@ def main(args): output_dir = Path(config["output_dir"]) output_dir.mkdir(parents=True, exist_ok=True) - ckpt_dir = output_dir / 'ckpt' + ckpt_dir = output_dir / "ckpt" ckpt_dir.mkdir(parents=True, exist_ok=True) - seed_everything(config["train"]["seed"], workers=True) ckpt_callback: ModelCheckpoint = my_model_ckpt( config=config, - if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"], + if_save_latest=config["train"]["if_save_latest"], + if_save_every_weights=config["train"]["if_save_every_weights"], + half_weights_save_dir=config["train"]["half_weights_save_dir"], + exp_name=config["train"]["exp_name"], save_top_k=-1, - monitor='top_3_acc', - mode='max', + monitor="top_3_acc", + mode="max", save_on_train_epoch_end=True, every_n_epochs=config["train"]["save_every_n_epoch"], dirpath=ckpt_dir, ) - logger = TensorBoardLogger( - name=output_dir.stem, - save_dir=output_dir - ) + logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) trainer: Trainer = Trainer( max_epochs=config["train"]["epochs"], - accelerator='gpu', + accelerator="gpu", # val_check_interval=9999999999999999999999,###不要验证 # check_val_every_n_epoch=None, limit_val_batches=0, devices=-1, benchmark=False, fast_dev_run=False, - strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"), + strategy=DDPStrategy( + process_group_backend="nccl" if platform.system() != "Windows" else "gloo" + ), precision=config["train"]["precision"], - logger=logger,num_sanity_val_steps=0, - callbacks=[ckpt_callback]) + logger=logger, + num_sanity_val_steps=0, + callbacks=[ckpt_callback], + ) model: Text2SemanticLightningModule = Text2SemanticLightningModule( - config, output_dir) + config, output_dir + ) data_module: Text2SemanticDataModule = Text2SemanticDataModule( config, @@ -116,14 +148,15 @@ def main(args): # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - '-c', - '--config_file', + "-c", + "--config_file", type=str, - default='configs/s1longer.yaml', - help='path of config file') + default="configs/s1longer.yaml", + help="path of config file", + ) # args for dataset # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv') # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt') diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index 7a455eb..d2ec262 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -1,4 +1,5 @@ -import utils,os +import utils, os + hps = utils.get_hparams(stage=2) os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") import torch @@ -6,11 +7,12 @@ from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import torch.multiprocessing as mp -import torch.distributed as dist,traceback +import torch.distributed as dist, traceback from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda.amp import autocast, GradScaler from tqdm import tqdm -import logging,traceback +import logging, traceback + logging.getLogger("matplotlib").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO) @@ -20,37 +22,42 @@ from module import commons from module.data_utils import ( TextAudioSpeakerLoader, TextAudioSpeakerCollate, - DistributedBucketSampler + DistributedBucketSampler, ) from module.models import ( SynthesizerTrn, MultiPeriodDiscriminator, ) -from module.losses import ( - generator_loss, - discriminator_loss, - feature_loss, - kl_loss -) +from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch from process_ckpt import savee + torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = False ###反正A100fp32更快,那试试tf32吧 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True -torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响 +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 # from config import pretrained_s2G,pretrained_s2D global_step = 0 + + def main(): """Assume Single Node Multi GPUs Training Only""" assert torch.cuda.is_available(), "CPU training is not allowed." n_gpus = torch.cuda.device_count() - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(randint(20000, 55555)) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) - mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) def run(rank, n_gpus, hps): @@ -62,21 +69,54 @@ def run(rank, n_gpus, hps): writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) - dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank) + dist.init_process_group( + backend="gloo" if os.name == "nt" else "nccl", + init_method="env://", + world_size=n_gpus, + rank=rank, + ) torch.manual_seed(hps.train.seed) torch.cuda.set_device(rank) - train_dataset = TextAudioSpeakerLoader(hps.data)######## + train_dataset = TextAudioSpeakerLoader(hps.data) ######## train_sampler = DistributedBucketSampler( train_dataset, hps.train.batch_size, - [32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900], + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + 1100, + 1200, + 1300, + 1400, + 1500, + 1600, + 1700, + 1800, + 1900, + ], num_replicas=n_gpus, rank=rank, - shuffle=True) + shuffle=True, + ) collate_fn = TextAudioSpeakerCollate() - train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True, - collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16) + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=16, + ) # if rank == 0: # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, @@ -87,17 +127,21 @@ def run(rank, n_gpus, hps): hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, - **hps.model).cuda(rank) + **hps.model, + ).cuda(rank) net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) for name, param in net_g.named_parameters(): if not param.requires_grad: - print(name,"not requires_grad") + print(name, "not requires_grad") te_p = list(map(id, net_g.enc_p.text_embedding.parameters())) et_p = list(map(id, net_g.enc_p.encoder_text.parameters())) mrte_p = list(map(id, net_g.enc_p.mrte.parameters())) - base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters()) + base_params = filter( + lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad, + net_g.parameters(), + ) # te_p=net_g.enc_p.text_embedding.parameters() # et_p=net_g.enc_p.encoder_text.parameters() @@ -106,31 +150,46 @@ def run(rank, n_gpus, hps): optim_g = torch.optim.AdamW( # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致 [ - {"params":base_params,"lr":hps.train.learning_rate}, - {"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, - {"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, - {"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, + {"params": base_params, "lr": hps.train.learning_rate}, + { + "params": net_g.enc_p.text_embedding.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + { + "params": net_g.enc_p.encoder_text.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + { + "params": net_g.enc_p.mrte.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, ], hps.train.learning_rate, betas=hps.train.betas, - eps=hps.train.eps) + eps=hps.train.eps, + ) optim_d = torch.optim.AdamW( net_d.parameters(), hps.train.learning_rate, betas=hps.train.betas, - eps=hps.train.eps) - net_g = DDP(net_g, device_ids=[rank],find_unused_parameters=True) - net_d = DDP(net_d, device_ids=[rank],find_unused_parameters=True) + eps=hps.train.eps, + ) + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) try: # 如果能加载自动resume _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d + utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"), + net_d, + optim_d, ) # D多半加载没事 if rank == 0: logger.info("loaded D") # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g + utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"), + net_g, + optim_g, ) global_step = (epoch_str - 1) * len(train_loader) # epoch_str = 1 @@ -144,7 +203,8 @@ def run(rank, n_gpus, hps): logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) print( net_g.module.load_state_dict( - torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, ) ) ##测试不加载优化器 if hps.train.pretrained_s2D != "": @@ -159,8 +219,12 @@ def run(rank, n_gpus, hps): # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1) + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=hps.train.lr_decay, last_epoch=-1 + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, gamma=hps.train.lr_decay, last_epoch=-1 + ) for _ in range(epoch_str): scheduler_g.step() scheduler_d.step() @@ -169,17 +233,39 @@ def run(rank, n_gpus, hps): for epoch in range(epoch_str, hps.train.epochs + 1): if rank == 0: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, - # [train_loader, eval_loader], logger, [writer, writer_eval]) - [train_loader, None], logger, [writer, writer_eval]) + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) else: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, - [train_loader, None], None, None) + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) scheduler_g.step() scheduler_d.step() -def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): +def train_and_evaluate( + rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers +): net_g, net_d = nets optim_g, optim_d = optims # scheduler_g, scheduler_d = schedulers @@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade net_g.train() net_d.train() - for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)): - spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) - y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) + for batch_idx, ( + ssl, + ssl_lengths, + spec, + spec_lengths, + y, + y_lengths, + text, + text_lengths, + ) in tqdm(enumerate(train_loader)): + spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( + rank, non_blocking=True + ) + y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( + rank, non_blocking=True + ) ssl = ssl.cuda(rank, non_blocking=True) - ssl.requires_grad=False + ssl.requires_grad = False # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) - text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda( + rank, non_blocking=True + ) with autocast(enabled=hps.train.fp16_run): - y_hat, kl_ssl, ids_slice, x_mask, z_mask, \ - (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths) + ( + y_hat, + kl_ssl, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + stats_ssl, + ) = net_g(ssl, spec, spec_lengths, text, text_lengths) mel = spec_to_mel_torch( spec, @@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade hps.data.n_mel_channels, hps.data.sampling_rate, hps.data.mel_fmin, - hps.data.mel_fmax) - y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + hps.data.mel_fmax, + ) + y_mel = commons.slice_segments( + mel, ids_slice, hps.train.segment_size // hps.data.hop_length + ) y_hat_mel = mel_spectrogram_torch( y_hat.squeeze(1), hps.data.filter_length, @@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin, - hps.data.mel_fmax + hps.data.mel_fmax, ) - y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + y = commons.slice_segments( + y, ids_slice * hps.data.hop_length, hps.train.segment_size + ) # slice # Discriminator y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) loss_disc_all = loss_disc optim_d.zero_grad() scaler.scale(loss_disc_all).backward() @@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade if rank == 0: if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]['lr'] + lr = optim_g.param_groups[0]["lr"] losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] - logger.info('Train Epoch: {} [{:.0f}%]'.format( - epoch, - 100. * batch_idx / len(train_loader))) + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, 100.0 * batch_idx / len(train_loader) + ) + ) logger.info([x.item() for x in losses] + [global_step, lr]) - scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, - "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } scalar_dict.update( - {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl}) + { + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/kl_ssl": kl_ssl, + "loss/g/kl": loss_kl, + } + ) # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), - "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), - "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), - "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()), + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy() + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy() + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy() + ), + "all/stats_ssl": utils.plot_spectrogram_to_numpy( + stats_ssl[0].data.cpu().numpy() + ), } utils.summarize( writer=writer, global_step=global_step, images=image_dict, - scalars=scalar_dict) + scalars=scalar_dict, + ) global_step += 1 if epoch % hps.train.save_every_epoch == 0 and rank == 0: if hps.train.if_save_latest == 0: @@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade optim_g, hps.train.learning_rate, epoch, - os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)), + os.path.join( + "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step) + ), ) utils.save_checkpoint( net_d, optim_d, hps.train.learning_rate, epoch, - os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)), + os.path.join( + "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step) + ), ) else: utils.save_checkpoint( @@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade optim_g, hps.train.learning_rate, epoch, - os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)), + os.path.join( + "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333) + ), ) utils.save_checkpoint( net_d, optim_d, hps.train.learning_rate, epoch, - os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)), + os.path.join( + "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333) + ), ) if rank == 0 and hps.train.if_save_every_weights == True: if hasattr(net_g, "module"): @@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade ) ) - - if rank == 0: - logger.info('====> Epoch: {}'.format(epoch)) - + logger.info("====> Epoch: {}".format(epoch)) def evaluate(hps, generator, eval_loader, writer_eval): @@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval): audio_dict = {} print("Evaluating ...") with torch.no_grad(): - for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader): + for batch_idx, ( + ssl, + ssl_lengths, + spec, + spec_lengths, + y, + y_lengths, + text, + text_lengths, + ) in enumerate(eval_loader): print(111) spec, spec_lengths = spec.cuda(), spec_lengths.cuda() y, y_lengths = y.cuda(), y_lengths.cuda() ssl = ssl.cuda() text, text_lengths = text.cuda(), text_lengths.cuda() for test in [0, 1]: - - y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test) + y_hat, mask, *_ = generator.module.infer( + ssl, spec, spec_lengths, text, text_lengths, test=test + ) y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length mel = spec_to_mel_torch( @@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval): hps.data.n_mel_channels, hps.data.sampling_rate, hps.data.mel_fmin, - hps.data.mel_fmax) + hps.data.mel_fmax, + ) y_hat_mel = mel_spectrogram_torch( y_hat.squeeze(1).float(), hps.data.filter_length, @@ -373,16 +526,26 @@ def evaluate(hps, generator, eval_loader, writer_eval): hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin, - hps.data.mel_fmax + hps.data.mel_fmax, ) - image_dict.update({ - f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) - }) - audio_dict.update({ - f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]] - }) - image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())}) - audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :y_lengths[0]]}) + image_dict.update( + { + f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].cpu().numpy() + ) + } + ) + audio_dict.update( + {f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]} + ) + image_dict.update( + { + f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( + mel[0].cpu().numpy() + ) + } + ) + audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None) # audio_dict.update({ @@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval): global_step=global_step, images=image_dict, audios=audio_dict, - audio_sampling_rate=hps.data.sampling_rate + audio_sampling_rate=hps.data.sampling_rate, ) generator.train() + if __name__ == "__main__": main() diff --git a/GPT_SoVITS/utils.py b/GPT_SoVITS/utils.py index e3ed89b..0ce03b3 100644 --- a/GPT_SoVITS/utils.py +++ b/GPT_SoVITS/utils.py @@ -12,8 +12,9 @@ import numpy as np from scipy.io.wavfile import read import torch import logging -logging.getLogger('numba').setLevel(logging.ERROR) -logging.getLogger('matplotlib').setLevel(logging.ERROR) + +logging.getLogger("numba").setLevel(logging.ERROR) +logging.getLogger("matplotlib").setLevel(logging.ERROR) MATPLOTLIB_FLAG = False @@ -23,13 +24,17 @@ logger = logging def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): assert os.path.isfile(checkpoint_path) - checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') - iteration = checkpoint_dict['iteration'] - learning_rate = checkpoint_dict['learning_rate'] - if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: - optimizer.load_state_dict(checkpoint_dict['optimizer']) - saved_state_dict = checkpoint_dict['model'] - if hasattr(model, 'module'): + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if ( + optimizer is not None + and not skip_optimizer + and checkpoint_dict["optimizer"] is not None + ): + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() @@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False # assert "quantizer" not in k # print("load", k) new_state_dict[k] = saved_state_dict[k] - assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) + assert saved_state_dict[k].shape == v.shape, ( + saved_state_dict[k].shape, + v.shape, + ) except: traceback.print_exc() - print("error, %s is not in the checkpoint" % k)#shape不对也会,比如text_embedding当cleaner修改时 + print( + "error, %s is not in the checkpoint" % k + ) # shape不对也会,比如text_embedding当cleaner修改时 new_state_dict[k] = v - if hasattr(model, 'module'): + if hasattr(model, "module"): model.module.load_state_dict(new_state_dict) else: model.load_state_dict(new_state_dict) print("load ") - logger.info("Loaded checkpoint '{}' (iteration {})".format( - checkpoint_path, iteration)) + logger.info( + "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) + ) return model, optimizer, learning_rate, iteration def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): - logger.info("Saving model and optimizer state at iteration {} to {}".format( - iteration, checkpoint_path)) - if hasattr(model, 'module'): + logger.info( + "Saving model and optimizer state at iteration {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() - torch.save({'model': state_dict, - 'iteration': iteration, - 'optimizer': optimizer.state_dict(), - 'learning_rate': learning_rate}, checkpoint_path) + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) -def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): writer.add_histogram(k, v, global_step) for k, v in images.items(): - writer.add_image(k, v, global_step, dataformats='HWC') + writer.add_image(k, v, global_step, dataformats="HWC") for k, v in audios.items(): writer.add_audio(k, v, global_step, audio_sampling_rate) @@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib + matplotlib.use("Agg") MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') + mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data @@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib + matplotlib.use("Agg") MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') + mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(6, 4)) - im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', - interpolation='none') + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) fig.colorbar(im, ax=ax) - xlabel = 'Decoder timestep' + xlabel = "Decoder timestep" if info is not None: - xlabel += '\n\n' + info + xlabel += "\n\n" + info plt.xlabel(xlabel) - plt.ylabel('Encoder timestep') + plt.ylabel("Encoder timestep") plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data @@ -147,16 +176,31 @@ def load_wav_to_torch(full_path): def load_filepaths_and_text(filename, split="|"): - with open(filename, encoding='utf-8') as f: + with open(filename, encoding="utf-8") as f: filepaths_and_text = [line.strip().split(split) for line in f] return filepaths_and_text def get_hparams(init=True, stage=1): parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration') - parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir') - parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step') + parser.add_argument( + "-c", + "--config", + type=str, + default="./configs/s2.json", + help="JSON file for configuration", + ) + parser.add_argument( + "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir" + ) + parser.add_argument( + "-rs", + "--resume_step", + type=int, + required=False, + default=None, + help="resume step", + ) # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory') # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights') # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights') @@ -172,7 +216,7 @@ def get_hparams(init=True, stage=1): hparams.pretrain = args.pretrain hparams.resume_step = args.resume_step # hparams.data.exp_dir = args.exp_dir - if stage ==1: + if stage == 1: model_dir = hparams.s1_ckpt_dir else: model_dir = hparams.s2_ckpt_dir @@ -186,29 +230,38 @@ def get_hparams(init=True, stage=1): return hparams - -def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True): +def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): """Freeing up space by deleting saved ckpts - Arguments: - path_to_models -- Path to the model directory - n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth - sort_by_time -- True -> chronologically delete ckpts - False -> lexicographically delete ckpts - """ + Arguments: + path_to_models -- Path to the model directory + n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth + sort_by_time -- True -> chronologically delete ckpts + False -> lexicographically delete ckpts + """ import re - ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] - name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1))) - time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) + + ckpts_files = [ + f + for f in os.listdir(path_to_models) + if os.path.isfile(os.path.join(path_to_models, f)) + ] + name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1)) + time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)) sort_key = time_key if sort_by_time else name_key - x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], - key=sort_key) - to_del = [os.path.join(path_to_models, fn) for fn in - (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])] + x_sorted = lambda _x: sorted( + [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], + key=sort_key, + ) + to_del = [ + os.path.join(path_to_models, fn) + for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep]) + ] del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") del_routine = lambda x: [os.remove(x), del_info(x)] rs = [del_routine(fn) for fn in to_del] + def get_hparams_from_dir(model_dir): config_save_path = os.path.join(model_dir, "config.json") with open(config_save_path, "r") as f: @@ -228,12 +281,15 @@ def get_hparams_from_file(config_path): hparams = HParams(**config) return hparams + def check_git_hash(model_dir): source_dir = os.path.dirname(os.path.realpath(__file__)) if not os.path.exists(os.path.join(source_dir, ".git")): - logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( - source_dir - )) + logger.warn( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + ) + ) return cur_hash = subprocess.getoutput("git rev-parse HEAD") @@ -242,8 +298,11 @@ def check_git_hash(model_dir): if os.path.exists(path): saved_hash = open(path).read() if saved_hash != cur_hash: - logger.warn("git hash values are different. {}(saved) != {}(current)".format( - saved_hash[:8], cur_hash[:8])) + logger.warn( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8] + ) + ) else: open(path, "w").write(cur_hash) @@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"): return logger -class HParams(): +class HParams: def __init__(self, **kwargs): for k, v in kwargs.items(): if type(v) == dict: @@ -294,5 +353,10 @@ class HParams(): def __repr__(self): return self.__dict__.__repr__() -if __name__ == '__main__': - print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac')) \ No newline at end of file + +if __name__ == "__main__": + print( + load_wav_to_torch( + "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac" + ) + )