Merge e9625c3a9edddd06ce87eda21665015113455bf6 into 959269b5ae2db5d0f0aead15b91c7e1e120f6303

This commit is contained in:
Ann-yang00 2024-04-12 23:59:44 +08:00 committed by GitHub
commit acfe2ddd6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

318
api.py
View File

@ -35,7 +35,8 @@ POST:
```json ```json
{ {
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
"text_language": "zh" "text_language": "zh",
"slice": "按标点符号切"
} }
``` ```
@ -80,6 +81,23 @@ RESP:
失败: json, 400 失败: json, 400
### 动态更换底模
endpoint: `/set_model`
GET:
`http://127.0.0.1:9880/set_model?gpt_model_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt&sovits_model_path=GPT_SoVITS/pretrained_models/s2G488k.pth`
POST:
```json
{
"gpt_model_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"sovits_model_path": "GPT_SoVITS/pretrained_models/s2G488k.pth"
}
```
RESP:
成功: json, http code 200
### 命令控制 ### 命令控制
endpoint: `/control` endpoint: `/control`
@ -103,7 +121,7 @@ RESP: 无
import argparse import argparse
import os import os, re
import sys import sys
now_dir = os.getcwd() now_dir = os.getcwd()
@ -126,6 +144,7 @@ from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
from text.cleaner import clean_text from text.cleaner import clean_text
import LangSegment
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from my_utils import load_audio from my_utils import load_audio
import config as global_config import config as global_config
@ -148,6 +167,7 @@ parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="defa
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
#parser.add_argument("-sl", "--slice", type=str, default="No slice", help="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
# bool值的用法为 `python ./api.py -fp ...` # bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False # 此时 full_precision==True, half_precision==False
@ -350,17 +370,133 @@ dict_language = {
"JA": "ja", "JA": "ja",
"zh": "zh", "zh": "zh",
"en": "en", "en": "en",
"ja": "ja" "ja": "ja",
"auto": "auto",
"中英混合": "zh",
"日英混合": "ja",
"多语种混合": "auto"
} }
slice_option = {
"凑四句一切": "凑四句一切",
"凑50字一切": "凑50字一切",
"按中文句号。切": "按中文句号。切",
"按英文句号.切": "按英文句号.切",
"按标点符号切": "按标点符号切",
"per 4 sentences": "凑四句一切",
"per 50 letters": "凑50字一切",
"per period": "按英文句号.切",
"per punctuation mark": "按标点符号切",
None: "No slice"
}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): dtype=torch.float16 if is_half == True else torch.float32
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
return bert
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
# Mixed language support
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones,bert.to(dtype),norm_text
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut):
# not supporting ref_free
t0 = ttime() t0 = ttime()
prompt_language = dict_language[prompt_language]
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n") if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) print(f"实际输入的参考文本: {prompt_text}")
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(f"实际输入的目标文本: {text}")
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32
)
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k, sr = librosa.load(ref_wav_path, sr=16000)
# neglected error checking for reference audio duration
wav16k = torch.from_numpy(wav16k) wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav) zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True): if (is_half == True):
@ -370,32 +506,51 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
wav16k = wav16k.to(device) wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device) zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch]) wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content) codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
t1 = ttime() t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language] # 文本切句
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) # default to no slice if no argument is provided
phones1 = cleaned_text_to_sequence(phones1) how_to_cut = slice_option[how_to_cut]
print(f"[INFO] 文本切句選項: {how_to_cut}")
if (how_to_cut == "凑四句一切"):
text = cut1(text)
elif (how_to_cut == "凑50字一切"):
text = cut2(text)
elif (how_to_cut == "按英文句号.切"):
text = cut3(text)
elif (how_to_cut == "按英文句号.切"):
text = cut4(text)
elif (how_to_cut == "按标点符号切"):
text = cut5(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print(f"实际输入的目标文本(切句后): {text}")
texts = text.split("\n") texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = [] audio_opt = []
for text in texts: phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if (prompt_language == "zh"):
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
device)
if (text_language == "zh"):
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(f"实际输入的目标文本(每句): {text}")
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
print(f"前端处理后的文本(每句): {norm_text2}")
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
@ -412,23 +567,118 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
early_stop_num=hz * max_sec) early_stop_num=hz * max_sec)
t3 = ttime() t3 = ttime()
# print(pred_semantic.shape,idx) # print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device) refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True): if (is_half == True):
refer = refer.half().to(device) refer = refer.half().to(device)
else: else:
refer = refer.to(device) refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \ audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
refer).detach().cpu().numpy()[ )
0, 0] ###试试重建不带上prompt部分 .detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav)
t4 = ttime() t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_first(text):
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
def handle_control(command): def handle_control(command):
if command == "restart": if command == "restart":
@ -457,7 +707,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200) return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice):
if ( if (
refer_wav_path == "" or refer_wav_path is None refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None or prompt_text == "" or prompt_text is None
@ -473,7 +723,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
with torch.no_grad(): with torch.no_grad():
gen = get_tts_wav( gen = get_tts_wav(
refer_wav_path, prompt_text, prompt_language, text, text_language refer_wav_path, prompt_text, prompt_language, text, text_language, slice
) )
sampling_rate, audio_data = next(gen) sampling_rate, audio_data = next(gen)
@ -541,6 +791,7 @@ async def tts_endpoint(request: Request):
json_post_raw.get("prompt_language"), json_post_raw.get("prompt_language"),
json_post_raw.get("text"), json_post_raw.get("text"),
json_post_raw.get("text_language"), json_post_raw.get("text_language"),
json_post_raw.get("slice"),
) )
@ -551,8 +802,9 @@ async def tts_endpoint(
prompt_language: str = None, prompt_language: str = None,
text: str = None, text: str = None,
text_language: str = None, text_language: str = None,
slice: str = None
): ):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language) return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice)
if __name__ == "__main__": if __name__ == "__main__":