Code refactor + remove unused imports

This commit is contained in:
Blaise 2024-01-16 17:10:27 +01:00
parent 9031ac9a92
commit 0d92575115
5 changed files with 671 additions and 335 deletions

View File

@ -1,49 +1,55 @@
import os import os
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth") gpt_path = os.environ.get(
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base") "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large") )
infer_ttswebui=os.environ.get("infer_ttswebui",9872) sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
infer_ttswebui=int(infer_ttswebui) cnhubert_base_path = os.environ.get(
if("_CUDA_VISIBLE_DEVICES"in os.environ): "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"] )
is_half=eval(os.environ.get("is_half","True")) bert_path = os.environ.get(
"bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True"))
import gradio as gr import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
import sys,torch,numpy as np import torch, numpy as np
from pathlib import Path import os, librosa, torch
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
# torch.backends.cuda.sdp_kernel("flash") # torch.backends.cuda.sdp_kernel("flash")
# torch.backends.cuda.enable_flash_sdp(True) # torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0 # torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
# torch.backends.cuda.enable_math_sdp(True) # torch.backends.cuda.enable_math_sdp(True)
from random import shuffle
from AR.utils import get_newest_ckpt
from glob import glob
from tqdm import tqdm
from feature_extractor import cnhubert from feature_extractor import cnhubert
cnhubert.cnhubert_base_path=cnhubert_base_path
from io import BytesIO cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
from text.cleaner import text_to_sequence, clean_text from text.cleaner import clean_text
from time import time as ttime from time import time as ttime
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from my_utils import load_audio from my_utils import load_audio
device="cuda" device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model=AutoModelForMaskedLM.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if(is_half==True):bert_model=bert_model.half().to(device) if is_half == True:
else:bert_model=bert_model.to(device) bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
# bert_model=bert_model.to(device) # bert_model=bert_model.to(device)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt") inputs = tokenizer(text, return_tensors="pt")
for i in inputs: for i in inputs:
inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题精度随bert_model inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model
res = bert_model(**inputs, output_hidden_states=True) res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text) assert len(word2ph) == len(text)
@ -55,9 +61,12 @@ def get_bert_feature(text, word2ph):
# if(is_half==True):phone_level_feature=phone_level_feature.half() # if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T return phone_level_feature.T
n_semantic = 1024 n_semantic = 1024
dict_s2=torch.load(sovits_path,map_location="cpu") dict_s2 = torch.load(sovits_path, map_location="cpu")
hps=dict_s2["config"] hps = dict_s2["config"]
class DictToAttrRecursive: class DictToAttrRecursive:
def __init__(self, input_dict): def __init__(self, input_dict):
for key, value in input_dict.items(): for key, value in input_dict.items():
@ -67,206 +76,271 @@ class DictToAttrRecursive:
else: else:
setattr(self, key, value) setattr(self, key, value)
hps = DictToAttrRecursive(hps) hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate="25hz" hps.model.semantic_frame_rate = "25hz"
dict_s1=torch.load(gpt_path,map_location="cpu") dict_s1 = torch.load(gpt_path, map_location="cpu")
config=dict_s1["config"] config = dict_s1["config"]
ssl_model=cnhubert.get_model() ssl_model = cnhubert.get_model()
if(is_half==True):ssl_model=ssl_model.half().to(device) if is_half == True:
else:ssl_model=ssl_model.to(device) ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**hps.model) **hps.model
if(is_half==True):vq_model=vq_model.half().to(device) )
else:vq_model=vq_model.to(device) if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"],strict=False)) print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50 hz = 50
max_sec = config['data']['max_sec'] max_sec = config["data"]["max_sec"]
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False) t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"]) t2s_model.load_state_dict(dict_s1["weight"])
if(is_half==True):t2s_model=t2s_model.half() if is_half == True:
t2s_model=t2s_model.to(device) t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval() t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()]) total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6)) print("Number of parameter: %.2fM" % (total / 1e6))
def get_spepc(hps, filename): def get_spepc(hps, filename):
audio=load_audio(filename,int(hps.data.sampling_rate)) audio = load_audio(filename, int(hps.data.sampling_rate))
audio=torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False) spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec return spec
dict_language={
"中文":"zh", dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
"英文":"en",
"日文":"ja"
} def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
t0 = ttime() t0 = ttime()
prompt_text=prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
prompt_language,text=prompt_language,text.strip("\n") prompt_language, text = prompt_language, text.strip("\n")
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙 wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
wav16k = torch.from_numpy(wav16k) wav16k = torch.from_numpy(wav16k)
if(is_half==True):wav16k=wav16k.half().to(device) if is_half == True:
else:wav16k=wav16k.to(device) wav16k = wav16k.half().to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float() else:
wav16k = wav16k.to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content) codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
t1 = ttime() t1 = ttime()
prompt_language=dict_language[prompt_language] prompt_language = dict_language[prompt_language]
text_language=dict_language[text_language] text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1=cleaned_text_to_sequence(phones1) phones1 = cleaned_text_to_sequence(phones1)
texts=text.split("\n") texts = text.split("\n")
audio_opt = [] audio_opt = []
zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32) zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
for text in texts: for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language) phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2) phones2 = cleaned_text_to_sequence(phones2)
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device) if prompt_language == "zh":
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device) bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device) else:
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1) bert1 = torch.zeros(
(1024, len(phones1)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
if text_language == "zh":
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1) bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime() t2 = ttime()
with torch.no_grad(): with torch.no_grad():
# pred_semantic = t2s_model.model.infer( # pred_semantic = t2s_model.model.infer(
pred_semantic,idx = t2s_model.model.infer_panel( pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
all_phoneme_len, all_phoneme_len,
prompt, prompt,
bert, bert,
# prompt_phone_len=ph_offset, # prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'], top_k=config["inference"]["top_k"],
early_stop_num=hz * max_sec) early_stop_num=hz * max_sec,
)
t3 = ttime() t3 = ttime()
# print(pred_semantic.shape,idx) # print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 pred_semantic = pred_semantic[:, -idx:].unsqueeze(
refer = get_spepc(hps, ref_wav_path)#.to(device) 0
if(is_half==True):refer=refer.half().to(device) ) # .unsqueeze(0)#mq要多unsqueeze一次
else:refer=refer.to(device) refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分 audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav)
t4 = ttime() t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
} # 不考虑省略号
splits={"","","","",",",".","?","!","~",":","","","",}#不考虑省略号
def split(todo_text): def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "") todo_text = todo_text.replace("……", "").replace("——", "")
if (todo_text[-1] not in splits): todo_text += "" if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0 i_split_head = i_split_tail = 0
len_text = len(todo_text) len_text = len(todo_text)
todo_texts = [] todo_texts = []
while (1): while 1:
if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 if i_split_head >= len_text:
if (todo_text[i_split_head] in splits): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1 i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head]) todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head i_split_tail = i_split_head
else: else:
i_split_head += 1 i_split_head += 1
return todo_texts return todo_texts
def cut1(inp): def cut1(inp):
inp=inp.strip("\n") inp = inp.strip("\n")
inps=split(inp) inps = split(inp)
split_idx=list(range(0,len(inps),5)) split_idx = list(range(0, len(inps), 5))
split_idx[-1]=None split_idx[-1] = None
if(len(split_idx)>1): if len(split_idx) > 1:
opts=[] opts = []
for idx in range(len(split_idx)-1): for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]])) opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else: else:
opts=[inp] opts = [inp]
return "\n".join(opts) return "\n".join(opts)
def cut2(inp): def cut2(inp):
inp=inp.strip("\n") inp = inp.strip("\n")
inps=split(inp) inps = split(inp)
if(len(inps)<2):return [inp] if len(inps) < 2:
opts=[] return [inp]
summ=0 opts = []
tmp_str="" summ = 0
tmp_str = ""
for i in range(len(inps)): for i in range(len(inps)):
summ+=len(inps[i]) summ += len(inps[i])
tmp_str+=inps[i] tmp_str += inps[i]
if(summ>50): if summ > 50:
summ=0 summ = 0
opts.append(tmp_str) opts.append(tmp_str)
tmp_str="" tmp_str = ""
if(tmp_str!=""):opts.append(tmp_str) if tmp_str != "":
if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起 opts.append(tmp_str)
opts[-2]=opts[-2]+opts[-1] if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts=opts[:-1] opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts) return "\n".join(opts)
def cut3(inp): def cut3(inp):
inp=inp.strip("\n") inp = inp.strip("\n")
return "\n".join(["%s"%item for item in inp.strip("").split("")]) return "\n".join(["%s" % item for item in inp.strip("").split("")])
with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown( gr.Markdown(
value= value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
) )
# with gr.Tabs(): # with gr.Tabs():
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")): # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
with gr.Group(): with gr.Group():
gr.Markdown( gr.Markdown(value="*请上传并填写参考信息")
value=
"*请上传并填写参考信息"
)
with gr.Row(): with gr.Row():
inp_ref = gr.Audio(label="请上传参考音频", type="filepath") inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
prompt_text= gr.Textbox(label="参考音频的文本",value="") prompt_text = gr.Textbox(label="参考音频的文本", value="")
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文") prompt_language = gr.Dropdown(
gr.Markdown( label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
value= )
"*请填写需要合成的目标文本" gr.Markdown(value="*请填写需要合成的目标文本")
)
with gr.Row(): with gr.Row():
text=gr.Textbox(label="需要合成的文本",value="") text = gr.Textbox(label="需要合成的文本", value="")
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文") text_language = gr.Dropdown(
inference_button=gr.Button("合成语音", variant="primary") label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
)
inference_button = gr.Button("合成语音", variant="primary")
output = gr.Audio(label="输出的语音") output = gr.Audio(label="输出的语音")
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output]) inference_button.click(
get_tts_wav,
gr.Markdown( [inp_ref, prompt_text, prompt_language, text, text_language],
value= [output],
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
) )
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
with gr.Row(): with gr.Row():
text_inp=gr.Textbox(label="需要合成的切分前文本",value="") text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
button1 = gr.Button("凑五句一切", variant="primary") button1 = gr.Button("凑五句一切", variant="primary")
button2 = gr.Button("凑50字一切", variant="primary") button2 = gr.Button("凑50字一切", variant="primary")
button3 = gr.Button("按中文句号。切", variant="primary") button3 = gr.Button("按中文句号。切", variant="primary")
text_opt = gr.Textbox(label="切分后文本", value="") text_opt = gr.Textbox(label="切分后文本", value="")
button1.click(cut1,[text_inp],[text_opt]) button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2,[text_inp],[text_opt]) button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3,[text_inp],[text_opt]) button3.click(cut3, [text_inp], [text_opt])
gr.Markdown( gr.Markdown(value="后续将支持混合语种编码文本输入。")
value=
"后续将支持混合语种编码文本输入。"
)
app.queue(concurrency_count=511, max_size=1022).launch( app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0", server_name="0.0.0.0",
inbrowser=True, inbrowser=True,
server_port=infer_ttswebui, server_port=infer_ttswebui,
quiet=True, quiet=True,
) )

View File

@ -1,11 +1,12 @@
import os
import sys
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
import torch import torch
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
def savee(ckpt, name, epoch, steps, hps): def savee(ckpt, name, epoch, steps, hps):
try: try:
opt = OrderedDict() opt = OrderedDict()
@ -15,8 +16,8 @@ def savee(ckpt, name, epoch, steps, hps):
continue continue
opt["weight"][key] = ckpt[key].half() opt["weight"][key] = ckpt[key].half()
opt["config"] = hps opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch,steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir,name)) torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success." return "Success."
except: except:
return traceback.format_exc() return traceback.format_exc()

View File

@ -2,56 +2,84 @@
import os import os
import pdb import pdb
if("_CUDA_VISIBLE_DEVICES"in os.environ): if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import torch,platform import torch, platform
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config from AR.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high') logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt from AR.utils import get_newest_ckpt
from collections import OrderedDict from collections import OrderedDict
class my_model_ckpt(ModelCheckpoint): class my_model_ckpt(ModelCheckpoint):
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs): def __init__(
self,
config,
if_save_latest,
if_save_every_weights,
half_weights_save_dir,
exp_name,
**kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.if_save_latest=if_save_latest self.if_save_latest = if_save_latest
self.if_save_every_weights=if_save_every_weights self.if_save_every_weights = if_save_every_weights
self.half_weights_save_dir=half_weights_save_dir self.half_weights_save_dir = half_weights_save_dir
self.exp_name=exp_name self.exp_name = exp_name
self.config=config self.config = config
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): if not self._should_skip_saving_checkpoint(
trainer
) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer) monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: if (
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt self._every_n_epochs >= 1
to_clean=list(os.listdir(self.dirpath)) and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if (
self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates) self._save_topk_checkpoint(trainer, monitor_candidates)
if (self.if_save_latest == True): if self.if_save_latest == True:
for name in to_clean: for name in to_clean:
try: try:
os.remove("%s/%s"%(self.dirpath,name)) os.remove("%s/%s" % (self.dirpath, name))
except:pass except:
if(self.if_save_every_weights==True): pass
to_save_od=OrderedDict() if self.if_save_every_weights == True:
to_save_od["weight"]=OrderedDict() to_save_od = OrderedDict()
dictt=trainer.strategy._lightning_module.state_dict() to_save_od["weight"] = OrderedDict()
for key in dictt:to_save_od["weight"][key]=dictt[key].half() dictt = trainer.strategy._lightning_module.state_dict()
to_save_od["config"]=self.config for key in dictt:
to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1) to_save_od["weight"][key] = dictt[key].half()
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1)) to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
torch.save(
to_save_od,
"%s/%s-e%s.ckpt"
% (
self.half_weights_save_dir,
self.exp_name,
trainer.current_epoch + 1,
),
)
self._save_last_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates)
@ -61,41 +89,45 @@ def main(args):
output_dir = Path(config["output_dir"]) output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt' ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True) ckpt_dir.mkdir(parents=True, exist_ok=True)
seed_everything(config["train"]["seed"], workers=True) seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = my_model_ckpt( ckpt_callback: ModelCheckpoint = my_model_ckpt(
config=config, config=config,
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"], if_save_latest=config["train"]["if_save_latest"],
if_save_every_weights=config["train"]["if_save_every_weights"],
half_weights_save_dir=config["train"]["half_weights_save_dir"],
exp_name=config["train"]["exp_name"],
save_top_k=-1, save_top_k=-1,
monitor='top_3_acc', monitor="top_3_acc",
mode='max', mode="max",
save_on_train_epoch_end=True, save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"], every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir, dirpath=ckpt_dir,
) )
logger = TensorBoardLogger( logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
name=output_dir.stem,
save_dir=output_dir
)
trainer: Trainer = Trainer( trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"], max_epochs=config["train"]["epochs"],
accelerator='gpu', accelerator="gpu",
# val_check_interval=9999999999999999999999,###不要验证 # val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None, # check_val_every_n_epoch=None,
limit_val_batches=0, limit_val_batches=0,
devices=-1, devices=-1,
benchmark=False, benchmark=False,
fast_dev_run=False, fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"), strategy=DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
),
precision=config["train"]["precision"], precision=config["train"]["precision"],
logger=logger,num_sanity_val_steps=0, logger=logger,
callbacks=[ckpt_callback]) num_sanity_val_steps=0,
callbacks=[ckpt_callback],
)
model: Text2SemanticLightningModule = Text2SemanticLightningModule( model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir) config, output_dir
)
data_module: Text2SemanticDataModule = Text2SemanticDataModule( data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config, config,
@ -116,14 +148,15 @@ def main(args):
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-c', "-c",
'--config_file', "--config_file",
type=str, type=str,
default='configs/s1longer.yaml', default="configs/s1longer.yaml",
help='path of config file') help="path of config file",
)
# args for dataset # args for dataset
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv') # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt') # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')

View File

@ -1,4 +1,5 @@
import utils,os import utils, os
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch import torch
@ -6,11 +7,12 @@ from torch.nn import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist,traceback import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
import logging,traceback import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO)
@ -20,37 +22,42 @@ from module import commons
from module.data_utils import ( from module.data_utils import (
TextAudioSpeakerLoader, TextAudioSpeakerLoader,
TextAudioSpeakerCollate, TextAudioSpeakerCollate,
DistributedBucketSampler DistributedBucketSampler,
) )
from module.models import ( from module.models import (
SynthesizerTrn, SynthesizerTrn,
MultiPeriodDiscriminator, MultiPeriodDiscriminator,
) )
from module.losses import ( from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
generator_loss,
discriminator_loss,
feature_loss,
kl_loss
)
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee from process_ckpt import savee
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
def main(): def main():
"""Assume Single Node Multi GPUs Training Only""" """Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed." assert torch.cuda.is_available(), "CPU training is not allowed."
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
os.environ['MASTER_ADDR'] = 'localhost' os.environ["MASTER_ADDR"] = "localhost"
os.environ['MASTER_PORT'] = str(randint(20000, 55555)) os.environ["MASTER_PORT"] = str(randint(20000, 55555))
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
def run(rank, n_gpus, hps): def run(rank, n_gpus, hps):
@ -62,21 +69,54 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank) dist.init_process_group(
backend="gloo" if os.name == "nt" else "nccl",
init_method="env://",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed) torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data)######## train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
train_dataset, train_dataset,
hps.train.batch_size, hps.train.batch_size,
[32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900], [
32,
300,
400,
500,
600,
700,
800,
900,
1000,
1100,
1200,
1300,
1400,
1500,
1600,
1700,
1800,
1900,
],
num_replicas=n_gpus, num_replicas=n_gpus,
rank=rank, rank=rank,
shuffle=True) shuffle=True,
)
collate_fn = TextAudioSpeakerCollate() collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True, train_loader = DataLoader(
collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16) train_dataset,
num_workers=6,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=16,
)
# if rank == 0: # if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
@ -87,17 +127,21 @@ def run(rank, n_gpus, hps):
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**hps.model).cuda(rank) **hps.model,
).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
if not param.requires_grad: if not param.requires_grad:
print(name,"not requires_grad") print(name, "not requires_grad")
te_p = list(map(id, net_g.enc_p.text_embedding.parameters())) te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
et_p = list(map(id, net_g.enc_p.encoder_text.parameters())) et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
mrte_p = list(map(id, net_g.enc_p.mrte.parameters())) mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters()) base_params = filter(
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
net_g.parameters(),
)
# te_p=net_g.enc_p.text_embedding.parameters() # te_p=net_g.enc_p.text_embedding.parameters()
# et_p=net_g.enc_p.encoder_text.parameters() # et_p=net_g.enc_p.encoder_text.parameters()
@ -106,31 +150,46 @@ def run(rank, n_gpus, hps):
optim_g = torch.optim.AdamW( optim_g = torch.optim.AdamW(
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致 # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
[ [
{"params":base_params,"lr":hps.train.learning_rate}, {"params": base_params, "lr": hps.train.learning_rate},
{"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, {
{"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, "params": net_g.enc_p.text_embedding.parameters(),
{"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.encoder_text.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.mrte.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
], ],
hps.train.learning_rate, hps.train.learning_rate,
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps) eps=hps.train.eps,
)
optim_d = torch.optim.AdamW( optim_d = torch.optim.AdamW(
net_d.parameters(), net_d.parameters(),
hps.train.learning_rate, hps.train.learning_rate,
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps) eps=hps.train.eps,
net_g = DDP(net_g, device_ids=[rank],find_unused_parameters=True) )
net_d = DDP(net_d, device_ids=[rank],find_unused_parameters=True) net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
try: # 如果能加载自动resume try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
net_d,
optim_d,
) # D多半加载没事 ) # D多半加载没事
if rank == 0: if rank == 0:
logger.info("loaded D") logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
net_g,
optim_g,
) )
global_step = (epoch_str - 1) * len(train_loader) global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1 # epoch_str = 1
@ -144,7 +203,8 @@ def run(rank, n_gpus, hps):
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print( print(
net_g.module.load_state_dict( net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
) )
) ##测试不加载优化器 ) ##测试不加载优化器
if hps.train.pretrained_s2D != "": if hps.train.pretrained_s2D != "":
@ -159,8 +219,12 @@ def run(rank, n_gpus, hps):
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1) optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
)
for _ in range(epoch_str): for _ in range(epoch_str):
scheduler_g.step() scheduler_g.step()
scheduler_d.step() scheduler_d.step()
@ -169,17 +233,39 @@ def run(rank, n_gpus, hps):
for epoch in range(epoch_str, hps.train.epochs + 1): for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0: if rank == 0:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, train_and_evaluate(
# [train_loader, eval_loader], logger, [writer, writer_eval]) rank,
[train_loader, None], logger, [writer, writer_eval]) epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None],
logger,
[writer, writer_eval],
)
else: else:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, train_and_evaluate(
[train_loader, None], None, None) rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
)
scheduler_g.step() scheduler_g.step()
scheduler_d.step() scheduler_d.step()
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets net_g, net_d = nets
optim_g, optim_d = optims optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers # scheduler_g, scheduler_d = schedulers
@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train() net_g.train()
net_d.train() net_d.train()
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)): for batch_idx, (
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) ssl,
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in tqdm(enumerate(train_loader)):
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
)
ssl = ssl.cuda(rank, non_blocking=True) ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad=False ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True) text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
)
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
y_hat, kl_ssl, ids_slice, x_mask, z_mask, \ (
(z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths) y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch( mel = spec_to_mel_torch(
spec, spec,
@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.n_mel_channels, hps.data.n_mel_channels,
hps.data.sampling_rate, hps.data.sampling_rate,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax) hps.data.mel_fmax,
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) )
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_hat_mel = mel_spectrogram_torch( y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1), y_hat.squeeze(1),
hps.data.filter_length, hps.data.filter_length,
@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.hop_length, hps.data.hop_length,
hps.data.win_length, hps.data.win_length,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax hps.data.mel_fmax,
) )
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice y = commons.slice_segments(
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator # Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False): with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
loss_disc_all = loss_disc loss_disc_all = loss_disc
optim_d.zero_grad() optim_d.zero_grad()
scaler.scale(loss_disc_all).backward() scaler.scale(loss_disc_all).backward()
@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if rank == 0: if rank == 0:
if global_step % hps.train.log_interval == 0: if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr'] lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info('Train Epoch: {} [{:.0f}%]'.format( logger.info(
epoch, "Train Epoch: {} [{:.0f}%]".format(
100. * batch_idx / len(train_loader))) epoch, 100.0 * batch_idx / len(train_loader)
)
)
logger.info([x.item() for x in losses] + [global_step, lr]) logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, scalar_dict = {
"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} "loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update( scalar_dict.update(
{"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl}) {
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl_ssl": kl_ssl,
"loss/g/kl": loss_kl,
}
)
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = { image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_org": utils.plot_spectrogram_to_numpy(
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), y_mel[0].data.cpu().numpy()
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), ),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()), "slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy()
),
} }
utils.summarize( utils.summarize(
writer=writer, writer=writer,
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,
scalars=scalar_dict) scalars=scalar_dict,
)
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0: if hps.train.if_save_latest == 0:
@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
),
) )
utils.save_checkpoint( utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
),
) )
else: else:
utils.save_checkpoint( utils.save_checkpoint(
@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
),
) )
utils.save_checkpoint( utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
),
) )
if rank == 0 and hps.train.if_save_every_weights == True: if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"): if hasattr(net_g, "module"):
@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
) )
) )
if rank == 0: if rank == 0:
logger.info('====> Epoch: {}'.format(epoch)) logger.info("====> Epoch: {}".format(epoch))
def evaluate(hps, generator, eval_loader, writer_eval): def evaluate(hps, generator, eval_loader, writer_eval):
@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
audio_dict = {} audio_dict = {}
print("Evaluating ...") print("Evaluating ...")
with torch.no_grad(): with torch.no_grad():
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader): for batch_idx, (
ssl,
ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in enumerate(eval_loader):
print(111) print(111)
spec, spec_lengths = spec.cuda(), spec_lengths.cuda() spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda() y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda() ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda() text, text_lengths = text.cuda(), text_lengths.cuda()
for test in [0, 1]: for test in [0, 1]:
y_hat, mask, *_ = generator.module.infer(
y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test) ssl, spec, spec_lengths, text, text_lengths, test=test
)
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel_torch( mel = spec_to_mel_torch(
@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.n_mel_channels, hps.data.n_mel_channels,
hps.data.sampling_rate, hps.data.sampling_rate,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax) hps.data.mel_fmax,
)
y_hat_mel = mel_spectrogram_torch( y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(), y_hat.squeeze(1).float(),
hps.data.filter_length, hps.data.filter_length,
@ -373,16 +526,26 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.hop_length, hps.data.hop_length,
hps.data.win_length, hps.data.win_length,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax hps.data.mel_fmax,
) )
image_dict.update({ image_dict.update(
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) {
}) f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
audio_dict.update({ y_hat_mel[0].cpu().numpy()
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]] )
}) }
image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())}) )
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :y_lengths[0]]}) audio_dict.update(
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
)
image_dict.update(
{
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy()
)
}
)
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None) # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
# audio_dict.update({ # audio_dict.update({
@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,
audios=audio_dict, audios=audio_dict,
audio_sampling_rate=hps.data.sampling_rate audio_sampling_rate=hps.data.sampling_rate,
) )
generator.train() generator.train()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -12,8 +12,9 @@ import numpy as np
from scipy.io.wavfile import read from scipy.io.wavfile import read
import torch import torch
import logging import logging
logging.getLogger('numba').setLevel(logging.ERROR)
logging.getLogger('matplotlib').setLevel(logging.ERROR) logging.getLogger("numba").setLevel(logging.ERROR)
logging.getLogger("matplotlib").setLevel(logging.ERROR)
MATPLOTLIB_FLAG = False MATPLOTLIB_FLAG = False
@ -23,13 +24,17 @@ logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path) assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict['iteration'] iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict['learning_rate'] learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: if (
optimizer.load_state_dict(checkpoint_dict['optimizer']) optimizer is not None
saved_state_dict = checkpoint_dict['model'] and not skip_optimizer
if hasattr(model, 'module'): and checkpoint_dict["optimizer"] is not None
):
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
# assert "quantizer" not in k # assert "quantizer" not in k
# print("load", k) # print("load", k)
new_state_dict[k] = saved_state_dict[k] new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except: except:
traceback.print_exc() traceback.print_exc()
print("error, %s is not in the checkpoint" % k)#shape不对也会比如text_embedding当cleaner修改时 print(
"error, %s is not in the checkpoint" % k
) # shape不对也会比如text_embedding当cleaner修改时
new_state_dict[k] = v new_state_dict[k] = v
if hasattr(model, 'module'): if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict) model.module.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
print("load ") print("load ")
logger.info("Loaded checkpoint '{}' (iteration {})".format( logger.info(
checkpoint_path, iteration)) "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info("Saving model and optimizer state at iteration {} to {}".format( logger.info(
iteration, checkpoint_path)) "Saving model and optimizer state at iteration {} to {}".format(
if hasattr(model, 'module'): iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
torch.save({'model': state_dict, torch.save(
'iteration': iteration, {
'optimizer': optimizer.state_dict(), "model": state_dict,
'learning_rate': learning_rate}, checkpoint_path) "iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items(): for k, v in scalars.items():
writer.add_scalar(k, v, global_step) writer.add_scalar(k, v, global_step)
for k, v in histograms.items(): for k, v in histograms.items():
writer.add_histogram(k, v, global_step) writer.add_histogram(k, v, global_step)
for k, v in images.items(): for k, v in images.items():
writer.add_image(k, v, global_step, dataformats='HWC') writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items(): for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate) writer.add_audio(k, v, global_step, audio_sampling_rate)
@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG: if not MATPLOTLIB_FLAG:
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
MATPLOTLIB_FLAG = True MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib') mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt import matplotlib.pylab as plt
import numpy as np import numpy as np
fig, ax = plt.subplots(figsize=(10, 2)) fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
interpolation='none')
plt.colorbar(im, ax=ax) plt.colorbar(im, ax=ax)
plt.xlabel("Frames") plt.xlabel("Frames")
plt.ylabel("Channels") plt.ylabel("Channels")
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close() plt.close()
return data return data
@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG: if not MATPLOTLIB_FLAG:
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
MATPLOTLIB_FLAG = True MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib') mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt import matplotlib.pylab as plt
import numpy as np import numpy as np
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', im = ax.imshow(
interpolation='none') alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax) fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep' xlabel = "Decoder timestep"
if info is not None: if info is not None:
xlabel += '\n\n' + info xlabel += "\n\n" + info
plt.xlabel(xlabel) plt.xlabel(xlabel)
plt.ylabel('Encoder timestep') plt.ylabel("Encoder timestep")
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close() plt.close()
return data return data
@ -147,16 +176,31 @@ def load_wav_to_torch(full_path):
def load_filepaths_and_text(filename, split="|"): def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f: with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f] filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text return filepaths_and_text
def get_hparams(init=True, stage=1): def get_hparams(init=True, stage=1):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration') parser.add_argument(
parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir') "-c",
parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step') "--config",
type=str,
default="./configs/s2.json",
help="JSON file for configuration",
)
parser.add_argument(
"-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
)
parser.add_argument(
"-rs",
"--resume_step",
type=int,
required=False,
default=None,
help="resume step",
)
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory') # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights') # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights') # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
@ -172,7 +216,7 @@ def get_hparams(init=True, stage=1):
hparams.pretrain = args.pretrain hparams.pretrain = args.pretrain
hparams.resume_step = args.resume_step hparams.resume_step = args.resume_step
# hparams.data.exp_dir = args.exp_dir # hparams.data.exp_dir = args.exp_dir
if stage ==1: if stage == 1:
model_dir = hparams.s1_ckpt_dir model_dir = hparams.s1_ckpt_dir
else: else:
model_dir = hparams.s2_ckpt_dir model_dir = hparams.s2_ckpt_dir
@ -186,29 +230,38 @@ def get_hparams(init=True, stage=1):
return hparams return hparams
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts """Freeing up space by deleting saved ckpts
Arguments: Arguments:
path_to_models -- Path to the model directory path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts False -> lexicographically delete ckpts
""" """
import re import re
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1))) ckpts_files = [
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], x_sorted = lambda _x: sorted(
key=sort_key) [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
to_del = [os.path.join(path_to_models, fn) for fn in key=sort_key,
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])] )
to_del = [
os.path.join(path_to_models, fn)
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)] del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del] rs = [del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir): def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json") config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f: with open(config_save_path, "r") as f:
@ -228,12 +281,15 @@ def get_hparams_from_file(config_path):
hparams = HParams(**config) hparams = HParams(**config)
return hparams return hparams
def check_git_hash(model_dir): def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__)) source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")): if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( logger.warn(
source_dir "{} is not a git repository, therefore hash value comparison will be ignored.".format(
)) source_dir
)
)
return return
cur_hash = subprocess.getoutput("git rev-parse HEAD") cur_hash = subprocess.getoutput("git rev-parse HEAD")
@ -242,8 +298,11 @@ def check_git_hash(model_dir):
if os.path.exists(path): if os.path.exists(path):
saved_hash = open(path).read() saved_hash = open(path).read()
if saved_hash != cur_hash: if saved_hash != cur_hash:
logger.warn("git hash values are different. {}(saved) != {}(current)".format( logger.warn(
saved_hash[:8], cur_hash[:8])) "git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else: else:
open(path, "w").write(cur_hash) open(path, "w").write(cur_hash)
@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"):
return logger return logger
class HParams(): class HParams:
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
if type(v) == dict: if type(v) == dict:
@ -294,5 +353,10 @@ class HParams():
def __repr__(self): def __repr__(self):
return self.__dict__.__repr__() return self.__dict__.__repr__()
if __name__ == '__main__':
print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac')) if __name__ == "__main__":
print(
load_wav_to_torch(
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
)
)