Merge branch 'main' into patch-4

This commit is contained in:
RVC-Boss 2024-01-26 17:49:38 +08:00 committed by GitHub
commit bbae213502
2 changed files with 76 additions and 51 deletions

View File

@ -4,12 +4,30 @@ logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR) logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR) logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb import pdb
gpt_path = os.environ.get( if os.path.exists("./gweight.txt"):
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" with open("./gweight.txt", 'r',encoding="utf-8") as file:
) gweight_data = file.read()
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth") gpt_path = os.environ.get(
"gpt_path", gweight_data)
else:
gpt_path = os.environ.get(
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
if os.path.exists("./sweight.txt"):
with open("./sweight.txt", 'r',encoding="utf-8") as file:
sweight_data = file.read()
sovits_path = os.environ.get("sovits_path", sweight_data)
else:
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
# gpt_path = os.environ.get(
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# )
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get( cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "pretrained_models/chinese-hubert-base" "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
) )
@ -60,7 +78,7 @@ def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt") inputs = tokenizer(text, return_tensors="pt")
for i in inputs: for i in inputs:
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True) res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text) assert len(word2ph) == len(text)
@ -124,6 +142,7 @@ def change_sovits_weights(sovits_path):
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
change_sovits_weights(sovits_path) change_sovits_weights(sovits_path)
def change_gpt_weights(gpt_path): def change_gpt_weights(gpt_path):
@ -140,6 +159,7 @@ def change_gpt_weights(gpt_path):
t2s_model.eval() t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()]) total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6)) print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
change_gpt_weights(gpt_path) change_gpt_weights(gpt_path)
def get_spepc(hps, filename): def get_spepc(hps, filename):
@ -188,19 +208,19 @@ def splite_en_inf(sentence, language):
def clean_text_inf(text, language): def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language) phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones) phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text return phones, word2ph, norm_text
def get_bert_inf(phones, word2ph, norm_text, language): def get_bert_inf(phones, word2ph, norm_text, language):
if language == "zh": if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device) bert = get_bert_feature(norm_text, word2ph).to(device)
else: else:
bert = torch.zeros( bert = torch.zeros(
(1024, len(phones)), (1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32, dtype=torch.float16 if is_half == True else torch.float32,
).to(device) ).to(device)
return bert return bert
@ -213,7 +233,7 @@ def nonen_clean_text_inf(text, language):
lang = langlist[i] lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
phones_list.append(phones) phones_list.append(phones)
if lang=="en" or "ja": if lang == "en" or "ja":
pass pass
else: else:
word2ph_list.append(word2ph) word2ph_list.append(word2ph)
@ -222,7 +242,7 @@ def nonen_clean_text_inf(text, language):
phones = sum(phones_list, []) phones = sum(phones_list, [])
word2ph = sum(word2ph_list, []) word2ph = sum(word2ph_list, [])
norm_text = ' '.join(norm_text_list) norm_text = ' '.join(norm_text_list)
return phones, word2ph, norm_text return phones, word2ph, norm_text
@ -238,7 +258,7 @@ def nonen_get_bert_inf(text, language):
bert = get_bert_inf(phones, word2ph, norm_text, lang) bert = get_bert_inf(phones, word2ph, norm_text, lang)
bert_list.append(bert) bert_list.append(bert)
bert = torch.cat(bert_list, dim=1) bert = torch.cat(bert_list, dim=1)
return bert return bert
@ -271,6 +291,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
t1 = ttime() t1 = ttime()
prompt_language = dict_language[prompt_language] prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language] text_language = dict_language[text_language]
if prompt_language == "en": if prompt_language == "en":
phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language) phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
else: else:
@ -281,21 +302,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language) bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
else: else:
bert1 = nonen_get_bert_inf(prompt_text, prompt_language) bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
for text in texts: for text in texts:
# 解决输入目标文本的空行导致报错的问题 # 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0): if (len(text.strip()) == 0):
continue continue
if text_language == "en": if text_language == "en":
phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language) phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
else: else:
phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language) phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language)
if text_language == "en": if text_language == "en":
bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language) bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language)
else: else:
bert2 = nonen_get_bert_inf(text, text_language) bert2 = nonen_get_bert_inf(text, text_language)
bert = torch.cat([bert1, bert2], 1) bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)

View File

@ -52,39 +52,32 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
paths = [path.name for path in paths] paths = [path.name for path in paths]
for path in paths: for path in paths:
inp_path = os.path.join(inp_root, path) inp_path = os.path.join(inp_root, path)
need_reformat = 1 if(os.path.isfile(inp_path)==False):continue
done = 0
try: try:
y, sr = librosa.load(inp_path, sr=None) done = 0
info = sf.info(inp_path) try:
channels = info.channels y, sr = librosa.load(inp_path, sr=None)
if channels == 2 and sr == 44100: info = sf.info(inp_path)
need_reformat = 0 channels = info.channels
pre_fun._path_audio_( if channels == 2 and sr == 44100:
inp_path, save_root_ins, save_root_vocal, format0, is_hp3=is_hp3 need_reformat = 0
) pre_fun._path_audio_(
done = 1 inp_path, save_root_ins, save_root_vocal, format0, is_hp3=is_hp3
else: )
done = 1
else:
need_reformat = 1
except:
need_reformat = 1 need_reformat = 1
except: traceback.print_exc()
need_reformat = 1 if need_reformat == 1:
traceback.print_exc() tmp_path = "%s/%s.reformatted.wav" % (
if need_reformat == 1: os.path.join(os.environ["TEMP"]),
tmp_path = "%s/%s.reformatted.wav" % ( os.path.basename(inp_path),
os.path.join(os.environ["TEMP"]),
os.path.basename(inp_path),
)
y_resampled = librosa.resample(y, sr, 44100)
sf.write(tmp_path, y_resampled, 44100, "PCM_16")
inp_path = tmp_path
try:
if done == 0:
pre_fun._path_audio_(
inp_path, save_root_ins, save_root_vocal, format0
) )
infos.append("%s->Success" % (os.path.basename(inp_path))) y_resampled = librosa.resample(y, sr, 44100)
yield "\n".join(infos) sf.write(tmp_path, y_resampled, 44100, "PCM_16")
except: inp_path = tmp_path
try: try:
if done == 0: if done == 0:
pre_fun._path_audio_( pre_fun._path_audio_(
@ -93,10 +86,21 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
infos.append("%s->Success" % (os.path.basename(inp_path))) infos.append("%s->Success" % (os.path.basename(inp_path)))
yield "\n".join(infos) yield "\n".join(infos)
except: except:
infos.append( try:
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc()) if done == 0:
) pre_fun._path_audio_(
yield "\n".join(infos) inp_path, save_root_ins, save_root_vocal, format0
)
infos.append("%s->Success" % (os.path.basename(inp_path)))
yield "\n".join(infos)
except:
infos.append(
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
)
yield "\n".join(infos)
except:
infos.append("Oh my god. %s->%s"%(os.path.basename(inp_path), traceback.format_exc()))
yield "\n".join(infos)
except: except:
infos.append(traceback.format_exc()) infos.append(traceback.format_exc())
yield "\n".join(infos) yield "\n".join(infos)