update dev branch

This commit is contained in:
Leon 2024-08-15 16:17:09 +08:00
parent 22726cacee
commit 7de7f68161
2 changed files with 51 additions and 11 deletions

View File

@ -1,5 +1,5 @@
import os, sys import os, sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.insert(0, now_dir) sys.path.insert(0, now_dir)
from .text.g2pw import G2PWPinyin from text.g2pw import G2PWPinyin
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True) g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)

View File

@ -1,6 +1,7 @@
import os, re, logging import os, re, logging
import LangSegment import LangSegment
import pdb import pdb
import json
import torch import torch
import gradio as gr import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -21,7 +22,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
i18n = I18nAuto() i18n = I18nAuto()
dict_language = { version = os.environ.get("version", "v2")
dict_language_v1 = {
i18n("中文"): "all_zh", # 全部按中文识别 i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变 i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja", # 全部按日文识别 i18n("日文"): "all_ja", # 全部按日文识别
@ -29,6 +31,20 @@ dict_language = {
i18n("日英混合"): "ja", # 按日英混合识别####不变 i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("多语种混合"): "auto", # 多语种启动切分识别语种 i18n("多语种混合"): "auto", # 多语种启动切分识别语种
} }
dict_language_v2 = {
i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja", # 全部按日文识别
i18n("粤语"): "all_yue", # 全部按中文识别
i18n("韩文"): "all_ko", # 全部按韩文识别
i18n("中英混合"): "zh", # 按中英混合识别####不变
i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
}
dict_language = dict_language_v1 if version == 'v1' else dict_language_v2
is_share = os.environ.get("is_share", "False") is_share = os.environ.get("is_share", "False")
is_share = eval(is_share) is_share = eval(is_share)
@ -283,12 +299,18 @@ def set_gpt_weights(gpt_path):
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
def set_sovits_weights(sovits_path): def set_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global vq_model, hps global vq_model, hps, version, dict_language
dict_s2 = torch.load(sovits_path, map_location="cpu") dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"] hps = dict_s2["config"]
hps = DictToAttrRecursive(hps) hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz" hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
version = hps.model.version
print("sovits版本:", hps.model.version)
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -303,8 +325,26 @@ def set_sovits_weights(sovits_path):
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt", "w", encoding="utf-8") as f: dict_language = dict_language_v1 if version == 'v1' else dict_language_v2
f.write(sovits_path) with open("./weight.json") as f:
data = f.read()
data = json.loads(data)
data["SoVITS"][version] = sovits_path
with open("./weight.json", "w") as f:
f.write(json.dumps(data))
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
prompt_text_update, prompt_language_update = {'__type__': 'update'}, {'__type__': 'update', 'value': prompt_language}
else:
prompt_text_update = {'__type__': 'update', 'value': ''}
prompt_language_update = {'__type__': 'update', 'value': i18n("中文")}
if text_language in list(dict_language.keys()):
text_update, text_language_update = {'__type__': 'update'}, {'__type__': 'update', 'value': text_language}
else:
text_update = {'__type__': 'update', 'value': ''}
text_language_update = {'__type__': 'update', 'value': i18n("中文")}
return {'__type__': 'update', 'choices': list(dict_language.keys())}, {'__type__': 'update', 'choices': list(
dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
def gen_audio(ref_wav_path, prompt_text, text_to_speak, output_file, top_k=20, top_p=0.6, temperature=0.6, ref_free=False): def gen_audio(ref_wav_path, prompt_text, text_to_speak, output_file, top_k=20, top_p=0.6, temperature=0.6, ref_free=False):
@ -454,12 +494,12 @@ else:
def speak(text_to_speak): def speak(text_to_speak):
sovits_path = "SoVITS_weights/阿贝多_e12_s2748.pth" sovits_path = "SoVITS_weights/迪希雅_e15_s1050.pth"
set_sovits_weights(sovits_path) set_sovits_weights(sovits_path)
gpt_path = "GPT_weights/阿贝多-e10.ckpt" gpt_path = "GPT_weights/迪希雅-e10.ckpt"
set_gpt_weights(gpt_path) set_gpt_weights(gpt_path)
ref_wav_path = "audio/首先,先看看这不明来源的元素力,究竟是如何对外流动的.wav" ref_wav_path = "audio/呼玛伊家也还会招工,报酬优厚,我和兄弟们自然没有拒绝的理由.wav"
prompt_text = "首先,先看看这不明来源的元素力,究竟是如何对外流动的" prompt_text = "呼玛伊家也还会招工,报酬优厚,我和兄弟们自然没有拒绝的理由"
# text_to_speak = "我...我...我不知道你在说什么,我们之间没有秘密呀。可能你弄错了,我们平时关系很好的,请不要误会。" # text_to_speak = "我...我...我不知道你在说什么,我们之间没有秘密呀。可能你弄错了,我们平时关系很好的,请不要误会。"
# 创建一个时间戳的文件名 # 创建一个时间戳的文件名
output_file = "outputs/" + str(int(ttime())) + ".wav" output_file = "outputs/" + str(int(ttime())) + ".wav"
@ -468,7 +508,7 @@ def speak(text_to_speak):
def main(): def main():
speak("放学了,我该回家了,你叫我留下来干什么?") speak("你就是没用!")
if __name__ == '__main__': if __name__ == '__main__':