mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Merge branch 'main' into patch-4
This commit is contained in:
commit
bbae213502
@ -4,12 +4,30 @@ logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
)
|
||||
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||
if os.path.exists("./gweight.txt"):
|
||||
with open("./gweight.txt", 'r',encoding="utf-8") as file:
|
||||
gweight_data = file.read()
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", gweight_data)
|
||||
else:
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
||||
|
||||
if os.path.exists("./sweight.txt"):
|
||||
with open("./sweight.txt", 'r',encoding="utf-8") as file:
|
||||
sweight_data = file.read()
|
||||
sovits_path = os.environ.get("sovits_path", sweight_data)
|
||||
else:
|
||||
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
|
||||
# gpt_path = os.environ.get(
|
||||
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
# )
|
||||
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||
cnhubert_base_path = os.environ.get(
|
||||
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
||||
)
|
||||
@ -60,7 +78,7 @@ def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
|
||||
inputs[i] = inputs[i].to(device)
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
@ -124,6 +142,7 @@ def change_sovits_weights(sovits_path):
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
|
||||
change_sovits_weights(sovits_path)
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
@ -140,6 +159,7 @@ def change_gpt_weights(gpt_path):
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
@ -188,19 +208,19 @@ def splite_en_inf(sentence, language):
|
||||
def clean_text_inf(text, language):
|
||||
phones, word2ph, norm_text = clean_text(text, language)
|
||||
phones = cleaned_text_to_sequence(phones)
|
||||
|
||||
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
|
||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
).to(device)
|
||||
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
).to(device)
|
||||
|
||||
return bert
|
||||
|
||||
|
||||
@ -213,7 +233,7 @@ def nonen_clean_text_inf(text, language):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||
phones_list.append(phones)
|
||||
if lang=="en" or "ja":
|
||||
if lang == "en" or "ja":
|
||||
pass
|
||||
else:
|
||||
word2ph_list.append(word2ph)
|
||||
@ -222,7 +242,7 @@ def nonen_clean_text_inf(text, language):
|
||||
phones = sum(phones_list, [])
|
||||
word2ph = sum(word2ph_list, [])
|
||||
norm_text = ' '.join(norm_text_list)
|
||||
|
||||
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
@ -238,7 +258,7 @@ def nonen_get_bert_inf(text, language):
|
||||
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
|
||||
|
||||
return bert
|
||||
|
||||
|
||||
@ -271,6 +291,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
if prompt_language == "en":
|
||||
phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
|
||||
else:
|
||||
@ -281,21 +302,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
|
||||
else:
|
||||
bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
|
||||
|
||||
for text in texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
|
||||
if text_language == "en":
|
||||
phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
|
||||
else:
|
||||
phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language)
|
||||
|
||||
|
||||
if text_language == "en":
|
||||
bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language)
|
||||
else:
|
||||
bert2 = nonen_get_bert_inf(text, text_language)
|
||||
|
||||
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
|
@ -52,39 +52,32 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
paths = [path.name for path in paths]
|
||||
for path in paths:
|
||||
inp_path = os.path.join(inp_root, path)
|
||||
need_reformat = 1
|
||||
done = 0
|
||||
if(os.path.isfile(inp_path)==False):continue
|
||||
try:
|
||||
y, sr = librosa.load(inp_path, sr=None)
|
||||
info = sf.info(inp_path)
|
||||
channels = info.channels
|
||||
if channels == 2 and sr == 44100:
|
||||
need_reformat = 0
|
||||
pre_fun._path_audio_(
|
||||
inp_path, save_root_ins, save_root_vocal, format0, is_hp3=is_hp3
|
||||
)
|
||||
done = 1
|
||||
else:
|
||||
done = 0
|
||||
try:
|
||||
y, sr = librosa.load(inp_path, sr=None)
|
||||
info = sf.info(inp_path)
|
||||
channels = info.channels
|
||||
if channels == 2 and sr == 44100:
|
||||
need_reformat = 0
|
||||
pre_fun._path_audio_(
|
||||
inp_path, save_root_ins, save_root_vocal, format0, is_hp3=is_hp3
|
||||
)
|
||||
done = 1
|
||||
else:
|
||||
need_reformat = 1
|
||||
except:
|
||||
need_reformat = 1
|
||||
except:
|
||||
need_reformat = 1
|
||||
traceback.print_exc()
|
||||
if need_reformat == 1:
|
||||
tmp_path = "%s/%s.reformatted.wav" % (
|
||||
os.path.join(os.environ["TEMP"]),
|
||||
os.path.basename(inp_path),
|
||||
)
|
||||
y_resampled = librosa.resample(y, sr, 44100)
|
||||
sf.write(tmp_path, y_resampled, 44100, "PCM_16")
|
||||
inp_path = tmp_path
|
||||
try:
|
||||
if done == 0:
|
||||
pre_fun._path_audio_(
|
||||
inp_path, save_root_ins, save_root_vocal, format0
|
||||
traceback.print_exc()
|
||||
if need_reformat == 1:
|
||||
tmp_path = "%s/%s.reformatted.wav" % (
|
||||
os.path.join(os.environ["TEMP"]),
|
||||
os.path.basename(inp_path),
|
||||
)
|
||||
infos.append("%s->Success" % (os.path.basename(inp_path)))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
y_resampled = librosa.resample(y, sr, 44100)
|
||||
sf.write(tmp_path, y_resampled, 44100, "PCM_16")
|
||||
inp_path = tmp_path
|
||||
try:
|
||||
if done == 0:
|
||||
pre_fun._path_audio_(
|
||||
@ -93,10 +86,21 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
infos.append("%s->Success" % (os.path.basename(inp_path)))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append(
|
||||
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
|
||||
)
|
||||
yield "\n".join(infos)
|
||||
try:
|
||||
if done == 0:
|
||||
pre_fun._path_audio_(
|
||||
inp_path, save_root_ins, save_root_vocal, format0
|
||||
)
|
||||
infos.append("%s->Success" % (os.path.basename(inp_path)))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append(
|
||||
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
|
||||
)
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append("Oh my god. %s->%s"%(os.path.basename(inp_path), traceback.format_exc()))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append(traceback.format_exc())
|
||||
yield "\n".join(infos)
|
||||
|
Loading…
x
Reference in New Issue
Block a user