mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-08 16:00:01 +08:00
Merge e9625c3a9edddd06ce87eda21665015113455bf6 into 959269b5ae2db5d0f0aead15b91c7e1e120f6303
This commit is contained in:
commit
acfe2ddd6a
320
api.py
320
api.py
@ -35,7 +35,8 @@ POST:
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
||||||
"text_language": "zh"
|
"text_language": "zh",
|
||||||
|
"slice": "按标点符号切"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -80,6 +81,23 @@ RESP:
|
|||||||
失败: json, 400
|
失败: json, 400
|
||||||
|
|
||||||
|
|
||||||
|
### 动态更换底模
|
||||||
|
|
||||||
|
endpoint: `/set_model`
|
||||||
|
|
||||||
|
GET:
|
||||||
|
`http://127.0.0.1:9880/set_model?gpt_model_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt&sovits_model_path=GPT_SoVITS/pretrained_models/s2G488k.pth`
|
||||||
|
POST:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"gpt_model_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||||
|
"sovits_model_path": "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
RESP:
|
||||||
|
成功: json, http code 200
|
||||||
|
|
||||||
### 命令控制
|
### 命令控制
|
||||||
|
|
||||||
endpoint: `/control`
|
endpoint: `/control`
|
||||||
@ -103,7 +121,7 @@ RESP: 无
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os, re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
@ -126,6 +144,7 @@ from module.models import SynthesizerTrn
|
|||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
from text.cleaner import clean_text
|
from text.cleaner import clean_text
|
||||||
|
import LangSegment
|
||||||
from module.mel_processing import spectrogram_torch
|
from module.mel_processing import spectrogram_torch
|
||||||
from my_utils import load_audio
|
from my_utils import load_audio
|
||||||
import config as global_config
|
import config as global_config
|
||||||
@ -148,6 +167,7 @@ parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="defa
|
|||||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||||
|
#parser.add_argument("-sl", "--slice", type=str, default="No slice", help="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
|
||||||
# bool值的用法为 `python ./api.py -fp ...`
|
# bool值的用法为 `python ./api.py -fp ...`
|
||||||
# 此时 full_precision==True, half_precision==False
|
# 此时 full_precision==True, half_precision==False
|
||||||
|
|
||||||
@ -350,17 +370,133 @@ dict_language = {
|
|||||||
"JA": "ja",
|
"JA": "ja",
|
||||||
"zh": "zh",
|
"zh": "zh",
|
||||||
"en": "en",
|
"en": "en",
|
||||||
"ja": "ja"
|
"ja": "ja",
|
||||||
|
"auto": "auto",
|
||||||
|
"中英混合": "zh",
|
||||||
|
"日英混合": "ja",
|
||||||
|
"多语种混合": "auto"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slice_option = {
|
||||||
|
"凑四句一切": "凑四句一切",
|
||||||
|
"凑50字一切": "凑50字一切",
|
||||||
|
"按中文句号。切": "按中文句号。切",
|
||||||
|
"按英文句号.切": "按英文句号.切",
|
||||||
|
"按标点符号切": "按标点符号切",
|
||||||
|
"per 4 sentences": "凑四句一切",
|
||||||
|
"per 50 letters": "凑50字一切",
|
||||||
|
"per period": "按英文句号.切",
|
||||||
|
"per punctuation mark": "按标点符号切",
|
||||||
|
None: "No slice"
|
||||||
|
}
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
dtype=torch.float16 if is_half == True else torch.float32
|
||||||
|
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||||
|
language=language.replace("all_","")
|
||||||
|
if language == "zh":
|
||||||
|
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
|
||||||
|
else:
|
||||||
|
bert = torch.zeros(
|
||||||
|
(1024, len(phones)),
|
||||||
|
dtype=torch.float16 if is_half == True else torch.float32,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
return bert
|
||||||
|
|
||||||
|
def clean_text_inf(text, language):
|
||||||
|
phones, word2ph, norm_text = clean_text(text, language)
|
||||||
|
phones = cleaned_text_to_sequence(phones)
|
||||||
|
return phones, word2ph, norm_text
|
||||||
|
|
||||||
|
# Mixed language support
|
||||||
|
def get_phones_and_bert(text,language):
|
||||||
|
if language in {"en","all_zh","all_ja"}:
|
||||||
|
language = language.replace("all_","")
|
||||||
|
if language == "en":
|
||||||
|
LangSegment.setfilters(["en"])
|
||||||
|
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||||
|
else:
|
||||||
|
# 因无法区别中日文汉字,以用户输入为准
|
||||||
|
formattext = text
|
||||||
|
while " " in formattext:
|
||||||
|
formattext = formattext.replace(" ", " ")
|
||||||
|
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||||
|
if language == "zh":
|
||||||
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||||
|
else:
|
||||||
|
bert = torch.zeros(
|
||||||
|
(1024, len(phones)),
|
||||||
|
dtype=torch.float16 if is_half == True else torch.float32,
|
||||||
|
).to(device)
|
||||||
|
elif language in {"zh", "ja","auto"}:
|
||||||
|
textlist=[]
|
||||||
|
langlist=[]
|
||||||
|
LangSegment.setfilters(["zh","ja","en"])
|
||||||
|
if language == "auto":
|
||||||
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
else:
|
||||||
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
if tmp["lang"] == "en":
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
else:
|
||||||
|
# 因无法区别中日文汉字,以用户输入为准
|
||||||
|
langlist.append(language)
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
print(textlist)
|
||||||
|
print(langlist)
|
||||||
|
phones_list = []
|
||||||
|
bert_list = []
|
||||||
|
norm_text_list = []
|
||||||
|
for i in range(len(textlist)):
|
||||||
|
lang = langlist[i]
|
||||||
|
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||||
|
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||||
|
phones_list.append(phones)
|
||||||
|
norm_text_list.append(norm_text)
|
||||||
|
bert_list.append(bert)
|
||||||
|
bert = torch.cat(bert_list, dim=1)
|
||||||
|
phones = sum(phones_list, [])
|
||||||
|
norm_text = ''.join(norm_text_list)
|
||||||
|
|
||||||
|
return phones,bert.to(dtype),norm_text
|
||||||
|
|
||||||
|
def merge_short_text_in_array(texts, threshold):
|
||||||
|
if (len(texts)) < 2:
|
||||||
|
return texts
|
||||||
|
result = []
|
||||||
|
text = ""
|
||||||
|
for ele in texts:
|
||||||
|
text += ele
|
||||||
|
if len(text) >= threshold:
|
||||||
|
result.append(text)
|
||||||
|
text = ""
|
||||||
|
if (len(text) > 0):
|
||||||
|
if len(result) == 0:
|
||||||
|
result.append(text)
|
||||||
|
else:
|
||||||
|
result[len(result) - 1] += text
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut):
|
||||||
|
# not supporting ref_free
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
|
prompt_language = dict_language[prompt_language]
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
print(f"实际输入的参考文本: {prompt_text}")
|
||||||
|
text = text.strip("\n")
|
||||||
|
if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
|
||||||
|
print(f"实际输入的目标文本: {text}")
|
||||||
|
|
||||||
|
zero_wav = np.zeros(
|
||||||
|
int(hps.data.sampling_rate * 0.3),
|
||||||
|
dtype=np.float16 if is_half == True else np.float32
|
||||||
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||||
|
# neglected error checking for reference audio duration
|
||||||
wav16k = torch.from_numpy(wav16k)
|
wav16k = torch.from_numpy(wav16k)
|
||||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||||
if (is_half == True):
|
if (is_half == True):
|
||||||
@ -370,32 +506,51 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
wav16k = wav16k.to(device)
|
wav16k = wav16k.to(device)
|
||||||
zero_wav_torch = zero_wav_torch.to(device)
|
zero_wav_torch = zero_wav_torch.to(device)
|
||||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||||
|
"last_hidden_state"
|
||||||
|
].transpose(
|
||||||
|
1, 2
|
||||||
|
) # .float()
|
||||||
codes = vq_model.extract_latent(ssl_content)
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
|
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
prompt_language = dict_language[prompt_language]
|
|
||||||
text_language = dict_language[text_language]
|
# 文本切句
|
||||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
# default to no slice if no argument is provided
|
||||||
phones1 = cleaned_text_to_sequence(phones1)
|
how_to_cut = slice_option[how_to_cut]
|
||||||
|
print(f"[INFO] 文本切句選項: {how_to_cut}")
|
||||||
|
if (how_to_cut == "凑四句一切"):
|
||||||
|
text = cut1(text)
|
||||||
|
elif (how_to_cut == "凑50字一切"):
|
||||||
|
text = cut2(text)
|
||||||
|
elif (how_to_cut == "按英文句号.切"):
|
||||||
|
text = cut3(text)
|
||||||
|
elif (how_to_cut == "按英文句号.切"):
|
||||||
|
text = cut4(text)
|
||||||
|
elif (how_to_cut == "按标点符号切"):
|
||||||
|
text = cut5(text)
|
||||||
|
|
||||||
|
while "\n\n" in text:
|
||||||
|
text = text.replace("\n\n", "\n")
|
||||||
|
print(f"实际输入的目标文本(切句后): {text}")
|
||||||
texts = text.split("\n")
|
texts = text.split("\n")
|
||||||
|
texts = merge_short_text_in_array(texts, 5)
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
|
|
||||||
for text in texts:
|
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
|
||||||
phones2 = cleaned_text_to_sequence(phones2)
|
for text in texts:
|
||||||
if (prompt_language == "zh"):
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
if (len(text.strip()) == 0):
|
||||||
else:
|
continue
|
||||||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
|
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||||
device)
|
print(f"实际输入的目标文本(每句): {text}")
|
||||||
if (text_language == "zh"):
|
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
||||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
print(f"前端处理后的文本(每句): {norm_text2}")
|
||||||
else:
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
|
||||||
|
|
||||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
|
||||||
bert = bert.to(device).unsqueeze(0)
|
bert = bert.to(device).unsqueeze(0)
|
||||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
@ -412,23 +567,118 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
early_stop_num=hz * max_sec)
|
early_stop_num=hz * max_sec)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
# print(pred_semantic.shape,idx)
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
|
||||||
|
0
|
||||||
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||||
if (is_half == True):
|
if (is_half == True):
|
||||||
refer = refer.half().to(device)
|
refer = refer.half().to(device)
|
||||||
else:
|
else:
|
||||||
refer = refer.to(device)
|
refer = refer.to(device)
|
||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||||
audio = \
|
audio = (
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
|
||||||
refer).detach().cpu().numpy()[
|
)
|
||||||
0, 0] ###试试重建不带上prompt部分
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()[0, 0]
|
||||||
|
) ###试试重建不带上prompt部分
|
||||||
|
max_audio=np.abs(audio).max()#简单防止16bit爆音
|
||||||
|
if max_audio>1:audio/=max_audio
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
|
||||||
|
np.int16
|
||||||
|
)
|
||||||
|
|
||||||
|
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||||
|
|
||||||
|
def get_first(text):
|
||||||
|
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||||
|
text = re.split(pattern, text)[0].strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
def split(todo_text):
|
||||||
|
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||||
|
if todo_text[-1] not in splits:
|
||||||
|
todo_text += "。"
|
||||||
|
i_split_head = i_split_tail = 0
|
||||||
|
len_text = len(todo_text)
|
||||||
|
todo_texts = []
|
||||||
|
while 1:
|
||||||
|
if i_split_head >= len_text:
|
||||||
|
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||||
|
if todo_text[i_split_head] in splits:
|
||||||
|
i_split_head += 1
|
||||||
|
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||||
|
i_split_tail = i_split_head
|
||||||
|
else:
|
||||||
|
i_split_head += 1
|
||||||
|
return todo_texts
|
||||||
|
|
||||||
|
def cut1(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
inps = split(inp)
|
||||||
|
split_idx = list(range(0, len(inps), 4))
|
||||||
|
split_idx[-1] = None
|
||||||
|
if len(split_idx) > 1:
|
||||||
|
opts = []
|
||||||
|
for idx in range(len(split_idx) - 1):
|
||||||
|
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||||
|
else:
|
||||||
|
opts = [inp]
|
||||||
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def cut2(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
inps = split(inp)
|
||||||
|
if len(inps) < 2:
|
||||||
|
return inp
|
||||||
|
opts = []
|
||||||
|
summ = 0
|
||||||
|
tmp_str = ""
|
||||||
|
for i in range(len(inps)):
|
||||||
|
summ += len(inps[i])
|
||||||
|
tmp_str += inps[i]
|
||||||
|
if summ > 50:
|
||||||
|
summ = 0
|
||||||
|
opts.append(tmp_str)
|
||||||
|
tmp_str = ""
|
||||||
|
if tmp_str != "":
|
||||||
|
opts.append(tmp_str)
|
||||||
|
# print(opts)
|
||||||
|
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||||
|
opts[-2] = opts[-2] + opts[-1]
|
||||||
|
opts = opts[:-1]
|
||||||
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def cut3(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
|
||||||
|
|
||||||
|
|
||||||
|
def cut4(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
|
||||||
|
|
||||||
|
|
||||||
|
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||||
|
def cut5(inp):
|
||||||
|
# if not re.search(r'[^\w\s]', inp[-1]):
|
||||||
|
# inp += '。'
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
punds = r'[,.;?!、,。?!;:…]'
|
||||||
|
items = re.split(f'({punds})', inp)
|
||||||
|
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
|
||||||
|
# 在句子不存在符号或句尾无符号的时候保证文本完整
|
||||||
|
if len(items)%2 == 1:
|
||||||
|
mergeitems.append(items[-1])
|
||||||
|
opt = "\n".join(mergeitems)
|
||||||
|
return opt
|
||||||
|
|
||||||
def handle_control(command):
|
def handle_control(command):
|
||||||
if command == "restart":
|
if command == "restart":
|
||||||
@ -457,7 +707,7 @@ def handle_change(path, text, language):
|
|||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice):
|
||||||
if (
|
if (
|
||||||
refer_wav_path == "" or refer_wav_path is None
|
refer_wav_path == "" or refer_wav_path is None
|
||||||
or prompt_text == "" or prompt_text is None
|
or prompt_text == "" or prompt_text is None
|
||||||
@ -473,7 +723,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gen = get_tts_wav(
|
gen = get_tts_wav(
|
||||||
refer_wav_path, prompt_text, prompt_language, text, text_language
|
refer_wav_path, prompt_text, prompt_language, text, text_language, slice
|
||||||
)
|
)
|
||||||
sampling_rate, audio_data = next(gen)
|
sampling_rate, audio_data = next(gen)
|
||||||
|
|
||||||
@ -541,6 +791,7 @@ async def tts_endpoint(request: Request):
|
|||||||
json_post_raw.get("prompt_language"),
|
json_post_raw.get("prompt_language"),
|
||||||
json_post_raw.get("text"),
|
json_post_raw.get("text"),
|
||||||
json_post_raw.get("text_language"),
|
json_post_raw.get("text_language"),
|
||||||
|
json_post_raw.get("slice"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -551,8 +802,9 @@ async def tts_endpoint(
|
|||||||
prompt_language: str = None,
|
prompt_language: str = None,
|
||||||
text: str = None,
|
text: str = None,
|
||||||
text_language: str = None,
|
text_language: str = None,
|
||||||
|
slice: str = None
|
||||||
):
|
):
|
||||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language)
|
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user