mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-01-10 12:26:57 +08:00
update dev branch
This commit is contained in:
parent
22726cacee
commit
7de7f68161
@ -1,5 +1,5 @@
|
|||||||
import os, sys
|
import os, sys
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.insert(0, now_dir)
|
sys.path.insert(0, now_dir)
|
||||||
from .text.g2pw import G2PWPinyin
|
from text.g2pw import G2PWPinyin
|
||||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import os, re, logging
|
import os, re, logging
|
||||||
import LangSegment
|
import LangSegment
|
||||||
import pdb
|
import pdb
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
@ -21,7 +22,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
dict_language = {
|
version = os.environ.get("version", "v2")
|
||||||
|
dict_language_v1 = {
|
||||||
i18n("中文"): "all_zh", # 全部按中文识别
|
i18n("中文"): "all_zh", # 全部按中文识别
|
||||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||||
i18n("日文"): "all_ja", # 全部按日文识别
|
i18n("日文"): "all_ja", # 全部按日文识别
|
||||||
@ -29,6 +31,20 @@ dict_language = {
|
|||||||
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||||
}
|
}
|
||||||
|
dict_language_v2 = {
|
||||||
|
i18n("中文"): "all_zh", # 全部按中文识别
|
||||||
|
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||||
|
i18n("日文"): "all_ja", # 全部按日文识别
|
||||||
|
i18n("粤语"): "all_yue", # 全部按中文识别
|
||||||
|
i18n("韩文"): "all_ko", # 全部按韩文识别
|
||||||
|
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||||
|
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||||
|
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
||||||
|
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
||||||
|
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||||
|
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
||||||
|
}
|
||||||
|
dict_language = dict_language_v1 if version == 'v1' else dict_language_v2
|
||||||
|
|
||||||
is_share = os.environ.get("is_share", "False")
|
is_share = os.environ.get("is_share", "False")
|
||||||
is_share = eval(is_share)
|
is_share = eval(is_share)
|
||||||
@ -283,12 +299,18 @@ def set_gpt_weights(gpt_path):
|
|||||||
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
|
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
|
||||||
|
|
||||||
|
|
||||||
def set_sovits_weights(sovits_path):
|
def set_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||||
global vq_model, hps
|
global vq_model, hps, version, dict_language
|
||||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
hps = DictToAttrRecursive(hps)
|
hps = DictToAttrRecursive(hps)
|
||||||
hps.model.semantic_frame_rate = "25hz"
|
hps.model.semantic_frame_rate = "25hz"
|
||||||
|
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||||
|
hps.model.version = "v1"
|
||||||
|
else:
|
||||||
|
hps.model.version = "v2"
|
||||||
|
version = hps.model.version
|
||||||
|
print("sovits版本:", hps.model.version)
|
||||||
vq_model = SynthesizerTrn(
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
@ -303,8 +325,26 @@ def set_sovits_weights(sovits_path):
|
|||||||
vq_model = vq_model.to(device)
|
vq_model = vq_model.to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||||
with open("./sweight.txt", "w", encoding="utf-8") as f:
|
dict_language = dict_language_v1 if version == 'v1' else dict_language_v2
|
||||||
f.write(sovits_path)
|
with open("./weight.json") as f:
|
||||||
|
data = f.read()
|
||||||
|
data = json.loads(data)
|
||||||
|
data["SoVITS"][version] = sovits_path
|
||||||
|
with open("./weight.json", "w") as f:
|
||||||
|
f.write(json.dumps(data))
|
||||||
|
if prompt_language is not None and text_language is not None:
|
||||||
|
if prompt_language in list(dict_language.keys()):
|
||||||
|
prompt_text_update, prompt_language_update = {'__type__': 'update'}, {'__type__': 'update', 'value': prompt_language}
|
||||||
|
else:
|
||||||
|
prompt_text_update = {'__type__': 'update', 'value': ''}
|
||||||
|
prompt_language_update = {'__type__': 'update', 'value': i18n("中文")}
|
||||||
|
if text_language in list(dict_language.keys()):
|
||||||
|
text_update, text_language_update = {'__type__': 'update'}, {'__type__': 'update', 'value': text_language}
|
||||||
|
else:
|
||||||
|
text_update = {'__type__': 'update', 'value': ''}
|
||||||
|
text_language_update = {'__type__': 'update', 'value': i18n("中文")}
|
||||||
|
return {'__type__': 'update', 'choices': list(dict_language.keys())}, {'__type__': 'update', 'choices': list(
|
||||||
|
dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
||||||
|
|
||||||
|
|
||||||
def gen_audio(ref_wav_path, prompt_text, text_to_speak, output_file, top_k=20, top_p=0.6, temperature=0.6, ref_free=False):
|
def gen_audio(ref_wav_path, prompt_text, text_to_speak, output_file, top_k=20, top_p=0.6, temperature=0.6, ref_free=False):
|
||||||
@ -454,12 +494,12 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def speak(text_to_speak):
|
def speak(text_to_speak):
|
||||||
sovits_path = "SoVITS_weights/阿贝多_e12_s2748.pth"
|
sovits_path = "SoVITS_weights/迪希雅_e15_s1050.pth"
|
||||||
set_sovits_weights(sovits_path)
|
set_sovits_weights(sovits_path)
|
||||||
gpt_path = "GPT_weights/阿贝多-e10.ckpt"
|
gpt_path = "GPT_weights/迪希雅-e10.ckpt"
|
||||||
set_gpt_weights(gpt_path)
|
set_gpt_weights(gpt_path)
|
||||||
ref_wav_path = "audio/首先,先看看这不明来源的元素力,究竟是如何对外流动的.wav"
|
ref_wav_path = "audio/呼玛伊家也还会招工,报酬优厚,我和兄弟们自然没有拒绝的理由.wav"
|
||||||
prompt_text = "首先,先看看这不明来源的元素力,究竟是如何对外流动的。"
|
prompt_text = "呼玛伊家也还会招工,报酬优厚,我和兄弟们自然没有拒绝的理由。"
|
||||||
# text_to_speak = "我...我...我不知道你在说什么,我们之间没有秘密呀。可能你弄错了,我们平时关系很好的,请不要误会。"
|
# text_to_speak = "我...我...我不知道你在说什么,我们之间没有秘密呀。可能你弄错了,我们平时关系很好的,请不要误会。"
|
||||||
# 创建一个时间戳的文件名
|
# 创建一个时间戳的文件名
|
||||||
output_file = "outputs/" + str(int(ttime())) + ".wav"
|
output_file = "outputs/" + str(int(ttime())) + ".wav"
|
||||||
@ -468,7 +508,7 @@ def speak(text_to_speak):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
speak("放学了,我该回家了,你叫我留下来干什么?")
|
speak("你就是没用!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user