mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Code refactor + remove unused imports
This commit is contained in:
parent
9031ac9a92
commit
0d92575115
@ -1,49 +1,55 @@
|
|||||||
import os
|
import os
|
||||||
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
|
||||||
sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
|
gpt_path = os.environ.get(
|
||||||
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
|
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
|
)
|
||||||
infer_ttswebui=os.environ.get("infer_ttswebui",9872)
|
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||||
infer_ttswebui=int(infer_ttswebui)
|
cnhubert_base_path = os.environ.get(
|
||||||
if("_CUDA_VISIBLE_DEVICES"in os.environ):
|
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
|
)
|
||||||
is_half=eval(os.environ.get("is_half","True"))
|
bert_path = os.environ.get(
|
||||||
|
"bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
|
||||||
|
)
|
||||||
|
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||||
|
infer_ttswebui = int(infer_ttswebui)
|
||||||
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||||
|
is_half = eval(os.environ.get("is_half", "True"))
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
import sys,torch,numpy as np
|
import torch, numpy as np
|
||||||
from pathlib import Path
|
import os, librosa, torch
|
||||||
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
|
|
||||||
# torch.backends.cuda.sdp_kernel("flash")
|
# torch.backends.cuda.sdp_kernel("flash")
|
||||||
# torch.backends.cuda.enable_flash_sdp(True)
|
# torch.backends.cuda.enable_flash_sdp(True)
|
||||||
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
|
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
|
||||||
# torch.backends.cuda.enable_math_sdp(True)
|
# torch.backends.cuda.enable_math_sdp(True)
|
||||||
from random import shuffle
|
|
||||||
from AR.utils import get_newest_ckpt
|
|
||||||
from glob import glob
|
|
||||||
from tqdm import tqdm
|
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
cnhubert.cnhubert_base_path=cnhubert_base_path
|
|
||||||
from io import BytesIO
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
from module.models import SynthesizerTrn
|
from module.models import SynthesizerTrn
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from AR.utils.io import load_yaml_config
|
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
from text.cleaner import text_to_sequence, clean_text
|
from text.cleaner import clean_text
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
from module.mel_processing import spectrogram_torch
|
from module.mel_processing import spectrogram_torch
|
||||||
from my_utils import load_audio
|
from my_utils import load_audio
|
||||||
|
|
||||||
device="cuda"
|
device = "cuda"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||||
if(is_half==True):bert_model=bert_model.half().to(device)
|
if is_half == True:
|
||||||
else:bert_model=bert_model.to(device)
|
bert_model = bert_model.half().to(device)
|
||||||
|
else:
|
||||||
|
bert_model = bert_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
# bert_model=bert_model.to(device)
|
# bert_model=bert_model.to(device)
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = tokenizer(text, return_tensors="pt")
|
inputs = tokenizer(text, return_tensors="pt")
|
||||||
for i in inputs:
|
for i in inputs:
|
||||||
inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
|
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
|
||||||
res = bert_model(**inputs, output_hidden_states=True)
|
res = bert_model(**inputs, output_hidden_states=True)
|
||||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||||
assert len(word2ph) == len(text)
|
assert len(word2ph) == len(text)
|
||||||
@ -55,9 +61,12 @@ def get_bert_feature(text, word2ph):
|
|||||||
# if(is_half==True):phone_level_feature=phone_level_feature.half()
|
# if(is_half==True):phone_level_feature=phone_level_feature.half()
|
||||||
return phone_level_feature.T
|
return phone_level_feature.T
|
||||||
|
|
||||||
|
|
||||||
n_semantic = 1024
|
n_semantic = 1024
|
||||||
dict_s2=torch.load(sovits_path,map_location="cpu")
|
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||||
hps=dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
|
|
||||||
|
|
||||||
class DictToAttrRecursive:
|
class DictToAttrRecursive:
|
||||||
def __init__(self, input_dict):
|
def __init__(self, input_dict):
|
||||||
for key, value in input_dict.items():
|
for key, value in input_dict.items():
|
||||||
@ -67,206 +76,271 @@ class DictToAttrRecursive:
|
|||||||
else:
|
else:
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
hps = DictToAttrRecursive(hps)
|
hps = DictToAttrRecursive(hps)
|
||||||
hps.model.semantic_frame_rate="25hz"
|
hps.model.semantic_frame_rate = "25hz"
|
||||||
dict_s1=torch.load(gpt_path,map_location="cpu")
|
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||||
config=dict_s1["config"]
|
config = dict_s1["config"]
|
||||||
ssl_model=cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
if(is_half==True):ssl_model=ssl_model.half().to(device)
|
if is_half == True:
|
||||||
else:ssl_model=ssl_model.to(device)
|
ssl_model = ssl_model.half().to(device)
|
||||||
|
else:
|
||||||
|
ssl_model = ssl_model.to(device)
|
||||||
|
|
||||||
vq_model = SynthesizerTrn(
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model)
|
**hps.model
|
||||||
if(is_half==True):vq_model=vq_model.half().to(device)
|
)
|
||||||
else:vq_model=vq_model.to(device)
|
if is_half == True:
|
||||||
|
vq_model = vq_model.half().to(device)
|
||||||
|
else:
|
||||||
|
vq_model = vq_model.to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||||
hz = 50
|
hz = 50
|
||||||
max_sec = config['data']['max_sec']
|
max_sec = config["data"]["max_sec"]
|
||||||
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
|
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
|
||||||
t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
|
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
||||||
t2s_model.load_state_dict(dict_s1["weight"])
|
t2s_model.load_state_dict(dict_s1["weight"])
|
||||||
if(is_half==True):t2s_model=t2s_model.half()
|
if is_half == True:
|
||||||
t2s_model=t2s_model.to(device)
|
t2s_model = t2s_model.half()
|
||||||
|
t2s_model = t2s_model.to(device)
|
||||||
t2s_model.eval()
|
t2s_model.eval()
|
||||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||||
|
|
||||||
|
|
||||||
def get_spepc(hps, filename):
|
def get_spepc(hps, filename):
|
||||||
audio=load_audio(filename,int(hps.data.sampling_rate))
|
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||||
audio=torch.FloatTensor(audio)
|
audio = torch.FloatTensor(audio)
|
||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
|
spec = spectrogram_torch(
|
||||||
|
audio_norm,
|
||||||
|
hps.data.filter_length,
|
||||||
|
hps.data.sampling_rate,
|
||||||
|
hps.data.hop_length,
|
||||||
|
hps.data.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
dict_language={
|
|
||||||
"中文":"zh",
|
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
|
||||||
"英文":"en",
|
|
||||||
"日文":"ja"
|
|
||||||
}
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||||
def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
|
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text=prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
prompt_language,text=prompt_language,text.strip("\n")
|
prompt_language, text = prompt_language, text.strip("\n")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
|
||||||
wav16k = torch.from_numpy(wav16k)
|
wav16k = torch.from_numpy(wav16k)
|
||||||
if(is_half==True):wav16k=wav16k.half().to(device)
|
if is_half == True:
|
||||||
else:wav16k=wav16k.to(device)
|
wav16k = wav16k.half().to(device)
|
||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
|
else:
|
||||||
|
wav16k = wav16k.to(device)
|
||||||
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||||
|
"last_hidden_state"
|
||||||
|
].transpose(
|
||||||
|
1, 2
|
||||||
|
) # .float()
|
||||||
codes = vq_model.extract_latent(ssl_content)
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
prompt_language=dict_language[prompt_language]
|
prompt_language = dict_language[prompt_language]
|
||||||
text_language=dict_language[text_language]
|
text_language = dict_language[text_language]
|
||||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
||||||
phones1=cleaned_text_to_sequence(phones1)
|
phones1 = cleaned_text_to_sequence(phones1)
|
||||||
texts=text.split("\n")
|
texts = text.split("\n")
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
|
zero_wav = np.zeros(
|
||||||
|
int(hps.data.sampling_rate * 0.3),
|
||||||
|
dtype=np.float16 if is_half == True else np.float32,
|
||||||
|
)
|
||||||
for text in texts:
|
for text in texts:
|
||||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
||||||
phones2 = cleaned_text_to_sequence(phones2)
|
phones2 = cleaned_text_to_sequence(phones2)
|
||||||
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
if prompt_language == "zh":
|
||||||
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
|
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||||
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
else:
|
||||||
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
bert1 = torch.zeros(
|
||||||
|
(1024, len(phones1)),
|
||||||
|
dtype=torch.float16 if is_half == True else torch.float32,
|
||||||
|
).to(device)
|
||||||
|
if text_language == "zh":
|
||||||
|
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||||
|
else:
|
||||||
|
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
|
|
||||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||||
bert = bert.to(device).unsqueeze(0)
|
bert = bert.to(device).unsqueeze(0)
|
||||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# pred_semantic = t2s_model.model.infer(
|
# pred_semantic = t2s_model.model.infer(
|
||||||
pred_semantic,idx = t2s_model.model.infer_panel(
|
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||||
all_phoneme_ids,
|
all_phoneme_ids,
|
||||||
all_phoneme_len,
|
all_phoneme_len,
|
||||||
prompt,
|
prompt,
|
||||||
bert,
|
bert,
|
||||||
# prompt_phone_len=ph_offset,
|
# prompt_phone_len=ph_offset,
|
||||||
top_k=config['inference']['top_k'],
|
top_k=config["inference"]["top_k"],
|
||||||
early_stop_num=hz * max_sec)
|
early_stop_num=hz * max_sec,
|
||||||
|
)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
# print(pred_semantic.shape,idx)
|
||||||
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
|
||||||
refer = get_spepc(hps, ref_wav_path)#.to(device)
|
0
|
||||||
if(is_half==True):refer=refer.half().to(device)
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
else:refer=refer.to(device)
|
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||||
|
if is_half == True:
|
||||||
|
refer = refer.half().to(device)
|
||||||
|
else:
|
||||||
|
refer = refer.to(device)
|
||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||||
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
|
audio = (
|
||||||
|
vq_model.decode(
|
||||||
|
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
|
||||||
|
)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()[0, 0]
|
||||||
|
) ###试试重建不带上prompt部分
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
|
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
|
||||||
|
np.int16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
splits = {
|
||||||
|
",",
|
||||||
|
"。",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
",",
|
||||||
|
".",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
"~",
|
||||||
|
":",
|
||||||
|
":",
|
||||||
|
"—",
|
||||||
|
"…",
|
||||||
|
} # 不考虑省略号
|
||||||
|
|
||||||
|
|
||||||
splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||||
if (todo_text[-1] not in splits): todo_text += "。"
|
if todo_text[-1] not in splits:
|
||||||
|
todo_text += "。"
|
||||||
i_split_head = i_split_tail = 0
|
i_split_head = i_split_tail = 0
|
||||||
len_text = len(todo_text)
|
len_text = len(todo_text)
|
||||||
todo_texts = []
|
todo_texts = []
|
||||||
while (1):
|
while 1:
|
||||||
if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
if i_split_head >= len_text:
|
||||||
if (todo_text[i_split_head] in splits):
|
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||||
|
if todo_text[i_split_head] in splits:
|
||||||
i_split_head += 1
|
i_split_head += 1
|
||||||
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||||
i_split_tail = i_split_head
|
i_split_tail = i_split_head
|
||||||
else:
|
else:
|
||||||
i_split_head += 1
|
i_split_head += 1
|
||||||
return todo_texts
|
return todo_texts
|
||||||
|
|
||||||
|
|
||||||
def cut1(inp):
|
def cut1(inp):
|
||||||
inp=inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
inps=split(inp)
|
inps = split(inp)
|
||||||
split_idx=list(range(0,len(inps),5))
|
split_idx = list(range(0, len(inps), 5))
|
||||||
split_idx[-1]=None
|
split_idx[-1] = None
|
||||||
if(len(split_idx)>1):
|
if len(split_idx) > 1:
|
||||||
opts=[]
|
opts = []
|
||||||
for idx in range(len(split_idx)-1):
|
for idx in range(len(split_idx) - 1):
|
||||||
opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
|
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
||||||
else:
|
else:
|
||||||
opts=[inp]
|
opts = [inp]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
def cut2(inp):
|
def cut2(inp):
|
||||||
inp=inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
inps=split(inp)
|
inps = split(inp)
|
||||||
if(len(inps)<2):return [inp]
|
if len(inps) < 2:
|
||||||
opts=[]
|
return [inp]
|
||||||
summ=0
|
opts = []
|
||||||
tmp_str=""
|
summ = 0
|
||||||
|
tmp_str = ""
|
||||||
for i in range(len(inps)):
|
for i in range(len(inps)):
|
||||||
summ+=len(inps[i])
|
summ += len(inps[i])
|
||||||
tmp_str+=inps[i]
|
tmp_str += inps[i]
|
||||||
if(summ>50):
|
if summ > 50:
|
||||||
summ=0
|
summ = 0
|
||||||
opts.append(tmp_str)
|
opts.append(tmp_str)
|
||||||
tmp_str=""
|
tmp_str = ""
|
||||||
if(tmp_str!=""):opts.append(tmp_str)
|
if tmp_str != "":
|
||||||
if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
|
opts.append(tmp_str)
|
||||||
opts[-2]=opts[-2]+opts[-1]
|
if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||||
opts=opts[:-1]
|
opts[-2] = opts[-2] + opts[-1]
|
||||||
|
opts = opts[:-1]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
def cut3(inp):
|
def cut3(inp):
|
||||||
inp=inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
return "\n".join(["%s。"%item for item in inp.strip("。").split("。")])
|
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
value=
|
value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
|
||||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
|
|
||||||
)
|
)
|
||||||
# with gr.Tabs():
|
# with gr.Tabs():
|
||||||
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.Markdown(
|
gr.Markdown(value="*请上传并填写参考信息")
|
||||||
value=
|
|
||||||
"*请上传并填写参考信息"
|
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
|
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
|
||||||
prompt_text= gr.Textbox(label="参考音频的文本",value="")
|
prompt_text = gr.Textbox(label="参考音频的文本", value="")
|
||||||
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文")
|
prompt_language = gr.Dropdown(
|
||||||
gr.Markdown(
|
label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
|
||||||
value=
|
)
|
||||||
"*请填写需要合成的目标文本"
|
gr.Markdown(value="*请填写需要合成的目标文本")
|
||||||
)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text=gr.Textbox(label="需要合成的文本",value="")
|
text = gr.Textbox(label="需要合成的文本", value="")
|
||||||
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文")
|
text_language = gr.Dropdown(
|
||||||
inference_button=gr.Button("合成语音", variant="primary")
|
label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
|
||||||
|
)
|
||||||
|
inference_button = gr.Button("合成语音", variant="primary")
|
||||||
output = gr.Audio(label="输出的语音")
|
output = gr.Audio(label="输出的语音")
|
||||||
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
|
inference_button.click(
|
||||||
|
get_tts_wav,
|
||||||
gr.Markdown(
|
[inp_ref, prompt_text, prompt_language, text, text_language],
|
||||||
value=
|
[output],
|
||||||
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
|
text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
|
||||||
button1 = gr.Button("凑五句一切", variant="primary")
|
button1 = gr.Button("凑五句一切", variant="primary")
|
||||||
button2 = gr.Button("凑50字一切", variant="primary")
|
button2 = gr.Button("凑50字一切", variant="primary")
|
||||||
button3 = gr.Button("按中文句号。切", variant="primary")
|
button3 = gr.Button("按中文句号。切", variant="primary")
|
||||||
text_opt = gr.Textbox(label="切分后文本", value="")
|
text_opt = gr.Textbox(label="切分后文本", value="")
|
||||||
button1.click(cut1,[text_inp],[text_opt])
|
button1.click(cut1, [text_inp], [text_opt])
|
||||||
button2.click(cut2,[text_inp],[text_opt])
|
button2.click(cut2, [text_inp], [text_opt])
|
||||||
button3.click(cut3,[text_inp],[text_opt])
|
button3.click(cut3, [text_inp], [text_opt])
|
||||||
gr.Markdown(
|
gr.Markdown(value="后续将支持混合语种编码文本输入。")
|
||||||
value=
|
|
||||||
"后续将支持混合语种编码文本输入。"
|
|
||||||
)
|
|
||||||
|
|
||||||
app.queue(concurrency_count=511, max_size=1022).launch(
|
app.queue(concurrency_count=511, max_size=1022).launch(
|
||||||
server_name="0.0.0.0",
|
server_name="0.0.0.0",
|
||||||
inbrowser=True,
|
inbrowser=True,
|
||||||
server_port=infer_ttswebui,
|
server_port=infer_ttswebui,
|
||||||
quiet=True,
|
quiet=True,
|
||||||
)
|
)
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
|
|
||||||
def savee(ckpt, name, epoch, steps, hps):
|
def savee(ckpt, name, epoch, steps, hps):
|
||||||
try:
|
try:
|
||||||
opt = OrderedDict()
|
opt = OrderedDict()
|
||||||
@ -15,8 +16,8 @@ def savee(ckpt, name, epoch, steps, hps):
|
|||||||
continue
|
continue
|
||||||
opt["weight"][key] = ckpt[key].half()
|
opt["weight"][key] = ckpt[key].half()
|
||||||
opt["config"] = hps
|
opt["config"] = hps
|
||||||
opt["info"] = "%sepoch_%siteration" % (epoch,steps)
|
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
||||||
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir,name))
|
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||||
return "Success."
|
return "Success."
|
||||||
except:
|
except:
|
||||||
return traceback.format_exc()
|
return traceback.format_exc()
|
||||||
|
@ -2,56 +2,84 @@
|
|||||||
import os
|
import os
|
||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
if("_CUDA_VISIBLE_DEVICES"in os.environ):
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch,platform
|
import torch, platform
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger
|
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||||
from pytorch_lightning.strategies import DDPStrategy
|
from pytorch_lightning.strategies import DDPStrategy
|
||||||
from AR.data.data_module import Text2SemanticDataModule
|
from AR.data.data_module import Text2SemanticDataModule
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from AR.utils.io import load_yaml_config
|
from AR.utils.io import load_yaml_config
|
||||||
logging.getLogger('numba').setLevel(logging.WARNING)
|
|
||||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
torch.set_float32_matmul_precision('high')
|
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
from AR.utils import get_newest_ckpt
|
from AR.utils import get_newest_ckpt
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class my_model_ckpt(ModelCheckpoint):
|
class my_model_ckpt(ModelCheckpoint):
|
||||||
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
if_save_latest,
|
||||||
|
if_save_every_weights,
|
||||||
|
half_weights_save_dir,
|
||||||
|
exp_name,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.if_save_latest=if_save_latest
|
self.if_save_latest = if_save_latest
|
||||||
self.if_save_every_weights=if_save_every_weights
|
self.if_save_every_weights = if_save_every_weights
|
||||||
self.half_weights_save_dir=half_weights_save_dir
|
self.half_weights_save_dir = half_weights_save_dir
|
||||||
self.exp_name=exp_name
|
self.exp_name = exp_name
|
||||||
self.config=config
|
self.config = config
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
if not self._should_skip_saving_checkpoint(
|
||||||
|
trainer
|
||||||
|
) and self._should_save_on_train_epoch_end(trainer):
|
||||||
monitor_candidates = self._monitor_candidates(trainer)
|
monitor_candidates = self._monitor_candidates(trainer)
|
||||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
if (
|
||||||
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
self._every_n_epochs >= 1
|
||||||
to_clean=list(os.listdir(self.dirpath))
|
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
self.if_save_latest == True
|
||||||
|
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||||
|
to_clean = list(os.listdir(self.dirpath))
|
||||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||||
if (self.if_save_latest == True):
|
if self.if_save_latest == True:
|
||||||
for name in to_clean:
|
for name in to_clean:
|
||||||
try:
|
try:
|
||||||
os.remove("%s/%s"%(self.dirpath,name))
|
os.remove("%s/%s" % (self.dirpath, name))
|
||||||
except:pass
|
except:
|
||||||
if(self.if_save_every_weights==True):
|
pass
|
||||||
to_save_od=OrderedDict()
|
if self.if_save_every_weights == True:
|
||||||
to_save_od["weight"]=OrderedDict()
|
to_save_od = OrderedDict()
|
||||||
dictt=trainer.strategy._lightning_module.state_dict()
|
to_save_od["weight"] = OrderedDict()
|
||||||
for key in dictt:to_save_od["weight"][key]=dictt[key].half()
|
dictt = trainer.strategy._lightning_module.state_dict()
|
||||||
to_save_od["config"]=self.config
|
for key in dictt:
|
||||||
to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
|
to_save_od["weight"][key] = dictt[key].half()
|
||||||
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
|
to_save_od["config"] = self.config
|
||||||
|
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||||
|
torch.save(
|
||||||
|
to_save_od,
|
||||||
|
"%s/%s-e%s.ckpt"
|
||||||
|
% (
|
||||||
|
self.half_weights_save_dir,
|
||||||
|
self.exp_name,
|
||||||
|
trainer.current_epoch + 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||||
|
|
||||||
|
|
||||||
@ -61,41 +89,45 @@ def main(args):
|
|||||||
output_dir = Path(config["output_dir"])
|
output_dir = Path(config["output_dir"])
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
ckpt_dir = output_dir / 'ckpt'
|
ckpt_dir = output_dir / "ckpt"
|
||||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
seed_everything(config["train"]["seed"], workers=True)
|
seed_everything(config["train"]["seed"], workers=True)
|
||||||
ckpt_callback: ModelCheckpoint = my_model_ckpt(
|
ckpt_callback: ModelCheckpoint = my_model_ckpt(
|
||||||
config=config,
|
config=config,
|
||||||
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
|
if_save_latest=config["train"]["if_save_latest"],
|
||||||
|
if_save_every_weights=config["train"]["if_save_every_weights"],
|
||||||
|
half_weights_save_dir=config["train"]["half_weights_save_dir"],
|
||||||
|
exp_name=config["train"]["exp_name"],
|
||||||
save_top_k=-1,
|
save_top_k=-1,
|
||||||
monitor='top_3_acc',
|
monitor="top_3_acc",
|
||||||
mode='max',
|
mode="max",
|
||||||
save_on_train_epoch_end=True,
|
save_on_train_epoch_end=True,
|
||||||
every_n_epochs=config["train"]["save_every_n_epoch"],
|
every_n_epochs=config["train"]["save_every_n_epoch"],
|
||||||
dirpath=ckpt_dir,
|
dirpath=ckpt_dir,
|
||||||
)
|
)
|
||||||
logger = TensorBoardLogger(
|
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||||
name=output_dir.stem,
|
|
||||||
save_dir=output_dir
|
|
||||||
)
|
|
||||||
trainer: Trainer = Trainer(
|
trainer: Trainer = Trainer(
|
||||||
max_epochs=config["train"]["epochs"],
|
max_epochs=config["train"]["epochs"],
|
||||||
accelerator='gpu',
|
accelerator="gpu",
|
||||||
# val_check_interval=9999999999999999999999,###不要验证
|
# val_check_interval=9999999999999999999999,###不要验证
|
||||||
# check_val_every_n_epoch=None,
|
# check_val_every_n_epoch=None,
|
||||||
limit_val_batches=0,
|
limit_val_batches=0,
|
||||||
devices=-1,
|
devices=-1,
|
||||||
benchmark=False,
|
benchmark=False,
|
||||||
fast_dev_run=False,
|
fast_dev_run=False,
|
||||||
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
|
strategy=DDPStrategy(
|
||||||
|
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||||
|
),
|
||||||
precision=config["train"]["precision"],
|
precision=config["train"]["precision"],
|
||||||
logger=logger,num_sanity_val_steps=0,
|
logger=logger,
|
||||||
callbacks=[ckpt_callback])
|
num_sanity_val_steps=0,
|
||||||
|
callbacks=[ckpt_callback],
|
||||||
|
)
|
||||||
|
|
||||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
||||||
config, output_dir)
|
config, output_dir
|
||||||
|
)
|
||||||
|
|
||||||
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
||||||
config,
|
config,
|
||||||
@ -116,14 +148,15 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c',
|
"-c",
|
||||||
'--config_file',
|
"--config_file",
|
||||||
type=str,
|
type=str,
|
||||||
default='configs/s1longer.yaml',
|
default="configs/s1longer.yaml",
|
||||||
help='path of config file')
|
help="path of config file",
|
||||||
|
)
|
||||||
# args for dataset
|
# args for dataset
|
||||||
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
|
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
|
||||||
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
|
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import utils,os
|
import utils, os
|
||||||
|
|
||||||
hps = utils.get_hparams(stage=2)
|
hps = utils.get_hparams(stage=2)
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||||
import torch
|
import torch
|
||||||
@ -6,11 +7,12 @@ from torch.nn import functional as F
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.distributed as dist,traceback
|
import torch.distributed as dist, traceback
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import logging,traceback
|
import logging, traceback
|
||||||
|
|
||||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||||
logging.getLogger("numba").setLevel(logging.INFO)
|
logging.getLogger("numba").setLevel(logging.INFO)
|
||||||
@ -20,37 +22,42 @@ from module import commons
|
|||||||
from module.data_utils import (
|
from module.data_utils import (
|
||||||
TextAudioSpeakerLoader,
|
TextAudioSpeakerLoader,
|
||||||
TextAudioSpeakerCollate,
|
TextAudioSpeakerCollate,
|
||||||
DistributedBucketSampler
|
DistributedBucketSampler,
|
||||||
)
|
)
|
||||||
from module.models import (
|
from module.models import (
|
||||||
SynthesizerTrn,
|
SynthesizerTrn,
|
||||||
MultiPeriodDiscriminator,
|
MultiPeriodDiscriminator,
|
||||||
)
|
)
|
||||||
from module.losses import (
|
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
||||||
generator_loss,
|
|
||||||
discriminator_loss,
|
|
||||||
feature_loss,
|
|
||||||
kl_loss
|
|
||||||
)
|
|
||||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||||
from process_ckpt import savee
|
from process_ckpt import savee
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = False
|
torch.backends.cudnn.deterministic = False
|
||||||
###反正A100fp32更快,那试试tf32吧
|
###反正A100fp32更快,那试试tf32吧
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响
|
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
||||||
# from config import pretrained_s2G,pretrained_s2D
|
# from config import pretrained_s2G,pretrained_s2D
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Assume Single Node Multi GPUs Training Only"""
|
"""Assume Single Node Multi GPUs Training Only"""
|
||||||
assert torch.cuda.is_available(), "CPU training is not allowed."
|
assert torch.cuda.is_available(), "CPU training is not allowed."
|
||||||
|
|
||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.cuda.device_count()
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ['MASTER_PORT'] = str(randint(20000, 55555))
|
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||||
|
|
||||||
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
|
mp.spawn(
|
||||||
|
run,
|
||||||
|
nprocs=n_gpus,
|
||||||
|
args=(
|
||||||
|
n_gpus,
|
||||||
|
hps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run(rank, n_gpus, hps):
|
def run(rank, n_gpus, hps):
|
||||||
@ -62,21 +69,54 @@ def run(rank, n_gpus, hps):
|
|||||||
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||||
|
|
||||||
dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank)
|
dist.init_process_group(
|
||||||
|
backend="gloo" if os.name == "nt" else "nccl",
|
||||||
|
init_method="env://",
|
||||||
|
world_size=n_gpus,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
torch.manual_seed(hps.train.seed)
|
torch.manual_seed(hps.train.seed)
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
train_dataset = TextAudioSpeakerLoader(hps.data)########
|
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
||||||
train_sampler = DistributedBucketSampler(
|
train_sampler = DistributedBucketSampler(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
hps.train.batch_size,
|
hps.train.batch_size,
|
||||||
[32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
|
[
|
||||||
|
32,
|
||||||
|
300,
|
||||||
|
400,
|
||||||
|
500,
|
||||||
|
600,
|
||||||
|
700,
|
||||||
|
800,
|
||||||
|
900,
|
||||||
|
1000,
|
||||||
|
1100,
|
||||||
|
1200,
|
||||||
|
1300,
|
||||||
|
1400,
|
||||||
|
1500,
|
||||||
|
1600,
|
||||||
|
1700,
|
||||||
|
1800,
|
||||||
|
1900,
|
||||||
|
],
|
||||||
num_replicas=n_gpus,
|
num_replicas=n_gpus,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
shuffle=True)
|
shuffle=True,
|
||||||
|
)
|
||||||
collate_fn = TextAudioSpeakerCollate()
|
collate_fn = TextAudioSpeakerCollate()
|
||||||
train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True,
|
train_loader = DataLoader(
|
||||||
collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16)
|
train_dataset,
|
||||||
|
num_workers=6,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
persistent_workers=True,
|
||||||
|
prefetch_factor=16,
|
||||||
|
)
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
||||||
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
|
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
|
||||||
@ -87,17 +127,21 @@ def run(rank, n_gpus, hps):
|
|||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model).cuda(rank)
|
**hps.model,
|
||||||
|
).cuda(rank)
|
||||||
|
|
||||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||||
for name, param in net_g.named_parameters():
|
for name, param in net_g.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
print(name,"not requires_grad")
|
print(name, "not requires_grad")
|
||||||
|
|
||||||
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
|
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
|
||||||
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
|
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
|
||||||
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
|
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
|
||||||
base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters())
|
base_params = filter(
|
||||||
|
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
|
||||||
|
net_g.parameters(),
|
||||||
|
)
|
||||||
|
|
||||||
# te_p=net_g.enc_p.text_embedding.parameters()
|
# te_p=net_g.enc_p.text_embedding.parameters()
|
||||||
# et_p=net_g.enc_p.encoder_text.parameters()
|
# et_p=net_g.enc_p.encoder_text.parameters()
|
||||||
@ -106,31 +150,46 @@ def run(rank, n_gpus, hps):
|
|||||||
optim_g = torch.optim.AdamW(
|
optim_g = torch.optim.AdamW(
|
||||||
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
|
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
|
||||||
[
|
[
|
||||||
{"params":base_params,"lr":hps.train.learning_rate},
|
{"params": base_params, "lr": hps.train.learning_rate},
|
||||||
{"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
|
{
|
||||||
{"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
|
"params": net_g.enc_p.text_embedding.parameters(),
|
||||||
{"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
|
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": net_g.enc_p.encoder_text.parameters(),
|
||||||
|
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": net_g.enc_p.mrte.parameters(),
|
||||||
|
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
betas=hps.train.betas,
|
betas=hps.train.betas,
|
||||||
eps=hps.train.eps)
|
eps=hps.train.eps,
|
||||||
|
)
|
||||||
optim_d = torch.optim.AdamW(
|
optim_d = torch.optim.AdamW(
|
||||||
net_d.parameters(),
|
net_d.parameters(),
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
betas=hps.train.betas,
|
betas=hps.train.betas,
|
||||||
eps=hps.train.eps)
|
eps=hps.train.eps,
|
||||||
net_g = DDP(net_g, device_ids=[rank],find_unused_parameters=True)
|
)
|
||||||
net_d = DDP(net_d, device_ids=[rank],find_unused_parameters=True)
|
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
try: # 如果能加载自动resume
|
try: # 如果能加载自动resume
|
||||||
_, _, _, epoch_str = utils.load_checkpoint(
|
_, _, _, epoch_str = utils.load_checkpoint(
|
||||||
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d
|
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
|
||||||
|
net_d,
|
||||||
|
optim_d,
|
||||||
) # D多半加载没事
|
) # D多半加载没事
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded D")
|
logger.info("loaded D")
|
||||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||||
_, _, _, epoch_str = utils.load_checkpoint(
|
_, _, _, epoch_str = utils.load_checkpoint(
|
||||||
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g
|
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
|
||||||
|
net_g,
|
||||||
|
optim_g,
|
||||||
)
|
)
|
||||||
global_step = (epoch_str - 1) * len(train_loader)
|
global_step = (epoch_str - 1) * len(train_loader)
|
||||||
# epoch_str = 1
|
# epoch_str = 1
|
||||||
@ -144,7 +203,8 @@ def run(rank, n_gpus, hps):
|
|||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||||
print(
|
print(
|
||||||
net_g.module.load_state_dict(
|
net_g.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
) ##测试不加载优化器
|
) ##测试不加载优化器
|
||||||
if hps.train.pretrained_s2D != "":
|
if hps.train.pretrained_s2D != "":
|
||||||
@ -159,8 +219,12 @@ def run(rank, n_gpus, hps):
|
|||||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1)
|
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||||
|
)
|
||||||
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||||
|
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
||||||
|
)
|
||||||
for _ in range(epoch_str):
|
for _ in range(epoch_str):
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
scheduler_d.step()
|
scheduler_d.step()
|
||||||
@ -169,17 +233,39 @@ def run(rank, n_gpus, hps):
|
|||||||
|
|
||||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
train_and_evaluate(
|
||||||
# [train_loader, eval_loader], logger, [writer, writer_eval])
|
rank,
|
||||||
[train_loader, None], logger, [writer, writer_eval])
|
epoch,
|
||||||
|
hps,
|
||||||
|
[net_g, net_d],
|
||||||
|
[optim_g, optim_d],
|
||||||
|
[scheduler_g, scheduler_d],
|
||||||
|
scaler,
|
||||||
|
# [train_loader, eval_loader], logger, [writer, writer_eval])
|
||||||
|
[train_loader, None],
|
||||||
|
logger,
|
||||||
|
[writer, writer_eval],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
train_and_evaluate(
|
||||||
[train_loader, None], None, None)
|
rank,
|
||||||
|
epoch,
|
||||||
|
hps,
|
||||||
|
[net_g, net_d],
|
||||||
|
[optim_g, optim_d],
|
||||||
|
[scheduler_g, scheduler_d],
|
||||||
|
scaler,
|
||||||
|
[train_loader, None],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
scheduler_d.step()
|
scheduler_d.step()
|
||||||
|
|
||||||
|
|
||||||
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
def train_and_evaluate(
|
||||||
|
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||||
|
):
|
||||||
net_g, net_d = nets
|
net_g, net_d = nets
|
||||||
optim_g, optim_d = optims
|
optim_g, optim_d = optims
|
||||||
# scheduler_g, scheduler_d = schedulers
|
# scheduler_g, scheduler_d = schedulers
|
||||||
@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
|
|
||||||
net_g.train()
|
net_g.train()
|
||||||
net_d.train()
|
net_d.train()
|
||||||
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)):
|
for batch_idx, (
|
||||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
|
ssl,
|
||||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
|
ssl_lengths,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
y,
|
||||||
|
y_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
) in tqdm(enumerate(train_loader)):
|
||||||
|
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||||
|
rank, non_blocking=True
|
||||||
|
)
|
||||||
|
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
||||||
|
rank, non_blocking=True
|
||||||
|
)
|
||||||
ssl = ssl.cuda(rank, non_blocking=True)
|
ssl = ssl.cuda(rank, non_blocking=True)
|
||||||
ssl.requires_grad=False
|
ssl.requires_grad = False
|
||||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True)
|
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||||
|
rank, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
y_hat, kl_ssl, ids_slice, x_mask, z_mask, \
|
(
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
y_hat,
|
||||||
|
kl_ssl,
|
||||||
|
ids_slice,
|
||||||
|
x_mask,
|
||||||
|
z_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
stats_ssl,
|
||||||
|
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||||
|
|
||||||
mel = spec_to_mel_torch(
|
mel = spec_to_mel_torch(
|
||||||
spec,
|
spec,
|
||||||
@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
hps.data.n_mel_channels,
|
hps.data.n_mel_channels,
|
||||||
hps.data.sampling_rate,
|
hps.data.sampling_rate,
|
||||||
hps.data.mel_fmin,
|
hps.data.mel_fmin,
|
||||||
hps.data.mel_fmax)
|
hps.data.mel_fmax,
|
||||||
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
)
|
||||||
|
y_mel = commons.slice_segments(
|
||||||
|
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||||
|
)
|
||||||
y_hat_mel = mel_spectrogram_torch(
|
y_hat_mel = mel_spectrogram_torch(
|
||||||
y_hat.squeeze(1),
|
y_hat.squeeze(1),
|
||||||
hps.data.filter_length,
|
hps.data.filter_length,
|
||||||
@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
hps.data.hop_length,
|
hps.data.hop_length,
|
||||||
hps.data.win_length,
|
hps.data.win_length,
|
||||||
hps.data.mel_fmin,
|
hps.data.mel_fmin,
|
||||||
hps.data.mel_fmax
|
hps.data.mel_fmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
y = commons.slice_segments(
|
||||||
|
y, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||||
|
) # slice
|
||||||
|
|
||||||
# Discriminator
|
# Discriminator
|
||||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||||
|
y_d_hat_r, y_d_hat_g
|
||||||
|
)
|
||||||
loss_disc_all = loss_disc
|
loss_disc_all = loss_disc
|
||||||
optim_d.zero_grad()
|
optim_d.zero_grad()
|
||||||
scaler.scale(loss_disc_all).backward()
|
scaler.scale(loss_disc_all).backward()
|
||||||
@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if global_step % hps.train.log_interval == 0:
|
if global_step % hps.train.log_interval == 0:
|
||||||
lr = optim_g.param_groups[0]['lr']
|
lr = optim_g.param_groups[0]["lr"]
|
||||||
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
||||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
logger.info(
|
||||||
epoch,
|
"Train Epoch: {} [{:.0f}%]".format(
|
||||||
100. * batch_idx / len(train_loader)))
|
epoch, 100.0 * batch_idx / len(train_loader)
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||||
|
|
||||||
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
|
scalar_dict = {
|
||||||
"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
|
"loss/g/total": loss_gen_all,
|
||||||
|
"loss/d/total": loss_disc_all,
|
||||||
|
"learning_rate": lr,
|
||||||
|
"grad_norm_d": grad_norm_d,
|
||||||
|
"grad_norm_g": grad_norm_g,
|
||||||
|
}
|
||||||
scalar_dict.update(
|
scalar_dict.update(
|
||||||
{"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl})
|
{
|
||||||
|
"loss/g/fm": loss_fm,
|
||||||
|
"loss/g/mel": loss_mel,
|
||||||
|
"loss/g/kl_ssl": kl_ssl,
|
||||||
|
"loss/g/kl": loss_kl,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
||||||
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
||||||
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
||||||
image_dict = {
|
image_dict = {
|
||||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
|
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
|
y_mel[0].data.cpu().numpy()
|
||||||
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
|
),
|
||||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
|
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||||
|
y_hat_mel[0].data.cpu().numpy()
|
||||||
|
),
|
||||||
|
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||||
|
mel[0].data.cpu().numpy()
|
||||||
|
),
|
||||||
|
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
||||||
|
stats_ssl[0].data.cpu().numpy()
|
||||||
|
),
|
||||||
}
|
}
|
||||||
utils.summarize(
|
utils.summarize(
|
||||||
writer=writer,
|
writer=writer,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
images=image_dict,
|
images=image_dict,
|
||||||
scalars=scalar_dict)
|
scalars=scalar_dict,
|
||||||
|
)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||||
if hps.train.if_save_latest == 0:
|
if hps.train.if_save_latest == 0:
|
||||||
@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
optim_g,
|
optim_g,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)),
|
os.path.join(
|
||||||
|
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
net_d,
|
net_d,
|
||||||
optim_d,
|
optim_d,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)),
|
os.path.join(
|
||||||
|
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
optim_g,
|
optim_g,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)),
|
os.path.join(
|
||||||
|
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
net_d,
|
net_d,
|
||||||
optim_d,
|
optim_d,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)),
|
os.path.join(
|
||||||
|
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||||
if hasattr(net_g, "module"):
|
if hasattr(net_g, "module"):
|
||||||
@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info('====> Epoch: {}'.format(epoch))
|
logger.info("====> Epoch: {}".format(epoch))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(hps, generator, eval_loader, writer_eval):
|
def evaluate(hps, generator, eval_loader, writer_eval):
|
||||||
@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
audio_dict = {}
|
audio_dict = {}
|
||||||
print("Evaluating ...")
|
print("Evaluating ...")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader):
|
for batch_idx, (
|
||||||
|
ssl,
|
||||||
|
ssl_lengths,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
y,
|
||||||
|
y_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
) in enumerate(eval_loader):
|
||||||
print(111)
|
print(111)
|
||||||
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
||||||
y, y_lengths = y.cuda(), y_lengths.cuda()
|
y, y_lengths = y.cuda(), y_lengths.cuda()
|
||||||
ssl = ssl.cuda()
|
ssl = ssl.cuda()
|
||||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
text, text_lengths = text.cuda(), text_lengths.cuda()
|
||||||
for test in [0, 1]:
|
for test in [0, 1]:
|
||||||
|
y_hat, mask, *_ = generator.module.infer(
|
||||||
y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test)
|
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||||
|
)
|
||||||
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
||||||
|
|
||||||
mel = spec_to_mel_torch(
|
mel = spec_to_mel_torch(
|
||||||
@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
hps.data.n_mel_channels,
|
hps.data.n_mel_channels,
|
||||||
hps.data.sampling_rate,
|
hps.data.sampling_rate,
|
||||||
hps.data.mel_fmin,
|
hps.data.mel_fmin,
|
||||||
hps.data.mel_fmax)
|
hps.data.mel_fmax,
|
||||||
|
)
|
||||||
y_hat_mel = mel_spectrogram_torch(
|
y_hat_mel = mel_spectrogram_torch(
|
||||||
y_hat.squeeze(1).float(),
|
y_hat.squeeze(1).float(),
|
||||||
hps.data.filter_length,
|
hps.data.filter_length,
|
||||||
@ -373,16 +526,26 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
hps.data.hop_length,
|
hps.data.hop_length,
|
||||||
hps.data.win_length,
|
hps.data.win_length,
|
||||||
hps.data.mel_fmin,
|
hps.data.mel_fmin,
|
||||||
hps.data.mel_fmax
|
hps.data.mel_fmax,
|
||||||
)
|
)
|
||||||
image_dict.update({
|
image_dict.update(
|
||||||
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
|
{
|
||||||
})
|
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
||||||
audio_dict.update({
|
y_hat_mel[0].cpu().numpy()
|
||||||
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]]
|
)
|
||||||
})
|
}
|
||||||
image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
|
)
|
||||||
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :y_lengths[0]]})
|
audio_dict.update(
|
||||||
|
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
|
||||||
|
)
|
||||||
|
image_dict.update(
|
||||||
|
{
|
||||||
|
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
|
||||||
|
mel[0].cpu().numpy()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
||||||
|
|
||||||
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
|
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
|
||||||
# audio_dict.update({
|
# audio_dict.update({
|
||||||
@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
images=image_dict,
|
images=image_dict,
|
||||||
audios=audio_dict,
|
audios=audio_dict,
|
||||||
audio_sampling_rate=hps.data.sampling_rate
|
audio_sampling_rate=hps.data.sampling_rate,
|
||||||
)
|
)
|
||||||
generator.train()
|
generator.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -12,8 +12,9 @@ import numpy as np
|
|||||||
from scipy.io.wavfile import read
|
from scipy.io.wavfile import read
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger('numba').setLevel(logging.ERROR)
|
|
||||||
logging.getLogger('matplotlib').setLevel(logging.ERROR)
|
logging.getLogger("numba").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("matplotlib").setLevel(logging.ERROR)
|
||||||
|
|
||||||
MATPLOTLIB_FLAG = False
|
MATPLOTLIB_FLAG = False
|
||||||
|
|
||||||
@ -23,13 +24,17 @@ logger = logging
|
|||||||
|
|
||||||
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
||||||
assert os.path.isfile(checkpoint_path)
|
assert os.path.isfile(checkpoint_path)
|
||||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||||
iteration = checkpoint_dict['iteration']
|
iteration = checkpoint_dict["iteration"]
|
||||||
learning_rate = checkpoint_dict['learning_rate']
|
learning_rate = checkpoint_dict["learning_rate"]
|
||||||
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
|
if (
|
||||||
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
optimizer is not None
|
||||||
saved_state_dict = checkpoint_dict['model']
|
and not skip_optimizer
|
||||||
if hasattr(model, 'module'):
|
and checkpoint_dict["optimizer"] is not None
|
||||||
|
):
|
||||||
|
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
||||||
|
saved_state_dict = checkpoint_dict["model"]
|
||||||
|
if hasattr(model, "module"):
|
||||||
state_dict = model.module.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
else:
|
else:
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|||||||
# assert "quantizer" not in k
|
# assert "quantizer" not in k
|
||||||
# print("load", k)
|
# print("load", k)
|
||||||
new_state_dict[k] = saved_state_dict[k]
|
new_state_dict[k] = saved_state_dict[k]
|
||||||
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
|
assert saved_state_dict[k].shape == v.shape, (
|
||||||
|
saved_state_dict[k].shape,
|
||||||
|
v.shape,
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print("error, %s is not in the checkpoint" % k)#shape不对也会,比如text_embedding当cleaner修改时
|
print(
|
||||||
|
"error, %s is not in the checkpoint" % k
|
||||||
|
) # shape不对也会,比如text_embedding当cleaner修改时
|
||||||
new_state_dict[k] = v
|
new_state_dict[k] = v
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, "module"):
|
||||||
model.module.load_state_dict(new_state_dict)
|
model.module.load_state_dict(new_state_dict)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
print("load ")
|
print("load ")
|
||||||
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
logger.info(
|
||||||
checkpoint_path, iteration))
|
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
|
||||||
|
)
|
||||||
return model, optimizer, learning_rate, iteration
|
return model, optimizer, learning_rate, iteration
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
||||||
logger.info("Saving model and optimizer state at iteration {} to {}".format(
|
logger.info(
|
||||||
iteration, checkpoint_path))
|
"Saving model and optimizer state at iteration {} to {}".format(
|
||||||
if hasattr(model, 'module'):
|
iteration, checkpoint_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if hasattr(model, "module"):
|
||||||
state_dict = model.module.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
else:
|
else:
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
torch.save({'model': state_dict,
|
torch.save(
|
||||||
'iteration': iteration,
|
{
|
||||||
'optimizer': optimizer.state_dict(),
|
"model": state_dict,
|
||||||
'learning_rate': learning_rate}, checkpoint_path)
|
"iteration": iteration,
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
},
|
||||||
|
checkpoint_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
|
def summarize(
|
||||||
|
writer,
|
||||||
|
global_step,
|
||||||
|
scalars={},
|
||||||
|
histograms={},
|
||||||
|
images={},
|
||||||
|
audios={},
|
||||||
|
audio_sampling_rate=22050,
|
||||||
|
):
|
||||||
for k, v in scalars.items():
|
for k, v in scalars.items():
|
||||||
writer.add_scalar(k, v, global_step)
|
writer.add_scalar(k, v, global_step)
|
||||||
for k, v in histograms.items():
|
for k, v in histograms.items():
|
||||||
writer.add_histogram(k, v, global_step)
|
writer.add_histogram(k, v, global_step)
|
||||||
for k, v in images.items():
|
for k, v in images.items():
|
||||||
writer.add_image(k, v, global_step, dataformats='HWC')
|
writer.add_image(k, v, global_step, dataformats="HWC")
|
||||||
for k, v in audios.items():
|
for k, v in audios.items():
|
||||||
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
||||||
|
|
||||||
@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram):
|
|||||||
global MATPLOTLIB_FLAG
|
global MATPLOTLIB_FLAG
|
||||||
if not MATPLOTLIB_FLAG:
|
if not MATPLOTLIB_FLAG:
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
MATPLOTLIB_FLAG = True
|
MATPLOTLIB_FLAG = True
|
||||||
mpl_logger = logging.getLogger('matplotlib')
|
mpl_logger = logging.getLogger("matplotlib")
|
||||||
mpl_logger.setLevel(logging.WARNING)
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
import matplotlib.pylab as plt
|
import matplotlib.pylab as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 2))
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||||
interpolation='none')
|
|
||||||
plt.colorbar(im, ax=ax)
|
plt.colorbar(im, ax=ax)
|
||||||
plt.xlabel("Frames")
|
plt.xlabel("Frames")
|
||||||
plt.ylabel("Channels")
|
plt.ylabel("Channels")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
||||||
@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None):
|
|||||||
global MATPLOTLIB_FLAG
|
global MATPLOTLIB_FLAG
|
||||||
if not MATPLOTLIB_FLAG:
|
if not MATPLOTLIB_FLAG:
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
MATPLOTLIB_FLAG = True
|
MATPLOTLIB_FLAG = True
|
||||||
mpl_logger = logging.getLogger('matplotlib')
|
mpl_logger = logging.getLogger("matplotlib")
|
||||||
mpl_logger.setLevel(logging.WARNING)
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
import matplotlib.pylab as plt
|
import matplotlib.pylab as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(6, 4))
|
fig, ax = plt.subplots(figsize=(6, 4))
|
||||||
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
|
im = ax.imshow(
|
||||||
interpolation='none')
|
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
|
||||||
|
)
|
||||||
fig.colorbar(im, ax=ax)
|
fig.colorbar(im, ax=ax)
|
||||||
xlabel = 'Decoder timestep'
|
xlabel = "Decoder timestep"
|
||||||
if info is not None:
|
if info is not None:
|
||||||
xlabel += '\n\n' + info
|
xlabel += "\n\n" + info
|
||||||
plt.xlabel(xlabel)
|
plt.xlabel(xlabel)
|
||||||
plt.ylabel('Encoder timestep')
|
plt.ylabel("Encoder timestep")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
||||||
@ -147,16 +176,31 @@ def load_wav_to_torch(full_path):
|
|||||||
|
|
||||||
|
|
||||||
def load_filepaths_and_text(filename, split="|"):
|
def load_filepaths_and_text(filename, split="|"):
|
||||||
with open(filename, encoding='utf-8') as f:
|
with open(filename, encoding="utf-8") as f:
|
||||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||||
return filepaths_and_text
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
def get_hparams(init=True, stage=1):
|
def get_hparams(init=True, stage=1):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration')
|
parser.add_argument(
|
||||||
parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir')
|
"-c",
|
||||||
parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step')
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="./configs/s2.json",
|
||||||
|
help="JSON file for configuration",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-rs",
|
||||||
|
"--resume_step",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help="resume step",
|
||||||
|
)
|
||||||
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
|
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
|
||||||
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
|
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
|
||||||
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
|
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
|
||||||
@ -172,7 +216,7 @@ def get_hparams(init=True, stage=1):
|
|||||||
hparams.pretrain = args.pretrain
|
hparams.pretrain = args.pretrain
|
||||||
hparams.resume_step = args.resume_step
|
hparams.resume_step = args.resume_step
|
||||||
# hparams.data.exp_dir = args.exp_dir
|
# hparams.data.exp_dir = args.exp_dir
|
||||||
if stage ==1:
|
if stage == 1:
|
||||||
model_dir = hparams.s1_ckpt_dir
|
model_dir = hparams.s1_ckpt_dir
|
||||||
else:
|
else:
|
||||||
model_dir = hparams.s2_ckpt_dir
|
model_dir = hparams.s2_ckpt_dir
|
||||||
@ -186,29 +230,38 @@ def get_hparams(init=True, stage=1):
|
|||||||
return hparams
|
return hparams
|
||||||
|
|
||||||
|
|
||||||
|
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
|
||||||
def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
|
|
||||||
"""Freeing up space by deleting saved ckpts
|
"""Freeing up space by deleting saved ckpts
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
path_to_models -- Path to the model directory
|
path_to_models -- Path to the model directory
|
||||||
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
|
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
|
||||||
sort_by_time -- True -> chronologically delete ckpts
|
sort_by_time -- True -> chronologically delete ckpts
|
||||||
False -> lexicographically delete ckpts
|
False -> lexicographically delete ckpts
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
|
|
||||||
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
|
ckpts_files = [
|
||||||
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
|
f
|
||||||
|
for f in os.listdir(path_to_models)
|
||||||
|
if os.path.isfile(os.path.join(path_to_models, f))
|
||||||
|
]
|
||||||
|
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
|
||||||
|
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
|
||||||
sort_key = time_key if sort_by_time else name_key
|
sort_key = time_key if sort_by_time else name_key
|
||||||
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
|
x_sorted = lambda _x: sorted(
|
||||||
key=sort_key)
|
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
|
||||||
to_del = [os.path.join(path_to_models, fn) for fn in
|
key=sort_key,
|
||||||
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
|
)
|
||||||
|
to_del = [
|
||||||
|
os.path.join(path_to_models, fn)
|
||||||
|
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
|
||||||
|
]
|
||||||
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
|
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
|
||||||
del_routine = lambda x: [os.remove(x), del_info(x)]
|
del_routine = lambda x: [os.remove(x), del_info(x)]
|
||||||
rs = [del_routine(fn) for fn in to_del]
|
rs = [del_routine(fn) for fn in to_del]
|
||||||
|
|
||||||
|
|
||||||
def get_hparams_from_dir(model_dir):
|
def get_hparams_from_dir(model_dir):
|
||||||
config_save_path = os.path.join(model_dir, "config.json")
|
config_save_path = os.path.join(model_dir, "config.json")
|
||||||
with open(config_save_path, "r") as f:
|
with open(config_save_path, "r") as f:
|
||||||
@ -228,12 +281,15 @@ def get_hparams_from_file(config_path):
|
|||||||
hparams = HParams(**config)
|
hparams = HParams(**config)
|
||||||
return hparams
|
return hparams
|
||||||
|
|
||||||
|
|
||||||
def check_git_hash(model_dir):
|
def check_git_hash(model_dir):
|
||||||
source_dir = os.path.dirname(os.path.realpath(__file__))
|
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
if not os.path.exists(os.path.join(source_dir, ".git")):
|
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||||
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
logger.warn(
|
||||||
source_dir
|
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||||
))
|
source_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
||||||
@ -242,8 +298,11 @@ def check_git_hash(model_dir):
|
|||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
saved_hash = open(path).read()
|
saved_hash = open(path).read()
|
||||||
if saved_hash != cur_hash:
|
if saved_hash != cur_hash:
|
||||||
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
|
logger.warn(
|
||||||
saved_hash[:8], cur_hash[:8]))
|
"git hash values are different. {}(saved) != {}(current)".format(
|
||||||
|
saved_hash[:8], cur_hash[:8]
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
open(path, "w").write(cur_hash)
|
open(path, "w").write(cur_hash)
|
||||||
|
|
||||||
@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"):
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
class HParams():
|
class HParams:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if type(v) == dict:
|
if type(v) == dict:
|
||||||
@ -294,5 +353,10 @@ class HParams():
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__dict__.__repr__()
|
return self.__dict__.__repr__()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac'))
|
if __name__ == "__main__":
|
||||||
|
print(
|
||||||
|
load_wav_to_torch(
|
||||||
|
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user