Merge 4afa9bf4733c26fea7ee8389962c6d9faab15500 into 7fc216107baadebc6fd5ae4d9c6ad9e0f8bb5f7c

This commit is contained in:
刘悦 2024-02-16 11:41:01 +08:00 committed by GitHub
commit 1ccbfc5f55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 145 deletions

View File

@ -365,7 +365,7 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=5, top_p=1, temperature=1):
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
@ -591,7 +591,9 @@ def get_weights_names():
SoVITS_names, GPT_names = get_weights_names()
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
def main():
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
)
@ -652,10 +654,13 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
button5.click(cut5, [text_inp], [text_opt])
gr.Markdown(value=i18n("后续将支持混合语种编码文本输入。"))
app.queue(concurrency_count=511, max_size=1022).launch(
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
share=is_share,
server_port=infer_ttswebui,
quiet=True,
)
)
if __name__ == '__main__':
main()

81
api.py
View File

@ -30,7 +30,7 @@ endpoint: `/`
使用执行参数指定的参考音频:
GET:
`http://127.0.0.1:9880?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh`
`http://127.0.0.1:9880?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=中文`
POST:
```json
{
@ -41,7 +41,7 @@ POST:
手动指定当次推理所使用的参考音频:
GET:
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh`
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=中文&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=中文`
POST:
```json
{
@ -129,6 +129,7 @@ from text.cleaner import clean_text
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
import config as global_config
from inference_webui import get_tts_wav
g_config = global_config.Config()
@ -316,82 +317,6 @@ dict_language = {
}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True):
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if (prompt_language == "zh"):
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
device)
if (text_language == "zh"):
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
def handle_control(command):
if command == "restart":
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)