修改了个远古bug,增加了更友好的提示信息

This commit is contained in:
ChasonJiang 2025-03-14 13:50:42 +08:00
parent f25d830ef4
commit 1f88d5b854
3 changed files with 22 additions and 17 deletions

View File

@ -116,7 +116,8 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found")
class NO_PROMPT_ERROR(Exception):
pass
# configs/tts_infer.yaml
@ -431,12 +432,12 @@ class TTS:
hps = dict_s2["config"]
hps["model"]["semantic_frame_rate"] = "25hz"
# if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
# hps["model"]["version"] = "v2"#v3model,v2sybomls
# elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
# hps["model"]["version"] = "v1"
# else:
# hps["model"]["version"] = "v2"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps["model"]["version"] = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps["model"]["version"] = "v1"
else:
hps["model"]["version"] = "v2"
# version = hps["model"]["version"]
self.configs.filter_length = hps["data"]["filter_length"]
@ -451,7 +452,8 @@ class TTS:
self.configs.update_version(model_version)
# print(f"model_version:{model_version}")
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version!="v3":
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
@ -926,7 +928,7 @@ class TTS:
assert prompt_lang in self.configs.languages
if no_prompt_text and self.configs.is_v3_synthesizer:
raise RuntimeError("prompt_text cannot be empty when using SoVITS_V3")
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
@ -957,13 +959,13 @@ class TTS:
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
if self.prompt_cache["prompt_text"] != prompt_text:
self.prompt_cache["prompt_text"] = prompt_text
self.prompt_cache["prompt_lang"] = prompt_lang
phones, bert_features, norm_text = \
self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text,
prompt_lang,
self.configs.version)
self.prompt_cache["prompt_text"] = prompt_text
self.prompt_cache["prompt_lang"] = prompt_lang
self.prompt_cache["phones"] = phones
self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text

View File

@ -3,9 +3,9 @@ custom:
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda
is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base

View File

@ -44,7 +44,7 @@ bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2")
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto, scan_language_list
@ -153,8 +153,11 @@ def inference(text, text_lang,
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
}
for item in tts_pipeline.run(inputs):
yield item, actual_seed
try:
for item in tts_pipeline.run(inputs):
yield item, actual_seed
except NO_PROMPT_ERROR:
gr.Warning(i18n('V3不支持无参考文本模式请填写参考文本'))
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分