mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
Code refactor + remove unused imports
This commit is contained in:
parent
9031ac9a92
commit
0d92575115
@ -1,49 +1,55 @@
|
||||
import os
|
||||
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
||||
sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
|
||||
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
|
||||
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
|
||||
infer_ttswebui=os.environ.get("infer_ttswebui",9872)
|
||||
infer_ttswebui=int(infer_ttswebui)
|
||||
if("_CUDA_VISIBLE_DEVICES"in os.environ):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
is_half=eval(os.environ.get("is_half","True"))
|
||||
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
)
|
||||
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||
cnhubert_base_path = os.environ.get(
|
||||
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
||||
)
|
||||
bert_path = os.environ.get(
|
||||
"bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
)
|
||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||
infer_ttswebui = int(infer_ttswebui)
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
is_half = eval(os.environ.get("is_half", "True"))
|
||||
import gradio as gr
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import sys,torch,numpy as np
|
||||
from pathlib import Path
|
||||
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
|
||||
import torch, numpy as np
|
||||
import os, librosa, torch
|
||||
|
||||
# torch.backends.cuda.sdp_kernel("flash")
|
||||
# torch.backends.cuda.enable_flash_sdp(True)
|
||||
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
|
||||
# torch.backends.cuda.enable_math_sdp(True)
|
||||
from random import shuffle
|
||||
from AR.utils import get_newest_ckpt
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
from feature_extractor import cnhubert
|
||||
cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||
from io import BytesIO
|
||||
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
from module.models import SynthesizerTrn
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from AR.utils.io import load_yaml_config
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import text_to_sequence, clean_text
|
||||
from text.cleaner import clean_text
|
||||
from time import time as ttime
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from my_utils import load_audio
|
||||
|
||||
device="cuda"
|
||||
device = "cuda"
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if(is_half==True):bert_model=bert_model.half().to(device)
|
||||
else:bert_model=bert_model.to(device)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if is_half == True:
|
||||
bert_model = bert_model.half().to(device)
|
||||
else:
|
||||
bert_model = bert_model.to(device)
|
||||
|
||||
|
||||
# bert_model=bert_model.to(device)
|
||||
def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
|
||||
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
@ -55,9 +61,12 @@ def get_bert_feature(text, word2ph):
|
||||
# if(is_half==True):phone_level_feature=phone_level_feature.half()
|
||||
return phone_level_feature.T
|
||||
|
||||
|
||||
n_semantic = 1024
|
||||
dict_s2=torch.load(sovits_path,map_location="cpu")
|
||||
hps=dict_s2["config"]
|
||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||
hps = dict_s2["config"]
|
||||
|
||||
|
||||
class DictToAttrRecursive:
|
||||
def __init__(self, input_dict):
|
||||
for key, value in input_dict.items():
|
||||
@ -67,206 +76,271 @@ class DictToAttrRecursive:
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate="25hz"
|
||||
dict_s1=torch.load(gpt_path,map_location="cpu")
|
||||
config=dict_s1["config"]
|
||||
ssl_model=cnhubert.get_model()
|
||||
if(is_half==True):ssl_model=ssl_model.half().to(device)
|
||||
else:ssl_model=ssl_model.to(device)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
config = dict_s1["config"]
|
||||
ssl_model = cnhubert.get_model()
|
||||
if is_half == True:
|
||||
ssl_model = ssl_model.half().to(device)
|
||||
else:
|
||||
ssl_model = ssl_model.to(device)
|
||||
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model)
|
||||
if(is_half==True):vq_model=vq_model.half().to(device)
|
||||
else:vq_model=vq_model.to(device)
|
||||
**hps.model
|
||||
)
|
||||
if is_half == True:
|
||||
vq_model = vq_model.half().to(device)
|
||||
else:
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
|
||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
hz = 50
|
||||
max_sec = config['data']['max_sec']
|
||||
max_sec = config["data"]["max_sec"]
|
||||
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
|
||||
t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
|
||||
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if(is_half==True):t2s_model=t2s_model.half()
|
||||
t2s_model=t2s_model.to(device)
|
||||
if is_half == True:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(device)
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
audio=load_audio(filename,int(hps.data.sampling_rate))
|
||||
audio=torch.FloatTensor(audio)
|
||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
hps.data.filter_length,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return spec
|
||||
|
||||
dict_language={
|
||||
"中文":"zh",
|
||||
"英文":"en",
|
||||
"日文":"ja"
|
||||
}
|
||||
def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
|
||||
|
||||
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
|
||||
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
t0 = ttime()
|
||||
prompt_text=prompt_text.strip("\n")
|
||||
prompt_language,text=prompt_language,text.strip("\n")
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
prompt_language, text = prompt_language, text.strip("\n")
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
if(is_half==True):wav16k=wav16k.half().to(device)
|
||||
else:wav16k=wav16k.to(device)
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
|
||||
if is_half == True:
|
||||
wav16k = wav16k.half().to(device)
|
||||
else:
|
||||
wav16k = wav16k.to(device)
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
t1 = ttime()
|
||||
prompt_language=dict_language[prompt_language]
|
||||
text_language=dict_language[text_language]
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
||||
phones1=cleaned_text_to_sequence(phones1)
|
||||
texts=text.split("\n")
|
||||
phones1 = cleaned_text_to_sequence(phones1)
|
||||
texts = text.split("\n")
|
||||
audio_opt = []
|
||||
zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
|
||||
zero_wav = np.zeros(
|
||||
int(hps.data.sampling_rate * 0.3),
|
||||
dtype=np.float16 if is_half == True else np.float32,
|
||||
)
|
||||
for text in texts:
|
||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
||||
phones2 = cleaned_text_to_sequence(phones2)
|
||||
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
|
||||
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||
if prompt_language == "zh":
|
||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||
else:
|
||||
bert1 = torch.zeros(
|
||||
(1024, len(phones1)),
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
).to(device)
|
||||
if text_language == "zh":
|
||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||
else:
|
||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
t2 = ttime()
|
||||
with torch.no_grad():
|
||||
# pred_semantic = t2s_model.model.infer(
|
||||
pred_semantic,idx = t2s_model.model.infer_panel(
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=config['inference']['top_k'],
|
||||
early_stop_num=hz * max_sec)
|
||||
top_k=config["inference"]["top_k"],
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path)#.to(device)
|
||||
if(is_half==True):refer=refer.half().to(device)
|
||||
else:refer=refer.to(device)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
|
||||
0
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||
if is_half == True:
|
||||
refer = refer.half().to(device)
|
||||
else:
|
||||
refer = refer.to(device)
|
||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
|
||||
audio = (
|
||||
vq_model.decode(
|
||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
) ###试试重建不带上prompt部分
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
|
||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
|
||||
np.int16
|
||||
)
|
||||
|
||||
|
||||
splits = {
|
||||
",",
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
",",
|
||||
".",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
":",
|
||||
":",
|
||||
"—",
|
||||
"…",
|
||||
} # 不考虑省略号
|
||||
|
||||
|
||||
splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if (todo_text[-1] not in splits): todo_text += "。"
|
||||
if todo_text[-1] not in splits:
|
||||
todo_text += "。"
|
||||
i_split_head = i_split_tail = 0
|
||||
len_text = len(todo_text)
|
||||
todo_texts = []
|
||||
while (1):
|
||||
if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||
if (todo_text[i_split_head] in splits):
|
||||
while 1:
|
||||
if i_split_head >= len_text:
|
||||
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||
if todo_text[i_split_head] in splits:
|
||||
i_split_head += 1
|
||||
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||
i_split_tail = i_split_head
|
||||
else:
|
||||
i_split_head += 1
|
||||
return todo_texts
|
||||
|
||||
|
||||
def cut1(inp):
|
||||
inp=inp.strip("\n")
|
||||
inps=split(inp)
|
||||
split_idx=list(range(0,len(inps),5))
|
||||
split_idx[-1]=None
|
||||
if(len(split_idx)>1):
|
||||
opts=[]
|
||||
for idx in range(len(split_idx)-1):
|
||||
opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
split_idx = list(range(0, len(inps), 5))
|
||||
split_idx[-1] = None
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
||||
else:
|
||||
opts=[inp]
|
||||
opts = [inp]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut2(inp):
|
||||
inp=inp.strip("\n")
|
||||
inps=split(inp)
|
||||
if(len(inps)<2):return [inp]
|
||||
opts=[]
|
||||
summ=0
|
||||
tmp_str=""
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
if len(inps) < 2:
|
||||
return [inp]
|
||||
opts = []
|
||||
summ = 0
|
||||
tmp_str = ""
|
||||
for i in range(len(inps)):
|
||||
summ+=len(inps[i])
|
||||
tmp_str+=inps[i]
|
||||
if(summ>50):
|
||||
summ=0
|
||||
summ += len(inps[i])
|
||||
tmp_str += inps[i]
|
||||
if summ > 50:
|
||||
summ = 0
|
||||
opts.append(tmp_str)
|
||||
tmp_str=""
|
||||
if(tmp_str!=""):opts.append(tmp_str)
|
||||
if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
|
||||
opts[-2]=opts[-2]+opts[-1]
|
||||
opts=opts[:-1]
|
||||
tmp_str = ""
|
||||
if tmp_str != "":
|
||||
opts.append(tmp_str)
|
||||
if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||
opts[-2] = opts[-2] + opts[-1]
|
||||
opts = opts[:-1]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut3(inp):
|
||||
inp=inp.strip("\n")
|
||||
return "\n".join(["%s。"%item for item in inp.strip("。").split("。")])
|
||||
inp = inp.strip("\n")
|
||||
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
|
||||
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
gr.Markdown(
|
||||
value=
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
|
||||
value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
|
||||
)
|
||||
# with gr.Tabs():
|
||||
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
||||
with gr.Group():
|
||||
gr.Markdown(
|
||||
value=
|
||||
"*请上传并填写参考信息"
|
||||
)
|
||||
gr.Markdown(value="*请上传并填写参考信息")
|
||||
with gr.Row():
|
||||
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
|
||||
prompt_text= gr.Textbox(label="参考音频的文本",value="")
|
||||
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文")
|
||||
gr.Markdown(
|
||||
value=
|
||||
"*请填写需要合成的目标文本"
|
||||
)
|
||||
prompt_text = gr.Textbox(label="参考音频的文本", value="")
|
||||
prompt_language = gr.Dropdown(
|
||||
label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
|
||||
)
|
||||
gr.Markdown(value="*请填写需要合成的目标文本")
|
||||
with gr.Row():
|
||||
text=gr.Textbox(label="需要合成的文本",value="")
|
||||
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文")
|
||||
inference_button=gr.Button("合成语音", variant="primary")
|
||||
text = gr.Textbox(label="需要合成的文本", value="")
|
||||
text_language = gr.Dropdown(
|
||||
label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
|
||||
)
|
||||
inference_button = gr.Button("合成语音", variant="primary")
|
||||
output = gr.Audio(label="输出的语音")
|
||||
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
|
||||
|
||||
gr.Markdown(
|
||||
value=
|
||||
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||
inference_button.click(
|
||||
get_tts_wav,
|
||||
[inp_ref, prompt_text, prompt_language, text, text_language],
|
||||
[output],
|
||||
)
|
||||
|
||||
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
|
||||
with gr.Row():
|
||||
text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
|
||||
text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
|
||||
button1 = gr.Button("凑五句一切", variant="primary")
|
||||
button2 = gr.Button("凑50字一切", variant="primary")
|
||||
button3 = gr.Button("按中文句号。切", variant="primary")
|
||||
text_opt = gr.Textbox(label="切分后文本", value="")
|
||||
button1.click(cut1,[text_inp],[text_opt])
|
||||
button2.click(cut2,[text_inp],[text_opt])
|
||||
button3.click(cut3,[text_inp],[text_opt])
|
||||
gr.Markdown(
|
||||
value=
|
||||
"后续将支持混合语种编码文本输入。"
|
||||
)
|
||||
button1.click(cut1, [text_inp], [text_opt])
|
||||
button2.click(cut2, [text_inp], [text_opt])
|
||||
button3.click(cut3, [text_inp], [text_opt])
|
||||
gr.Markdown(value="后续将支持混合语种编码文本输入。")
|
||||
|
||||
app.queue(concurrency_count=511, max_size=1022).launch(
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
server_port=infer_ttswebui,
|
||||
quiet=True,
|
||||
)
|
||||
)
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps):
|
||||
try:
|
||||
opt = OrderedDict()
|
||||
@ -15,8 +16,8 @@ def savee(ckpt, name, epoch, steps, hps):
|
||||
continue
|
||||
opt["weight"][key] = ckpt[key].half()
|
||||
opt["config"] = hps
|
||||
opt["info"] = "%sepoch_%siteration" % (epoch,steps)
|
||||
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir,name))
|
||||
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
||||
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
return "Success."
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
@ -2,56 +2,84 @@
|
||||
import os
|
||||
import pdb
|
||||
|
||||
if("_CUDA_VISIBLE_DEVICES"in os.environ):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch,platform
|
||||
import torch, platform
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from AR.data.data_module import Text2SemanticDataModule
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from AR.utils.io import load_yaml_config
|
||||
logging.getLogger('numba').setLevel(logging.WARNING)
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision("high")
|
||||
from AR.utils import get_newest_ckpt
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class my_model_ckpt(ModelCheckpoint):
|
||||
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
if_save_latest,
|
||||
if_save_every_weights,
|
||||
half_weights_save_dir,
|
||||
exp_name,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.if_save_latest=if_save_latest
|
||||
self.if_save_every_weights=if_save_every_weights
|
||||
self.half_weights_save_dir=half_weights_save_dir
|
||||
self.exp_name=exp_name
|
||||
self.config=config
|
||||
self.if_save_latest = if_save_latest
|
||||
self.if_save_every_weights = if_save_every_weights
|
||||
self.half_weights_save_dir = half_weights_save_dir
|
||||
self.exp_name = exp_name
|
||||
self.config = config
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
||||
if not self._should_skip_saving_checkpoint(
|
||||
trainer
|
||||
) and self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
to_clean=list(os.listdir(self.dirpath))
|
||||
if (
|
||||
self._every_n_epochs >= 1
|
||||
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
||||
):
|
||||
if (
|
||||
self.if_save_latest == True
|
||||
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
to_clean = list(os.listdir(self.dirpath))
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
if (self.if_save_latest == True):
|
||||
if self.if_save_latest == True:
|
||||
for name in to_clean:
|
||||
try:
|
||||
os.remove("%s/%s"%(self.dirpath,name))
|
||||
except:pass
|
||||
if(self.if_save_every_weights==True):
|
||||
to_save_od=OrderedDict()
|
||||
to_save_od["weight"]=OrderedDict()
|
||||
dictt=trainer.strategy._lightning_module.state_dict()
|
||||
for key in dictt:to_save_od["weight"][key]=dictt[key].half()
|
||||
to_save_od["config"]=self.config
|
||||
to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
|
||||
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
|
||||
os.remove("%s/%s" % (self.dirpath, name))
|
||||
except:
|
||||
pass
|
||||
if self.if_save_every_weights == True:
|
||||
to_save_od = OrderedDict()
|
||||
to_save_od["weight"] = OrderedDict()
|
||||
dictt = trainer.strategy._lightning_module.state_dict()
|
||||
for key in dictt:
|
||||
to_save_od["weight"][key] = dictt[key].half()
|
||||
to_save_od["config"] = self.config
|
||||
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||
torch.save(
|
||||
to_save_od,
|
||||
"%s/%s-e%s.ckpt"
|
||||
% (
|
||||
self.half_weights_save_dir,
|
||||
self.exp_name,
|
||||
trainer.current_epoch + 1,
|
||||
),
|
||||
)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
|
||||
@ -61,41 +89,45 @@ def main(args):
|
||||
output_dir = Path(config["output_dir"])
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ckpt_dir = output_dir / 'ckpt'
|
||||
ckpt_dir = output_dir / "ckpt"
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
seed_everything(config["train"]["seed"], workers=True)
|
||||
ckpt_callback: ModelCheckpoint = my_model_ckpt(
|
||||
config=config,
|
||||
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
|
||||
if_save_latest=config["train"]["if_save_latest"],
|
||||
if_save_every_weights=config["train"]["if_save_every_weights"],
|
||||
half_weights_save_dir=config["train"]["half_weights_save_dir"],
|
||||
exp_name=config["train"]["exp_name"],
|
||||
save_top_k=-1,
|
||||
monitor='top_3_acc',
|
||||
mode='max',
|
||||
monitor="top_3_acc",
|
||||
mode="max",
|
||||
save_on_train_epoch_end=True,
|
||||
every_n_epochs=config["train"]["save_every_n_epoch"],
|
||||
dirpath=ckpt_dir,
|
||||
)
|
||||
logger = TensorBoardLogger(
|
||||
name=output_dir.stem,
|
||||
save_dir=output_dir
|
||||
)
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
accelerator='gpu',
|
||||
accelerator="gpu",
|
||||
# val_check_interval=9999999999999999999999,###不要验证
|
||||
# check_val_every_n_epoch=None,
|
||||
limit_val_batches=0,
|
||||
devices=-1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
|
||||
strategy=DDPStrategy(
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||
),
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,num_sanity_val_steps=0,
|
||||
callbacks=[ckpt_callback])
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
callbacks=[ckpt_callback],
|
||||
)
|
||||
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
||||
config, output_dir)
|
||||
config, output_dir
|
||||
)
|
||||
|
||||
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
||||
config,
|
||||
@ -116,14 +148,15 @@ def main(args):
|
||||
|
||||
|
||||
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config_file',
|
||||
"-c",
|
||||
"--config_file",
|
||||
type=str,
|
||||
default='configs/s1longer.yaml',
|
||||
help='path of config file')
|
||||
default="configs/s1longer.yaml",
|
||||
help="path of config file",
|
||||
)
|
||||
# args for dataset
|
||||
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
|
||||
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
|
||||
|
@ -1,4 +1,5 @@
|
||||
import utils,os
|
||||
import utils, os
|
||||
|
||||
hps = utils.get_hparams(stage=2)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||
import torch
|
||||
@ -6,11 +7,12 @@ from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist,traceback
|
||||
import torch.distributed as dist, traceback
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import logging,traceback
|
||||
import logging, traceback
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.INFO)
|
||||
@ -20,37 +22,42 @@ from module import commons
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollate,
|
||||
DistributedBucketSampler
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrn,
|
||||
MultiPeriodDiscriminator,
|
||||
)
|
||||
from module.losses import (
|
||||
generator_loss,
|
||||
discriminator_loss,
|
||||
feature_loss,
|
||||
kl_loss
|
||||
)
|
||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from process_ckpt import savee
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = False
|
||||
###反正A100fp32更快,那试试tf32吧
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响
|
||||
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
||||
# from config import pretrained_s2G,pretrained_s2D
|
||||
global_step = 0
|
||||
|
||||
|
||||
def main():
|
||||
"""Assume Single Node Multi GPUs Training Only"""
|
||||
assert torch.cuda.is_available(), "CPU training is not allowed."
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = str(randint(20000, 55555))
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||
|
||||
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
|
||||
mp.spawn(
|
||||
run,
|
||||
nprocs=n_gpus,
|
||||
args=(
|
||||
n_gpus,
|
||||
hps,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run(rank, n_gpus, hps):
|
||||
@ -62,21 +69,54 @@ def run(rank, n_gpus, hps):
|
||||
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank)
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" else "nccl",
|
||||
init_method="env://",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
torch.manual_seed(hps.train.seed)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data)########
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
||||
train_sampler = DistributedBucketSampler(
|
||||
train_dataset,
|
||||
hps.train.batch_size,
|
||||
[32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
|
||||
[
|
||||
32,
|
||||
300,
|
||||
400,
|
||||
500,
|
||||
600,
|
||||
700,
|
||||
800,
|
||||
900,
|
||||
1000,
|
||||
1100,
|
||||
1200,
|
||||
1300,
|
||||
1400,
|
||||
1500,
|
||||
1600,
|
||||
1700,
|
||||
1800,
|
||||
1900,
|
||||
],
|
||||
num_replicas=n_gpus,
|
||||
rank=rank,
|
||||
shuffle=True)
|
||||
shuffle=True,
|
||||
)
|
||||
collate_fn = TextAudioSpeakerCollate()
|
||||
train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True,
|
||||
collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=6,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_sampler=train_sampler,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=16,
|
||||
)
|
||||
# if rank == 0:
|
||||
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
||||
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
|
||||
@ -87,17 +127,21 @@ def run(rank, n_gpus, hps):
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model).cuda(rank)
|
||||
**hps.model,
|
||||
).cuda(rank)
|
||||
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
print(name,"not requires_grad")
|
||||
print(name, "not requires_grad")
|
||||
|
||||
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
|
||||
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
|
||||
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
|
||||
base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters())
|
||||
base_params = filter(
|
||||
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
|
||||
net_g.parameters(),
|
||||
)
|
||||
|
||||
# te_p=net_g.enc_p.text_embedding.parameters()
|
||||
# et_p=net_g.enc_p.encoder_text.parameters()
|
||||
@ -106,31 +150,46 @@ def run(rank, n_gpus, hps):
|
||||
optim_g = torch.optim.AdamW(
|
||||
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
|
||||
[
|
||||
{"params":base_params,"lr":hps.train.learning_rate},
|
||||
{"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
|
||||
{"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
|
||||
{"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
|
||||
{"params": base_params, "lr": hps.train.learning_rate},
|
||||
{
|
||||
"params": net_g.enc_p.text_embedding.parameters(),
|
||||
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
||||
},
|
||||
{
|
||||
"params": net_g.enc_p.encoder_text.parameters(),
|
||||
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
||||
},
|
||||
{
|
||||
"params": net_g.enc_p.mrte.parameters(),
|
||||
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
||||
},
|
||||
],
|
||||
hps.train.learning_rate,
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps)
|
||||
eps=hps.train.eps,
|
||||
)
|
||||
optim_d = torch.optim.AdamW(
|
||||
net_d.parameters(),
|
||||
hps.train.learning_rate,
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps)
|
||||
net_g = DDP(net_g, device_ids=[rank],find_unused_parameters=True)
|
||||
net_d = DDP(net_d, device_ids=[rank],find_unused_parameters=True)
|
||||
eps=hps.train.eps,
|
||||
)
|
||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
try: # 如果能加载自动resume
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d
|
||||
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
|
||||
net_d,
|
||||
optim_d,
|
||||
) # D多半加载没事
|
||||
if rank == 0:
|
||||
logger.info("loaded D")
|
||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g
|
||||
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
# epoch_str = 1
|
||||
@ -144,7 +203,8 @@ def run(rank, n_gpus, hps):
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print(
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
) ##测试不加载优化器
|
||||
if hps.train.pretrained_s2D != "":
|
||||
@ -159,8 +219,12 @@ def run(rank, n_gpus, hps):
|
||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
)
|
||||
for _ in range(epoch_str):
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
@ -169,17 +233,39 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||
if rank == 0:
|
||||
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
||||
# [train_loader, eval_loader], logger, [writer, writer_eval])
|
||||
[train_loader, None], logger, [writer, writer_eval])
|
||||
train_and_evaluate(
|
||||
rank,
|
||||
epoch,
|
||||
hps,
|
||||
[net_g, net_d],
|
||||
[optim_g, optim_d],
|
||||
[scheduler_g, scheduler_d],
|
||||
scaler,
|
||||
# [train_loader, eval_loader], logger, [writer, writer_eval])
|
||||
[train_loader, None],
|
||||
logger,
|
||||
[writer, writer_eval],
|
||||
)
|
||||
else:
|
||||
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
||||
[train_loader, None], None, None)
|
||||
train_and_evaluate(
|
||||
rank,
|
||||
epoch,
|
||||
hps,
|
||||
[net_g, net_d],
|
||||
[optim_g, optim_d],
|
||||
[scheduler_g, scheduler_d],
|
||||
scaler,
|
||||
[train_loader, None],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
|
||||
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
|
||||
net_g.train()
|
||||
net_d.train()
|
||||
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)):
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
|
||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
|
||||
for batch_idx, (
|
||||
ssl,
|
||||
ssl_lengths,
|
||||
spec,
|
||||
spec_lengths,
|
||||
y,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
) in tqdm(enumerate(train_loader)):
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad=False
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
y_hat, kl_ssl, ids_slice, x_mask, z_mask, \
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
(
|
||||
y_hat,
|
||||
kl_ssl,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
stats_ssl,
|
||||
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax)
|
||||
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
hps.data.filter_length,
|
||||
@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
|
||||
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
||||
y = commons.slice_segments(
|
||||
y, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||
with autocast(enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
)
|
||||
loss_disc_all = loss_disc
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc_all).backward()
|
||||
@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]['lr']
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
||||
epoch,
|
||||
100. * batch_idx / len(train_loader)))
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch, 100.0 * batch_idx / len(train_loader)
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
|
||||
"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
|
||||
scalar_dict = {
|
||||
"loss/g/total": loss_gen_all,
|
||||
"loss/d/total": loss_disc_all,
|
||||
"learning_rate": lr,
|
||||
"grad_norm_d": grad_norm_d,
|
||||
"grad_norm_g": grad_norm_g,
|
||||
}
|
||||
scalar_dict.update(
|
||||
{"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl})
|
||||
{
|
||||
"loss/g/fm": loss_fm,
|
||||
"loss/g/mel": loss_mel,
|
||||
"loss/g/kl_ssl": kl_ssl,
|
||||
"loss/g/kl": loss_kl,
|
||||
}
|
||||
)
|
||||
|
||||
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
||||
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
||||
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
|
||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy()
|
||||
),
|
||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
||||
stats_ssl[0].data.cpu().numpy()
|
||||
),
|
||||
}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict)
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)),
|
||||
os.path.join(
|
||||
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)),
|
||||
os.path.join(
|
||||
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
|
||||
),
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)),
|
||||
os.path.join(
|
||||
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)),
|
||||
os.path.join(
|
||||
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
|
||||
),
|
||||
)
|
||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||
if hasattr(net_g, "module"):
|
||||
@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
if rank == 0:
|
||||
logger.info('====> Epoch: {}'.format(epoch))
|
||||
|
||||
logger.info("====> Epoch: {}".format(epoch))
|
||||
|
||||
|
||||
def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
audio_dict = {}
|
||||
print("Evaluating ...")
|
||||
with torch.no_grad():
|
||||
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader):
|
||||
for batch_idx, (
|
||||
ssl,
|
||||
ssl_lengths,
|
||||
spec,
|
||||
spec_lengths,
|
||||
y,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
) in enumerate(eval_loader):
|
||||
print(111)
|
||||
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
||||
y, y_lengths = y.cuda(), y_lengths.cuda()
|
||||
ssl = ssl.cuda()
|
||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
||||
for test in [0, 1]:
|
||||
|
||||
y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test)
|
||||
y_hat, mask, *_ = generator.module.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
)
|
||||
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
||||
|
||||
mel = spec_to_mel_torch(
|
||||
@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax)
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1).float(),
|
||||
hps.data.filter_length,
|
||||
@ -373,16 +526,26 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
image_dict.update({
|
||||
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
|
||||
})
|
||||
audio_dict.update({
|
||||
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]]
|
||||
})
|
||||
image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
|
||||
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :y_lengths[0]]})
|
||||
image_dict.update(
|
||||
{
|
||||
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].cpu().numpy()
|
||||
)
|
||||
}
|
||||
)
|
||||
audio_dict.update(
|
||||
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
|
||||
)
|
||||
image_dict.update(
|
||||
{
|
||||
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].cpu().numpy()
|
||||
)
|
||||
}
|
||||
)
|
||||
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
||||
|
||||
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
|
||||
# audio_dict.update({
|
||||
@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
audios=audio_dict,
|
||||
audio_sampling_rate=hps.data.sampling_rate
|
||||
audio_sampling_rate=hps.data.sampling_rate,
|
||||
)
|
||||
generator.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -12,8 +12,9 @@ import numpy as np
|
||||
from scipy.io.wavfile import read
|
||||
import torch
|
||||
import logging
|
||||
logging.getLogger('numba').setLevel(logging.ERROR)
|
||||
logging.getLogger('matplotlib').setLevel(logging.ERROR)
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.ERROR)
|
||||
logging.getLogger("matplotlib").setLevel(logging.ERROR)
|
||||
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
@ -23,13 +24,17 @@ logger = logging
|
||||
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
iteration = checkpoint_dict['iteration']
|
||||
learning_rate = checkpoint_dict['learning_rate']
|
||||
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
|
||||
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
saved_state_dict = checkpoint_dict['model']
|
||||
if hasattr(model, 'module'):
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
iteration = checkpoint_dict["iteration"]
|
||||
learning_rate = checkpoint_dict["learning_rate"]
|
||||
if (
|
||||
optimizer is not None
|
||||
and not skip_optimizer
|
||||
and checkpoint_dict["optimizer"] is not None
|
||||
):
|
||||
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
||||
saved_state_dict = checkpoint_dict["model"]
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
||||
# assert "quantizer" not in k
|
||||
# print("load", k)
|
||||
new_state_dict[k] = saved_state_dict[k]
|
||||
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
|
||||
assert saved_state_dict[k].shape == v.shape, (
|
||||
saved_state_dict[k].shape,
|
||||
v.shape,
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print("error, %s is not in the checkpoint" % k)#shape不对也会,比如text_embedding当cleaner修改时
|
||||
print(
|
||||
"error, %s is not in the checkpoint" % k
|
||||
) # shape不对也会,比如text_embedding当cleaner修改时
|
||||
new_state_dict[k] = v
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
model.module.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(new_state_dict)
|
||||
print("load ")
|
||||
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
||||
checkpoint_path, iteration))
|
||||
logger.info(
|
||||
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
|
||||
)
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
||||
logger.info("Saving model and optimizer state at iteration {} to {}".format(
|
||||
iteration, checkpoint_path))
|
||||
if hasattr(model, 'module'):
|
||||
logger.info(
|
||||
"Saving model and optimizer state at iteration {} to {}".format(
|
||||
iteration, checkpoint_path
|
||||
)
|
||||
)
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save({'model': state_dict,
|
||||
'iteration': iteration,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'learning_rate': learning_rate}, checkpoint_path)
|
||||
torch.save(
|
||||
{
|
||||
"model": state_dict,
|
||||
"iteration": iteration,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"learning_rate": learning_rate,
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
|
||||
def summarize(
|
||||
writer,
|
||||
global_step,
|
||||
scalars={},
|
||||
histograms={},
|
||||
images={},
|
||||
audios={},
|
||||
audio_sampling_rate=22050,
|
||||
):
|
||||
for k, v in scalars.items():
|
||||
writer.add_scalar(k, v, global_step)
|
||||
for k, v in histograms.items():
|
||||
writer.add_histogram(k, v, global_step)
|
||||
for k, v in images.items():
|
||||
writer.add_image(k, v, global_step, dataformats='HWC')
|
||||
writer.add_image(k, v, global_step, dataformats="HWC")
|
||||
for k, v in audios.items():
|
||||
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
||||
|
||||
@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
|
||||
interpolation='none')
|
||||
im = ax.imshow(
|
||||
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
|
||||
)
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
xlabel = "Decoder timestep"
|
||||
if info is not None:
|
||||
xlabel += '\n\n' + info
|
||||
xlabel += "\n\n" + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.ylabel("Encoder timestep")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
@ -147,16 +176,31 @@ def load_wav_to_torch(full_path):
|
||||
|
||||
|
||||
def load_filepaths_and_text(filename, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def get_hparams(init=True, stage=1):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration')
|
||||
parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir')
|
||||
parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step')
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
type=str,
|
||||
default="./configs/s2.json",
|
||||
help="JSON file for configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-rs",
|
||||
"--resume_step",
|
||||
type=int,
|
||||
required=False,
|
||||
default=None,
|
||||
help="resume step",
|
||||
)
|
||||
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
|
||||
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
|
||||
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
|
||||
@ -172,7 +216,7 @@ def get_hparams(init=True, stage=1):
|
||||
hparams.pretrain = args.pretrain
|
||||
hparams.resume_step = args.resume_step
|
||||
# hparams.data.exp_dir = args.exp_dir
|
||||
if stage ==1:
|
||||
if stage == 1:
|
||||
model_dir = hparams.s1_ckpt_dir
|
||||
else:
|
||||
model_dir = hparams.s2_ckpt_dir
|
||||
@ -186,29 +230,38 @@ def get_hparams(init=True, stage=1):
|
||||
return hparams
|
||||
|
||||
|
||||
|
||||
def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
|
||||
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
|
||||
"""Freeing up space by deleting saved ckpts
|
||||
|
||||
Arguments:
|
||||
path_to_models -- Path to the model directory
|
||||
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
|
||||
sort_by_time -- True -> chronologically delete ckpts
|
||||
False -> lexicographically delete ckpts
|
||||
"""
|
||||
Arguments:
|
||||
path_to_models -- Path to the model directory
|
||||
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
|
||||
sort_by_time -- True -> chronologically delete ckpts
|
||||
False -> lexicographically delete ckpts
|
||||
"""
|
||||
import re
|
||||
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
|
||||
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
|
||||
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
|
||||
|
||||
ckpts_files = [
|
||||
f
|
||||
for f in os.listdir(path_to_models)
|
||||
if os.path.isfile(os.path.join(path_to_models, f))
|
||||
]
|
||||
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
|
||||
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
|
||||
sort_key = time_key if sort_by_time else name_key
|
||||
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
|
||||
key=sort_key)
|
||||
to_del = [os.path.join(path_to_models, fn) for fn in
|
||||
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
|
||||
x_sorted = lambda _x: sorted(
|
||||
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
|
||||
key=sort_key,
|
||||
)
|
||||
to_del = [
|
||||
os.path.join(path_to_models, fn)
|
||||
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
|
||||
]
|
||||
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
|
||||
del_routine = lambda x: [os.remove(x), del_info(x)]
|
||||
rs = [del_routine(fn) for fn in to_del]
|
||||
|
||||
|
||||
def get_hparams_from_dir(model_dir):
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
with open(config_save_path, "r") as f:
|
||||
@ -228,12 +281,15 @@ def get_hparams_from_file(config_path):
|
||||
hparams = HParams(**config)
|
||||
return hparams
|
||||
|
||||
|
||||
def check_git_hash(model_dir):
|
||||
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||
source_dir
|
||||
))
|
||||
logger.warn(
|
||||
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||
source_dir
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
||||
@ -242,8 +298,11 @@ def check_git_hash(model_dir):
|
||||
if os.path.exists(path):
|
||||
saved_hash = open(path).read()
|
||||
if saved_hash != cur_hash:
|
||||
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
|
||||
saved_hash[:8], cur_hash[:8]))
|
||||
logger.warn(
|
||||
"git hash values are different. {}(saved) != {}(current)".format(
|
||||
saved_hash[:8], cur_hash[:8]
|
||||
)
|
||||
)
|
||||
else:
|
||||
open(path, "w").write(cur_hash)
|
||||
|
||||
@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"):
|
||||
return logger
|
||||
|
||||
|
||||
class HParams():
|
||||
class HParams:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if type(v) == dict:
|
||||
@ -294,5 +353,10 @@ class HParams():
|
||||
def __repr__(self):
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac'))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(
|
||||
load_wav_to_torch(
|
||||
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
|
||||
)
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user