Merge branch 'main' into patch-4

This commit is contained in:
RVC-Boss 2024-01-26 17:49:38 +08:00 committed by GitHub
commit bbae213502
2 changed files with 76 additions and 51 deletions

View File

@ -4,12 +4,30 @@ logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
gpt_path = os.environ.get(
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
)
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r',encoding="utf-8") as file:
gweight_data = file.read()
gpt_path = os.environ.get(
"gpt_path", gweight_data)
else:
gpt_path = os.environ.get(
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
if os.path.exists("./sweight.txt"):
with open("./sweight.txt", 'r',encoding="utf-8") as file:
sweight_data = file.read()
sovits_path = os.environ.get("sovits_path", sweight_data)
else:
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
# gpt_path = os.environ.get(
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# )
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
)
@ -60,7 +78,7 @@ def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model
inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
@ -124,6 +142,7 @@ def change_sovits_weights(sovits_path):
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
change_sovits_weights(sovits_path)
def change_gpt_weights(gpt_path):
@ -140,6 +159,7 @@ def change_gpt_weights(gpt_path):
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
change_gpt_weights(gpt_path)
def get_spepc(hps, filename):
@ -213,7 +233,7 @@ def nonen_clean_text_inf(text, language):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
phones_list.append(phones)
if lang=="en" or "ja":
if lang == "en" or "ja":
pass
else:
word2ph_list.append(word2ph)
@ -271,6 +291,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
if prompt_language == "en":
phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
else:
@ -281,11 +302,11 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
else:
bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if text_language == "en":
phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
else:

View File

@ -52,7 +52,8 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
paths = [path.name for path in paths]
for path in paths:
inp_path = os.path.join(inp_root, path)
need_reformat = 1
if(os.path.isfile(inp_path)==False):continue
try:
done = 0
try:
y, sr = librosa.load(inp_path, sr=None)
@ -97,6 +98,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
)
yield "\n".join(infos)
except:
infos.append("Oh my god. %s->%s"%(os.path.basename(inp_path), traceback.format_exc()))
yield "\n".join(infos)
except:
infos.append(traceback.format_exc())
yield "\n".join(infos)