Code refactor + remove unused imports

This commit is contained in:
Blaise 2024-01-16 17:10:27 +01:00
parent 9031ac9a92
commit 0d92575115
5 changed files with 671 additions and 335 deletions

View File

@ -1,49 +1,55 @@
import os
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
infer_ttswebui=os.environ.get("infer_ttswebui",9872)
infer_ttswebui=int(infer_ttswebui)
if("_CUDA_VISIBLE_DEVICES"in os.environ):
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
is_half=eval(os.environ.get("is_half","True"))
gpt_path = os.environ.get(
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
)
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True"))
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import sys,torch,numpy as np
from pathlib import Path
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
import torch, numpy as np
import os, librosa, torch
# torch.backends.cuda.sdp_kernel("flash")
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
# torch.backends.cuda.enable_math_sdp(True)
from random import shuffle
from AR.utils import get_newest_ckpt
from glob import glob
from tqdm import tqdm
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path=cnhubert_base_path
from io import BytesIO
cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from text import cleaned_text_to_sequence
from text.cleaner import text_to_sequence, clean_text
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
device="cuda"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
if(is_half==True):bert_model=bert_model.half().to(device)
else:bert_model=bert_model.to(device)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
# bert_model=bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题精度随bert_model
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
@ -55,9 +61,12 @@ def get_bert_feature(text, word2ph):
# if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
n_semantic = 1024
dict_s2=torch.load(sovits_path,map_location="cpu")
hps=dict_s2["config"]
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
class DictToAttrRecursive:
def __init__(self, input_dict):
for key, value in input_dict.items():
@ -67,206 +76,271 @@ class DictToAttrRecursive:
else:
setattr(self, key, value)
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate="25hz"
dict_s1=torch.load(gpt_path,map_location="cpu")
config=dict_s1["config"]
ssl_model=cnhubert.get_model()
if(is_half==True):ssl_model=ssl_model.half().to(device)
else:ssl_model=ssl_model.to(device)
hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
if(is_half==True):vq_model=vq_model.half().to(device)
else:vq_model=vq_model.to(device)
**hps.model
)
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50
max_sec = config['data']['max_sec']
max_sec = config["data"]["max_sec"]
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if(is_half==True):t2s_model=t2s_model.half()
t2s_model=t2s_model.to(device)
if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
def get_spepc(hps, filename):
audio=load_audio(filename,int(hps.data.sampling_rate))
audio=torch.FloatTensor(audio)
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
dict_language={
"中文":"zh",
"英文":"en",
"日文":"ja"
}
def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text=prompt_text.strip("\n")
prompt_language,text=prompt_language,text.strip("\n")
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
wav16k = torch.from_numpy(wav16k)
if(is_half==True):wav16k=wav16k.half().to(device)
else:wav16k=wav16k.to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
if is_half == True:
wav16k = wav16k.half().to(device)
else:
wav16k = wav16k.to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language=dict_language[prompt_language]
text_language=dict_language[text_language]
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1=cleaned_text_to_sequence(phones1)
texts=text.split("\n")
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
if prompt_language == "zh":
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros(
(1024, len(phones1)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
if text_language == "zh":
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic,idx = t2s_model.model.infer_panel(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
top_k=config["inference"]["top_k"],
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path)#.to(device)
if(is_half==True):refer=refer.half().to(device)
else:refer=refer.to(device)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
} # 不考虑省略号
splits={"","","","",",",".","?","!","~",":","","","",}#不考虑省略号
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if (todo_text[-1] not in splits): todo_text += ""
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while (1):
if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if (todo_text[i_split_head] in splits):
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp=inp.strip("\n")
inps=split(inp)
split_idx=list(range(0,len(inps),5))
split_idx[-1]=None
if(len(split_idx)>1):
opts=[]
for idx in range(len(split_idx)-1):
opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 5))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts=[inp]
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp=inp.strip("\n")
inps=split(inp)
if(len(inps)<2):return [inp]
opts=[]
summ=0
tmp_str=""
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return [inp]
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ+=len(inps[i])
tmp_str+=inps[i]
if(summ>50):
summ=0
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str=""
if(tmp_str!=""):opts.append(tmp_str)
if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
opts[-2]=opts[-2]+opts[-1]
opts=opts[:-1]
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp=inp.strip("\n")
return "\n".join(["%s"%item for item in inp.strip("").split("")])
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
)
# with gr.Tabs():
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
with gr.Group():
gr.Markdown(
value=
"*请上传并填写参考信息"
)
gr.Markdown(value="*请上传并填写参考信息")
with gr.Row():
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
prompt_text= gr.Textbox(label="参考音频的文本",value="")
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文")
gr.Markdown(
value=
"*请填写需要合成的目标文本"
)
prompt_text = gr.Textbox(label="参考音频的文本", value="")
prompt_language = gr.Dropdown(
label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
)
gr.Markdown(value="*请填写需要合成的目标文本")
with gr.Row():
text=gr.Textbox(label="需要合成的文本",value="")
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文")
inference_button=gr.Button("合成语音", variant="primary")
text = gr.Textbox(label="需要合成的文本", value="")
text_language = gr.Dropdown(
label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
)
inference_button = gr.Button("合成语音", variant="primary")
output = gr.Audio(label="输出的语音")
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
gr.Markdown(
value=
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language],
[output],
)
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
with gr.Row():
text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
button1 = gr.Button("凑五句一切", variant="primary")
button2 = gr.Button("凑50字一切", variant="primary")
button3 = gr.Button("按中文句号。切", variant="primary")
text_opt = gr.Textbox(label="切分后文本", value="")
button1.click(cut1,[text_inp],[text_opt])
button2.click(cut2,[text_inp],[text_opt])
button3.click(cut3,[text_inp],[text_opt])
gr.Markdown(
value=
"后续将支持混合语种编码文本输入。"
)
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
gr.Markdown(value="后续将支持混合语种编码文本输入。")
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
server_port=infer_ttswebui,
quiet=True,
)
)

View File

@ -1,11 +1,12 @@
import os
import sys
import traceback
from collections import OrderedDict
import torch
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
def savee(ckpt, name, epoch, steps, hps):
try:
opt = OrderedDict()
@ -15,8 +16,8 @@ def savee(ckpt, name, epoch, steps, hps):
continue
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch,steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir,name))
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()

View File

@ -2,56 +2,84 @@
import os
import pdb
if("_CUDA_VISIBLE_DEVICES"in os.environ):
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
from pathlib import Path
import torch,platform
import torch, platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict
class my_model_ckpt(ModelCheckpoint):
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
def __init__(
self,
config,
if_save_latest,
if_save_every_weights,
half_weights_save_dir,
exp_name,
**kwargs
):
super().__init__(**kwargs)
self.if_save_latest=if_save_latest
self.if_save_every_weights=if_save_every_weights
self.half_weights_save_dir=half_weights_save_dir
self.exp_name=exp_name
self.config=config
self.if_save_latest = if_save_latest
self.if_save_every_weights = if_save_every_weights
self.half_weights_save_dir = half_weights_save_dir
self.exp_name = exp_name
self.config = config
def on_train_epoch_end(self, trainer, pl_module):
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
if not self._should_skip_saving_checkpoint(
trainer
) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean=list(os.listdir(self.dirpath))
if (
self._every_n_epochs >= 1
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if (
self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates)
if (self.if_save_latest == True):
if self.if_save_latest == True:
for name in to_clean:
try:
os.remove("%s/%s"%(self.dirpath,name))
except:pass
if(self.if_save_every_weights==True):
to_save_od=OrderedDict()
to_save_od["weight"]=OrderedDict()
dictt=trainer.strategy._lightning_module.state_dict()
for key in dictt:to_save_od["weight"][key]=dictt[key].half()
to_save_od["config"]=self.config
to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
os.remove("%s/%s" % (self.dirpath, name))
except:
pass
if self.if_save_every_weights == True:
to_save_od = OrderedDict()
to_save_od["weight"] = OrderedDict()
dictt = trainer.strategy._lightning_module.state_dict()
for key in dictt:
to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
torch.save(
to_save_od,
"%s/%s-e%s.ckpt"
% (
self.half_weights_save_dir,
self.exp_name,
trainer.current_epoch + 1,
),
)
self._save_last_checkpoint(trainer, monitor_candidates)
@ -61,41 +89,45 @@ def main(args):
output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt'
ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True)
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = my_model_ckpt(
config=config,
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
if_save_latest=config["train"]["if_save_latest"],
if_save_every_weights=config["train"]["if_save_every_weights"],
half_weights_save_dir=config["train"]["half_weights_save_dir"],
exp_name=config["train"]["exp_name"],
save_top_k=-1,
monitor='top_3_acc',
mode='max',
monitor="top_3_acc",
mode="max",
save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir,
)
logger = TensorBoardLogger(
name=output_dir.stem,
save_dir=output_dir
)
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator='gpu',
accelerator="gpu",
# val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None,
limit_val_batches=0,
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
strategy=DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
),
precision=config["train"]["precision"],
logger=logger,num_sanity_val_steps=0,
callbacks=[ckpt_callback])
logger=logger,
num_sanity_val_steps=0,
callbacks=[ckpt_callback],
)
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir)
config, output_dir
)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
@ -116,14 +148,15 @@ def main(args):
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'-c',
'--config_file',
"-c",
"--config_file",
type=str,
default='configs/s1longer.yaml',
help='path of config file')
default="configs/s1longer.yaml",
help="path of config file",
)
# args for dataset
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')

View File

@ -1,4 +1,5 @@
import utils,os
import utils, os
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch
@ -6,11 +7,12 @@ from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist,traceback
import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging,traceback
import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
@ -20,37 +22,42 @@ from module import commons
from module.data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import (
generator_loss,
discriminator_loss,
feature_loss,
kl_loss
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
def main():
"""Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed."
n_gpus = torch.cuda.device_count()
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(randint(20000, 55555))
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
def run(rank, n_gpus, hps):
@ -62,21 +69,54 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank)
dist.init_process_group(
backend="gloo" if os.name == "nt" else "nccl",
init_method="env://",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data)########
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
[
32,
300,
400,
500,
600,
700,
800,
900,
1000,
1100,
1200,
1300,
1400,
1500,
1600,
1700,
1800,
1900,
],
num_replicas=n_gpus,
rank=rank,
shuffle=True)
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True,
collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16)
train_loader = DataLoader(
train_dataset,
num_workers=6,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=16,
)
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
@ -87,17 +127,21 @@ def run(rank, n_gpus, hps):
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model).cuda(rank)
**hps.model,
).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
for name, param in net_g.named_parameters():
if not param.requires_grad:
print(name,"not requires_grad")
print(name, "not requires_grad")
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters())
base_params = filter(
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
net_g.parameters(),
)
# te_p=net_g.enc_p.text_embedding.parameters()
# et_p=net_g.enc_p.encoder_text.parameters()
@ -106,31 +150,46 @@ def run(rank, n_gpus, hps):
optim_g = torch.optim.AdamW(
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
[
{"params":base_params,"lr":hps.train.learning_rate},
{"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
{"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
{"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
{"params": base_params, "lr": hps.train.learning_rate},
{
"params": net_g.enc_p.text_embedding.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.encoder_text.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.mrte.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
],
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
eps=hps.train.eps,
)
optim_d = torch.optim.AdamW(
net_d.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
net_g = DDP(net_g, device_ids=[rank],find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank],find_unused_parameters=True)
eps=hps.train.eps,
)
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
net_d,
optim_d,
) # D多半加载没事
if rank == 0:
logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
net_g,
optim_g,
)
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
@ -144,7 +203,8 @@ def run(rank, n_gpus, hps):
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print(
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
) ##测试不加载优化器
if hps.train.pretrained_s2D != "":
@ -159,8 +219,12 @@ def run(rank, n_gpus, hps):
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
)
for _ in range(epoch_str):
scheduler_g.step()
scheduler_d.step()
@ -169,17 +233,39 @@ def run(rank, n_gpus, hps):
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None], logger, [writer, writer_eval])
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None],
logger,
[writer, writer_eval],
)
else:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
[train_loader, None], None, None)
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
)
scheduler_g.step()
scheduler_d.step()
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train()
net_d.train()
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)):
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
for batch_idx, (
ssl,
ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in tqdm(enumerate(train_loader)):
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad=False
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
)
with autocast(enabled=hps.train.fp16_run):
y_hat, kl_ssl, ids_slice, x_mask, z_mask, \
(z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths)
(
y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch(
spec,
@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
hps.data.mel_fmax,
)
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
y = commons.slice_segments(
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update(
{"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl})
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl_ssl": kl_ssl,
"loss/g/kl": loss_kl,
}
)
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy()
),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict)
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
),
)
else:
utils.save_checkpoint(
@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
)
)
if rank == 0:
logger.info('====> Epoch: {}'.format(epoch))
logger.info("====> Epoch: {}".format(epoch))
def evaluate(hps, generator, eval_loader, writer_eval):
@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
audio_dict = {}
print("Evaluating ...")
with torch.no_grad():
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader):
for batch_idx, (
ssl,
ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in enumerate(eval_loader):
print(111)
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda()
for test in [0, 1]:
y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test)
y_hat, mask, *_ = generator.module.infer(
ssl, spec, spec_lengths, text, text_lengths, test=test
)
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel_torch(
@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax)
hps.data.mel_fmax,
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(),
hps.data.filter_length,
@ -373,16 +526,26 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
hps.data.mel_fmax,
)
image_dict.update({
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
})
audio_dict.update({
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]]
})
image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :y_lengths[0]]})
image_dict.update(
{
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy()
)
}
)
audio_dict.update(
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
)
image_dict.update(
{
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy()
)
}
)
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
# audio_dict.update({
@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
global_step=global_step,
images=image_dict,
audios=audio_dict,
audio_sampling_rate=hps.data.sampling_rate
audio_sampling_rate=hps.data.sampling_rate,
)
generator.train()
if __name__ == "__main__":
main()

View File

@ -12,8 +12,9 @@ import numpy as np
from scipy.io.wavfile import read
import torch
import logging
logging.getLogger('numba').setLevel(logging.ERROR)
logging.getLogger('matplotlib').setLevel(logging.ERROR)
logging.getLogger("numba").setLevel(logging.ERROR)
logging.getLogger("matplotlib").setLevel(logging.ERROR)
MATPLOTLIB_FLAG = False
@ -23,13 +24,17 @@ logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
learning_rate = checkpoint_dict['learning_rate']
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None
and not skip_optimizer
and checkpoint_dict["optimizer"] is not None
):
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
# assert "quantizer" not in k
# print("load", k)
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except:
traceback.print_exc()
print("error, %s is not in the checkpoint" % k)#shape不对也会比如text_embedding当cleaner修改时
print(
"error, %s is not in the checkpoint" % k
) # shape不对也会比如text_embedding当cleaner修改时
new_state_dict[k] = v
if hasattr(model, 'module'):
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
print("load ")
logger.info("Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path, iteration))
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info("Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path))
if hasattr(model, 'module'):
logger.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save({'model': state_dict,
'iteration': iteration,
'optimizer': optimizer.state_dict(),
'learning_rate': learning_rate}, checkpoint_path)
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats='HWC')
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
interpolation='none')
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
xlabel = "Decoder timestep"
if info is not None:
xlabel += '\n\n' + info
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
@ -147,16 +176,31 @@ def load_wav_to_torch(full_path):
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True, stage=1):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration')
parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir')
parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step')
parser.add_argument(
"-c",
"--config",
type=str,
default="./configs/s2.json",
help="JSON file for configuration",
)
parser.add_argument(
"-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
)
parser.add_argument(
"-rs",
"--resume_step",
type=int,
required=False,
default=None,
help="resume step",
)
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
@ -172,7 +216,7 @@ def get_hparams(init=True, stage=1):
hparams.pretrain = args.pretrain
hparams.resume_step = args.resume_step
# hparams.data.exp_dir = args.exp_dir
if stage ==1:
if stage == 1:
model_dir = hparams.s1_ckpt_dir
else:
model_dir = hparams.s2_ckpt_dir
@ -186,29 +230,38 @@ def get_hparams(init=True, stage=1):
return hparams
def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
import re
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
ckpts_files = [
f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
key=sort_key)
to_del = [os.path.join(path_to_models, fn) for fn in
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
x_sorted = lambda _x: sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn)
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
@ -228,12 +281,15 @@ def get_hparams_from_file(config_path):
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
))
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
@ -242,8 +298,11 @@ def check_git_hash(model_dir):
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]))
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"):
return logger
class HParams():
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
@ -294,5 +353,10 @@ class HParams():
def __repr__(self):
return self.__dict__.__repr__()
if __name__ == '__main__':
print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac'))
if __name__ == "__main__":
print(
load_wav_to_torch(
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
)
)