mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Merge branch 'main' into patch-4
This commit is contained in:
commit
bbae213502
@ -4,12 +4,30 @@ logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
)
|
||||
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||
if os.path.exists("./gweight.txt"):
|
||||
with open("./gweight.txt", 'r',encoding="utf-8") as file:
|
||||
gweight_data = file.read()
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", gweight_data)
|
||||
else:
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
||||
|
||||
if os.path.exists("./sweight.txt"):
|
||||
with open("./sweight.txt", 'r',encoding="utf-8") as file:
|
||||
sweight_data = file.read()
|
||||
sovits_path = os.environ.get("sovits_path", sweight_data)
|
||||
else:
|
||||
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
|
||||
# gpt_path = os.environ.get(
|
||||
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
# )
|
||||
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||
cnhubert_base_path = os.environ.get(
|
||||
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
||||
)
|
||||
@ -60,7 +78,7 @@ def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
|
||||
inputs[i] = inputs[i].to(device)
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
@ -124,6 +142,7 @@ def change_sovits_weights(sovits_path):
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
|
||||
change_sovits_weights(sovits_path)
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
@ -140,6 +159,7 @@ def change_gpt_weights(gpt_path):
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
@ -213,7 +233,7 @@ def nonen_clean_text_inf(text, language):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||
phones_list.append(phones)
|
||||
if lang=="en" or "ja":
|
||||
if lang == "en" or "ja":
|
||||
pass
|
||||
else:
|
||||
word2ph_list.append(word2ph)
|
||||
@ -271,6 +291,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
if prompt_language == "en":
|
||||
phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
|
||||
else:
|
||||
@ -281,11 +302,11 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
|
||||
else:
|
||||
bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
|
||||
|
||||
for text in texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
|
||||
if text_language == "en":
|
||||
phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
|
||||
else:
|
||||
|
@ -52,7 +52,8 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
paths = [path.name for path in paths]
|
||||
for path in paths:
|
||||
inp_path = os.path.join(inp_root, path)
|
||||
need_reformat = 1
|
||||
if(os.path.isfile(inp_path)==False):continue
|
||||
try:
|
||||
done = 0
|
||||
try:
|
||||
y, sr = librosa.load(inp_path, sr=None)
|
||||
@ -97,6 +98,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
|
||||
)
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append("Oh my god. %s->%s"%(os.path.basename(inp_path), traceback.format_exc()))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append(traceback.format_exc())
|
||||
yield "\n".join(infos)
|
||||
|
Loading…
x
Reference in New Issue
Block a user