diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 57d72a4..df9f104 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -116,7 +116,8 @@ class DictToAttrRecursive(dict): raise AttributeError(f"Attribute {item} not found") - +class NO_PROMPT_ERROR(Exception): + pass # configs/tts_infer.yaml @@ -431,12 +432,12 @@ class TTS: hps = dict_s2["config"] hps["model"]["semantic_frame_rate"] = "25hz" - # if 'enc_p.text_embedding.weight'not in dict_s2['weight']: - # hps["model"]["version"] = "v2"#v3model,v2sybomls - # elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: - # hps["model"]["version"] = "v1" - # else: - # hps["model"]["version"] = "v2" + if 'enc_p.text_embedding.weight'not in dict_s2['weight']: + hps["model"]["version"] = "v2"#v3model,v2sybomls + elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + hps["model"]["version"] = "v1" + else: + hps["model"]["version"] = "v2" # version = hps["model"]["version"] self.configs.filter_length = hps["data"]["filter_length"] @@ -451,7 +452,8 @@ class TTS: self.configs.update_version(model_version) - + # print(f"model_version:{model_version}") + # print(f'hps["model"]["version"]:{hps["model"]["version"]}') if model_version!="v3": vits_model = SynthesizerTrn( self.configs.filter_length // 2 + 1, @@ -926,7 +928,7 @@ class TTS: assert prompt_lang in self.configs.languages if no_prompt_text and self.configs.is_v3_synthesizer: - raise RuntimeError("prompt_text cannot be empty when using SoVITS_V3") + raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3") if ref_audio_path in [None, ""] and \ ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])): @@ -957,13 +959,13 @@ class TTS: if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." print(i18n("实际输入的参考文本:"), prompt_text) if self.prompt_cache["prompt_text"] != prompt_text: - self.prompt_cache["prompt_text"] = prompt_text - self.prompt_cache["prompt_lang"] = prompt_lang phones, bert_features, norm_text = \ self.text_preprocessor.segment_and_extract_feature_for_text( prompt_text, prompt_lang, self.configs.version) + self.prompt_cache["prompt_text"] = prompt_text + self.prompt_cache["prompt_lang"] = prompt_lang self.prompt_cache["phones"] = phones self.prompt_cache["bert_features"] = bert_features self.prompt_cache["norm_text"] = norm_text diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 578a794..34878ae 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -3,9 +3,9 @@ custom: cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cuda is_half: true - t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt - version: v3 - vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth + t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt + version: v2 + vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth v1: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 703df7e..0578bab 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -44,7 +44,7 @@ bert_path = os.environ.get("bert_path", None) version=os.environ.get("version","v2") import gradio as gr -from TTS_infer_pack.TTS import TTS, TTS_Config +from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR from TTS_infer_pack.text_segmentation_method import get_method from tools.i18n.i18n import I18nAuto, scan_language_list @@ -153,8 +153,11 @@ def inference(text, text_lang, "sample_steps": int(sample_steps), "super_sampling": super_sampling, } - for item in tts_pipeline.run(inputs): - yield item, actual_seed + try: + for item in tts_pipeline.run(inputs): + yield item, actual_seed + except NO_PROMPT_ERROR: + gr.Warning(i18n('V3不支持无参考文本模式,请填写参考文本!')) def custom_sort_key(s): # 使用正则表达式提取字符串中的数字部分和非数字部分