diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py
index e5e604f5..550722fb 100644
--- a/GPT_SoVITS/inference_webui.py
+++ b/GPT_SoVITS/inference_webui.py
@@ -1,66 +1,17 @@
import os
-
-gpt_path = os.environ.get(
- "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
-)
-sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
-cnhubert_base_path = os.environ.get(
- "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
-)
-bert_path = os.environ.get(
- "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
-)
-infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
-infer_ttswebui = int(infer_ttswebui)
-if "_CUDA_VISIBLE_DEVICES" in os.environ:
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
-is_half = eval(os.environ.get("is_half", "True"))
-import gradio as gr
-from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import librosa,torch
from feature_extractor import cnhubert
-cnhubert.cnhubert_base_path=cnhubert_base_path
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
+from transformers import AutoModelForMaskedLM, AutoTokenizer
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
-device = "cuda"
-tokenizer = AutoTokenizer.from_pretrained(bert_path)
-bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
-if is_half == True:
- bert_model = bert_model.half().to(device)
-else:
- bert_model = bert_model.to(device)
-
-
-# bert_model=bert_model.to(device)
-def get_bert_feature(text, word2ph):
- with torch.no_grad():
- inputs = tokenizer(text, return_tensors="pt")
- for i in inputs:
- inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
- res = bert_model(**inputs, output_hidden_states=True)
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
- assert len(word2ph) == len(text)
- phone_level_feature = []
- for i in range(len(word2ph)):
- repeat_feature = res[i].repeat(word2ph[i], 1)
- phone_level_feature.append(repeat_feature)
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
- return phone_level_feature.T
-
-
-n_semantic = 1024
-
-dict_s2=torch.load(sovits_path,map_location="cpu")
-hps=dict_s2["config"]
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
@@ -90,40 +41,205 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found")
-hps = DictToAttrRecursive(hps)
+class Inference:
+ def __init__(self, is_half, GPT_weight_root, SoVITS_weight_root):
+ self.n_semantic = 1024
+ self.model_loaded = False
+ self.is_half = is_half
+ self.GPT_weight_root = GPT_weight_root
+ self.SoVITS_weight_root = SoVITS_weight_root
-hps.model.semantic_frame_rate = "25hz"
-dict_s1 = torch.load(gpt_path, map_location="cpu")
-config = dict_s1["config"]
-ssl_model = cnhubert.get_model()
-if is_half == True:
- ssl_model = ssl_model.half().to(device)
-else:
- ssl_model = ssl_model.to(device)
-vq_model = SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model
-)
-if is_half == True:
- vq_model = vq_model.half().to(device)
-else:
- vq_model = vq_model.to(device)
-vq_model.eval()
-print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
-hz = 50
-max_sec = config["data"]["max_sec"]
-# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
-t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
-t2s_model.load_state_dict(dict_s1["weight"])
-if is_half == True:
- t2s_model = t2s_model.half()
-t2s_model = t2s_model.to(device)
-t2s_model.eval()
-total = sum([param.nelement() for param in t2s_model.parameters()])
-print("Number of parameter: %.2fM" % (total / 1e6))
+ def update_envs(self, gpt_path, sovits_path, cnhubert_base_path, bert_path):
+ self.gpt_path = os.path.join(self.GPT_weight_root, gpt_path)
+ self.sovits_path = os.path.join(self.SoVITS_weight_root, sovits_path)
+ self.cnhubert_base_path = cnhubert_base_path
+ self.bert_path = bert_path
+
+ cnhubert.cnhubert_base_path=cnhubert_base_path
+
+ yield self.load_model()
+
+ def load_model(self, device='cuda'):
+ try:
+ # Load bert model
+ self.device = device
+ self.tokenizer = AutoTokenizer.from_pretrained(self.bert_path)
+ self.bert_model = AutoModelForMaskedLM.from_pretrained(self.bert_path)
+ if self.is_half == True:
+ self.bert_model = self.bert_model.half().to(device)
+ else:
+ self.bert_model = self.bert_model.to(device)
+
+ # Load ssl model
+ dict_s1 = torch.load(self.gpt_path, map_location="cpu")
+ self.config = dict_s1["config"]
+ self.ssl_model = cnhubert.get_model()
+ if self.is_half == True:
+ self.ssl_model = self.ssl_model.half().to(device)
+ else:
+ self.ssl_model = self.ssl_model.to(device)
+
+ dict_s2=torch.load(self.sovits_path,map_location="cpu")
+ self.hps=dict_s2["config"]
+ self.hps = DictToAttrRecursive(self.hps)
+ self.hps.model.semantic_frame_rate = "25hz"
+
+ # Load vq model
+ self.vq_model = SynthesizerTrn(
+ self.hps.data.filter_length // 2 + 1,
+ self.hps.train.segment_size // self.hps.data.hop_length,
+ n_speakers=self.hps.data.n_speakers,
+ **self.hps.model
+ )
+ if self.is_half == True:
+ self.vq_model = self.vq_model.half().to(device)
+ else:
+ self.vq_model = self.vq_model.to(device)
+ self.vq_model.eval()
+ self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
+
+ # Load t2s model
+ # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
+ self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
+ self.t2s_model.load_state_dict(dict_s1["weight"])
+ if self.is_half == True:
+ self.t2s_model = self.t2s_model.half()
+ self.t2s_model = self.t2s_model.to(device)
+ self.t2s_model.eval()
+ total = sum([param.nelement() for param in self.t2s_model.parameters()])
+ print("Number of parameter: %.2fM" % (total / 1e6))
+
+ self.model_loaded = True
+ return '模型加载成功'
+ except Exception as e:
+ return f'模型加载失败:{e}'
+
+ def unload_model(self):
+ if self.model_loaded:
+ try:
+ del self.bert_model, self.ssl_model, self.hps, self.vq_model, self.t2s_model
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ self.model_loaded = False
+ yield '模型卸载成功'
+ except Exception as e:
+ yield f'模型卸载失败:{e}'
+ else:
+ yield '模型未加载'
+
+ def get_bert_feature(self, text, word2ph):
+ with torch.no_grad():
+ inputs = self.tokenizer(text, return_tensors="pt")
+ for i in inputs:
+ inputs[i] = inputs[i].to(self.device) #####输入是long不用管精度问题,精度随bert_model
+ res = self.bert_model(**inputs, output_hidden_states=True)
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
+ assert len(word2ph) == len(text)
+ phone_level_feature = []
+ for i in range(len(word2ph)):
+ repeat_feature = res[i].repeat(word2ph[i], 1)
+ phone_level_feature.append(repeat_feature)
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
+ return phone_level_feature.T
+
+
+ def get_tts_wav(self, ref_wav_path, prompt_text, prompt_language, text, text_language):
+ if not self.model_loaded:
+ return
+ hz = 50
+ dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
+ t0 = ttime()
+ prompt_text = prompt_text.strip("\n")
+ prompt_language, text = prompt_language, text.strip("\n")
+ with torch.no_grad():
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
+ wav16k = torch.from_numpy(wav16k)
+ if self.is_half == True:
+ wav16k = wav16k.half().to(self.device)
+ else:
+ wav16k = wav16k.to(self.device)
+ ssl_content = self.ssl_model.model(wav16k.unsqueeze(0))[
+ "last_hidden_state"
+ ].transpose(
+ 1, 2
+ ) # .float()
+ codes = self.vq_model.extract_latent(ssl_content)
+ prompt_semantic = codes[0, 0]
+ t1 = ttime()
+ prompt_language = dict_language[prompt_language]
+ text_language = dict_language[text_language]
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
+ phones1 = cleaned_text_to_sequence(phones1)
+ texts = text.split("\n")
+ audio_opt = []
+ zero_wav = np.zeros(
+ int(self.hps.data.sampling_rate * 0.3),
+ dtype=np.float16 if self.is_half == True else np.float32,
+ )
+ for text in texts:
+ # 解决输入目标文本的空行导致报错的问题
+ if (len(text.strip()) == 0):
+ continue
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
+ phones2 = cleaned_text_to_sequence(phones2)
+ if prompt_language == "zh":
+ bert1 = self.get_bert_feature(norm_text1, word2ph1).to(self.device)
+ else:
+ bert1 = torch.zeros(
+ (1024, len(phones1)),
+ dtype=torch.float16 if self.is_half == True else torch.float32,
+ ).to(self.device)
+ if text_language == "zh":
+ bert2 = self.get_bert_feature(norm_text2, word2ph2).to(self.device)
+ else:
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
+ bert = torch.cat([bert1, bert2], 1)
+
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(self.device).unsqueeze(0)
+ bert = bert.to(self.device).unsqueeze(0)
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(self.device)
+ prompt = prompt_semantic.unsqueeze(0).to(self.device)
+ t2 = ttime()
+ with torch.no_grad():
+ # pred_semantic = t2s_model.model.infer(
+ pred_semantic, idx = self.t2s_model.model.infer_panel(
+ all_phoneme_ids,
+ all_phoneme_len,
+ prompt,
+ bert,
+ # prompt_phone_len=ph_offset,
+ top_k=self.config["inference"]["top_k"],
+ early_stop_num=hz * self.config["data"]["max_sec"]
+ )
+ t3 = ttime()
+ # print(pred_semantic.shape,idx)
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
+ 0
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
+ refer = get_spepc(self.hps, ref_wav_path) # .to(device)
+ if self.is_half == True:
+ refer = refer.half().to(self.device)
+ else:
+ refer = refer.to(self.device)
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
+ audio = (
+ self.vq_model.decode(
+ pred_semantic, torch.LongTensor(phones2).to(self.device).unsqueeze(0), refer
+ )
+ .detach()
+ .cpu()
+ .numpy()[0, 0]
+ ) ###试试重建不带上prompt部分
+ audio_opt.append(audio)
+ audio_opt.append(zero_wav)
+ t4 = ttime()
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
+ yield self.hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
+ np.int16
+ )
+
def get_spepc(hps, filename):
@@ -142,119 +258,9 @@ def get_spepc(hps, filename):
return spec
-dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
-
-
-def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
- t0 = ttime()
- prompt_text = prompt_text.strip("\n")
- prompt_language, text = prompt_language, text.strip("\n")
- with torch.no_grad():
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
- wav16k = torch.from_numpy(wav16k)
- if is_half == True:
- wav16k = wav16k.half().to(device)
- else:
- wav16k = wav16k.to(device)
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
- "last_hidden_state"
- ].transpose(
- 1, 2
- ) # .float()
- codes = vq_model.extract_latent(ssl_content)
- prompt_semantic = codes[0, 0]
- t1 = ttime()
- prompt_language = dict_language[prompt_language]
- text_language = dict_language[text_language]
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
- phones1 = cleaned_text_to_sequence(phones1)
- texts = text.split("\n")
- audio_opt = []
- zero_wav = np.zeros(
- int(hps.data.sampling_rate * 0.3),
- dtype=np.float16 if is_half == True else np.float32,
- )
- for text in texts:
- # 解决输入目标文本的空行导致报错的问题
- if (len(text.strip()) == 0):
- continue
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
- phones2 = cleaned_text_to_sequence(phones2)
- if prompt_language == "zh":
- bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
- else:
- bert1 = torch.zeros(
- (1024, len(phones1)),
- dtype=torch.float16 if is_half == True else torch.float32,
- ).to(device)
- if text_language == "zh":
- bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
- else:
- bert2 = torch.zeros((1024, len(phones2))).to(bert1)
- bert = torch.cat([bert1, bert2], 1)
-
- all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
- bert = bert.to(device).unsqueeze(0)
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
- prompt = prompt_semantic.unsqueeze(0).to(device)
- t2 = ttime()
- with torch.no_grad():
- # pred_semantic = t2s_model.model.infer(
- pred_semantic, idx = t2s_model.model.infer_panel(
- all_phoneme_ids,
- all_phoneme_len,
- prompt,
- bert,
- # prompt_phone_len=ph_offset,
- top_k=config["inference"]["top_k"],
- early_stop_num=hz * max_sec,
- )
- t3 = ttime()
- # print(pred_semantic.shape,idx)
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(
- 0
- ) # .unsqueeze(0)#mq要多unsqueeze一次
- refer = get_spepc(hps, ref_wav_path) # .to(device)
- if is_half == True:
- refer = refer.half().to(device)
- else:
- refer = refer.to(device)
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
- audio = (
- vq_model.decode(
- pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
- )
- .detach()
- .cpu()
- .numpy()[0, 0]
- ) ###试试重建不带上prompt部分
- audio_opt.append(audio)
- audio_opt.append(zero_wav)
- t4 = ttime()
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
- yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
- np.int16
- )
-
-
-splits = {
- ",",
- "。",
- "?",
- "!",
- ",",
- ".",
- "?",
- "!",
- "~",
- ":",
- ":",
- "—",
- "…",
-} # 不考虑省略号
-
def split(todo_text):
+ splits = { ",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } # 不考虑省略号
todo_text = todo_text.replace("……", "。").replace("——", ",")
if todo_text[-1] not in splits:
todo_text += "。"
@@ -314,50 +320,3 @@ def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
-
-with gr.Blocks(title="GPT-SoVITS WebUI") as app:
- gr.Markdown(
- value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
- )
- # with gr.Tabs():
- # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
- with gr.Group():
- gr.Markdown(value="*请上传并填写参考信息")
- with gr.Row():
- inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
- prompt_text = gr.Textbox(label="参考音频的文本", value="")
- prompt_language = gr.Dropdown(
- label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
- )
- gr.Markdown(value="*请填写需要合成的目标文本")
- with gr.Row():
- text = gr.Textbox(label="需要合成的文本", value="")
- text_language = gr.Dropdown(
- label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
- )
- inference_button = gr.Button("合成语音", variant="primary")
- output = gr.Audio(label="输出的语音")
- inference_button.click(
- get_tts_wav,
- [inp_ref, prompt_text, prompt_language, text, text_language],
- [output],
- )
-
- gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
- with gr.Row():
- text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
- button1 = gr.Button("凑五句一切", variant="primary")
- button2 = gr.Button("凑50字一切", variant="primary")
- button3 = gr.Button("按中文句号。切", variant="primary")
- text_opt = gr.Textbox(label="切分后文本", value="")
- button1.click(cut1, [text_inp], [text_opt])
- button2.click(cut2, [text_inp], [text_opt])
- button3.click(cut3, [text_inp], [text_opt])
- gr.Markdown(value="后续将支持混合语种编码文本输入。")
-
-app.queue(concurrency_count=511, max_size=1022).launch(
- server_name="0.0.0.0",
- inbrowser=True,
- server_port=infer_ttswebui,
- quiet=True,
-)