From 7de7f6816137c8f0cae35bdafa8f0fe136be867b Mon Sep 17 00:00:00 2001 From: Leon <88978827@qq.com> Date: Thu, 15 Aug 2024 16:17:09 +0800 Subject: [PATCH] update dev branch --- GPT_SoVITS/download.py | 2 +- quick_inference.py | 60 +++++++++++++++++++++++++++++++++++------- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/GPT_SoVITS/download.py b/GPT_SoVITS/download.py index b6393e2c..e7c8c97b 100644 --- a/GPT_SoVITS/download.py +++ b/GPT_SoVITS/download.py @@ -1,5 +1,5 @@ import os, sys now_dir = os.getcwd() sys.path.insert(0, now_dir) -from .text.g2pw import G2PWPinyin +from text.g2pw import G2PWPinyin g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True) \ No newline at end of file diff --git a/quick_inference.py b/quick_inference.py index ded1f3eb..2fb49049 100644 --- a/quick_inference.py +++ b/quick_inference.py @@ -1,6 +1,7 @@ import os, re, logging import LangSegment import pdb +import json import torch import gradio as gr from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -21,7 +22,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu" i18n = I18nAuto() -dict_language = { +version = os.environ.get("version", "v2") +dict_language_v1 = { i18n("中文"): "all_zh", # 全部按中文识别 i18n("英文"): "en", # 全部按英文识别#######不变 i18n("日文"): "all_ja", # 全部按日文识别 @@ -29,6 +31,20 @@ dict_language = { i18n("日英混合"): "ja", # 按日英混合识别####不变 i18n("多语种混合"): "auto", # 多语种启动切分识别语种 } +dict_language_v2 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("粤语"): "all_yue", # 全部按中文识别 + i18n("韩文"): "all_ko", # 全部按韩文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("粤英混合"): "yue", # 按粤英混合识别####不变 + i18n("韩英混合"): "ko", # 按韩英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 + i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种 +} +dict_language = dict_language_v1 if version == 'v1' else dict_language_v2 is_share = os.environ.get("is_share", "False") is_share = eval(is_share) @@ -283,12 +299,18 @@ def set_gpt_weights(gpt_path): with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) -def set_sovits_weights(sovits_path): - global vq_model, hps +def set_sovits_weights(sovits_path, prompt_language=None, text_language=None): + global vq_model, hps, version, dict_language dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" + if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + version = hps.model.version + print("sovits版本:", hps.model.version) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -303,8 +325,26 @@ def set_sovits_weights(sovits_path): vq_model = vq_model.to(device) vq_model.eval() print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) - with open("./sweight.txt", "w", encoding="utf-8") as f: - f.write(sovits_path) + dict_language = dict_language_v1 if version == 'v1' else dict_language_v2 + with open("./weight.json") as f: + data = f.read() + data = json.loads(data) + data["SoVITS"][version] = sovits_path + with open("./weight.json", "w") as f: + f.write(json.dumps(data)) + if prompt_language is not None and text_language is not None: + if prompt_language in list(dict_language.keys()): + prompt_text_update, prompt_language_update = {'__type__': 'update'}, {'__type__': 'update', 'value': prompt_language} + else: + prompt_text_update = {'__type__': 'update', 'value': ''} + prompt_language_update = {'__type__': 'update', 'value': i18n("中文")} + if text_language in list(dict_language.keys()): + text_update, text_language_update = {'__type__': 'update'}, {'__type__': 'update', 'value': text_language} + else: + text_update = {'__type__': 'update', 'value': ''} + text_language_update = {'__type__': 'update', 'value': i18n("中文")} + return {'__type__': 'update', 'choices': list(dict_language.keys())}, {'__type__': 'update', 'choices': list( + dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update def gen_audio(ref_wav_path, prompt_text, text_to_speak, output_file, top_k=20, top_p=0.6, temperature=0.6, ref_free=False): @@ -454,12 +494,12 @@ else: def speak(text_to_speak): - sovits_path = "SoVITS_weights/阿贝多_e12_s2748.pth" + sovits_path = "SoVITS_weights/迪希雅_e15_s1050.pth" set_sovits_weights(sovits_path) - gpt_path = "GPT_weights/阿贝多-e10.ckpt" + gpt_path = "GPT_weights/迪希雅-e10.ckpt" set_gpt_weights(gpt_path) - ref_wav_path = "audio/首先,先看看这不明来源的元素力,究竟是如何对外流动的.wav" - prompt_text = "首先,先看看这不明来源的元素力,究竟是如何对外流动的。" + ref_wav_path = "audio/呼玛伊家也还会招工,报酬优厚,我和兄弟们自然没有拒绝的理由.wav" + prompt_text = "呼玛伊家也还会招工,报酬优厚,我和兄弟们自然没有拒绝的理由。" # text_to_speak = "我...我...我不知道你在说什么,我们之间没有秘密呀。可能你弄错了,我们平时关系很好的,请不要误会。" # 创建一个时间戳的文件名 output_file = "outputs/" + str(int(ttime())) + ".wav" @@ -468,7 +508,7 @@ def speak(text_to_speak): def main(): - speak("放学了,我该回家了,你叫我留下来干什么?") + speak("你就是没用!") if __name__ == '__main__':