Merge remote-tracking branch 'beta/fast_inference_'

This commit is contained in:
XTer 2024-03-12 22:36:05 +08:00
commit a057c697e7
2 changed files with 24 additions and 11 deletions

View File

@ -59,6 +59,8 @@ class TextPreprocessor:
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
if phones is None:
continue
res={
"phones": phones,
"bert_features": bert_features,
@ -79,12 +81,10 @@ class TextPreprocessor:
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print(i18n("实际输入的目标文本(切句后):"))
print(text)
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
@ -94,15 +94,21 @@ class TextPreprocessor:
if (text[-1] not in splits): text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
texts.extend(split_big_text(text))
if (len(text) > 510):
texts.extend(split_big_text(text))
else:
texts.append(text)
print(i18n("实际输入的目标文本(切句后):"))
print(texts)
return texts
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
textlist, langlist = self.seg_text(texts, language)
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
if len(textlist) == 0:
return None, None, None
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
return phones, bert_features, norm_text
@ -113,6 +119,8 @@ class TextPreprocessor:
if language in ["auto", "zh", "ja"]:
LangSegment.setfilters(["zh","ja","en","ko"])
for tmp in LangSegment.getTexts(text):
if tmp["text"] == "":
continue
if tmp["lang"] == "ko":
langlist.append("zh")
elif tmp["lang"] == "en":
@ -126,14 +134,18 @@ class TextPreprocessor:
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
while " " in formattext:
formattext = formattext.replace(" ", " ")
textlist.append(formattext)
langlist.append("en")
if formattext != "":
textlist.append(formattext)
langlist.append("en")
elif language in ["all_zh","all_ja"]:
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
language = language.replace("all_","")
if text == "":
return [],[]
textlist.append(formattext)
langlist.append(language)

View File

@ -45,8 +45,8 @@ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
@ -80,6 +80,7 @@ if cnhubert_base_path is not None:
if bert_path is not None:
tts_config.bert_base_path = bert_path
print(tts_config)
tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
@ -186,7 +187,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=1,interactive=True)
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)