diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 606c1a8..d77ae37 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -285,10 +285,11 @@ class TTS: def init_cnhuhbert_weights(self, base_path: str): print(f"Loading CNHuBERT weights from {base_path}") self.cnhuhbert_model = CNHubert(base_path) - self.cnhuhbert_model=self.cnhuhbert_model.eval() + self.cnhuhbert_model = self.cnhuhbert_model.eval() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) - if self.configs.is_half and str(self.configs.device)!="cpu": + if self.configs.is_half and str(self.configs.device) != "cpu": self.cnhuhbert_model = self.cnhuhbert_model.half() + self.cnhuhbert_model = torch.compile(self.cnhuhbert_model) @@ -296,10 +297,11 @@ class TTS: print(f"Loading BERT weights from {base_path}") self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) - self.bert_model=self.bert_model.eval() + self.bert_model = self.bert_model.eval() self.bert_model = self.bert_model.to(self.configs.device) - if self.configs.is_half and str(self.configs.device)!="cpu": + if self.configs.is_half and str(self.configs.device) != "cpu": self.bert_model = self.bert_model.half() + self.bert_model = torch.compile(self.bert_model) def init_vits_weights(self, weights_path: str): print(f"Loading VITS weights from {weights_path}") @@ -334,26 +336,28 @@ class TTS: vits_model = vits_model.to(self.configs.device) vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) + vits_model = torch.compile(vits_model) self.vits_model = vits_model - if self.configs.is_half and str(self.configs.device)!="cpu": + if self.configs.is_half and str(self.configs.device) != "cpu": self.vits_model = self.vits_model.half() def init_t2s_weights(self, weights_path: str): - print(f"Loading Text2Semantic weights from {weights_path}") - self.configs.t2s_weights_path = weights_path - self.configs.save_configs() - self.configs.hz = 50 - dict_s1 = torch.load(weights_path, map_location=self.configs.device) - config = dict_s1["config"] - self.configs.max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) - t2s_model.load_state_dict(dict_s1["weight"]) - t2s_model = t2s_model.to(self.configs.device) - t2s_model = t2s_model.eval() - self.t2s_model = t2s_model - if self.configs.is_half and str(self.configs.device)!="cpu": - self.t2s_model = self.t2s_model.half() + print(f"Loading Text2Semantic weights from {weights_path}") + self.configs.t2s_weights_path = weights_path + self.configs.save_configs() + self.configs.hz = 50 + dict_s1 = torch.load(weights_path, map_location=self.configs.device) + config = dict_s1["config"] + self.configs.max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + t2s_model = t2s_model.to(self.configs.device) + t2s_model = t2s_model.eval() + t2s_model = torch.compile(t2s_model) + self.t2s_model = t2s_model + if self.configs.is_half and str(self.configs.device) != "cpu": + self.t2s_model = self.t2s_model.half() def enable_half_precision(self, enable: bool = True, save: bool = True): ''' diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 66f1193..34c460c 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -1,11 +1,11 @@ custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cuda - is_half: true - t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt + device: cpu + is_half: false + t2s_weights_path: GPT_weights_v2/amiya-e50.ckpt version: v2 - vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth + vits_weights_path: SoVITS_weights_v2/amiya_e25_s950.pth default: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base diff --git a/GPT_SoVITS/text/ja_userdic/user.dict b/GPT_SoVITS/text/ja_userdic/user.dict new file mode 100644 index 0000000..6ddcfef Binary files /dev/null and b/GPT_SoVITS/text/ja_userdic/user.dict differ diff --git a/GPT_SoVITS/text/ja_userdic/userdict.md5 b/GPT_SoVITS/text/ja_userdic/userdict.md5 new file mode 100644 index 0000000..7848c97 --- /dev/null +++ b/GPT_SoVITS/text/ja_userdic/userdict.md5 @@ -0,0 +1 @@ +d36bd5ffba62f195d22bf4f1a41cd08f \ No newline at end of file diff --git a/api_amiya.py b/api_amiya.py new file mode 100644 index 0000000..443eab1 --- /dev/null +++ b/api_amiya.py @@ -0,0 +1,442 @@ +""" +# WebAPI文档 + +` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` + +## 执行参数: + `-a` - `绑定地址, 默认"127.0.0.1"` + `-p` - `绑定端口, 默认9880` + `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"` + +## 调用: + +### 推理 + +endpoint: `/tts` +GET: +``` +http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true +``` + +POST: +```json +{ + "input": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35 # float. repetition penalty for T2S model. +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: +``` +http://127.0.0.1:9880/control?command=restart +``` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + + +### 切换GPT模型 + +endpoint: `/set_gpt_weights` + +GET: +``` +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +``` +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 切换Sovits模型 + +endpoint: `/set_sovits_weights` + +GET: +``` +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +``` + +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + +""" +import os +import sys +import traceback +from typing import Generator + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % now_dir) + +import argparse +import subprocess +import wave +import signal +import numpy as np +import soundfile as sf +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from io import BytesIO +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from pydantic import BaseModel +cut_method_names = get_cut_method_names() + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") +parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") +args = parser.parse_args() +config_path = args.tts_config +# device = args.device +port = args.port +host = args.bind_addr +argv = sys.argv + +if config_path in [None, ""]: + config_path = "GPT-SoVITS/configs/tts_infer.yaml" + +tts_config = TTS_Config(config_path) +print(tts_config) +tts_pipeline = TTS(tts_config) + +REFERENCE_AUDIO = "/Users/baysonfox/Desktop/amiya-chatbot/reference.mp3" +AUXILARY_AUDIO = [ + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_001_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_002_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_005_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_007_seg_1.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_008_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_009_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_010_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_011_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_012_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_028_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_029_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_032_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_043_seg_0.mp3", + "/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_044_seg_0.mp3" + ] +APP = FastAPI() +class TTS_Request(BaseModel): + input: str = None + text_lang: str = "all_zh" + ref_audio_path: str = REFERENCE_AUDIO + aux_ref_audio_paths: list = AUXILARY_AUDIO + prompt_lang: str = "all_zh" + prompt_text: str = "博士,休息好了吗?还觉得累的话,不用勉强的。有我在呢。" + top_k:int = 5 + top_p:float = 1 + temperature:float = 1 + text_split_method:str = "cut5" + batch_size:int = 1 + batch_threshold:float = 0.75 + split_bucket:bool = True + speed_factor:float = 1.0 + fragment_interval:float = 0.3 + seed:int = -1 + media_type:str = "wav" + streaming_mode:bool = False + parallel_infer:bool = True + repetition_penalty:float = 1.35 + +### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files +def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): + with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: + audio_file.write(data) + return io_buffer + + +def pack_raw(io_buffer:BytesIO, data:np.ndarray, rate:int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer:BytesIO, data:np.ndarray, rate:int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format='wav') + return io_buffer + +def pack_aac(io_buffer:BytesIO, data:np.ndarray, rate:int): + process = subprocess.Popen([ + 'ffmpeg', + '-f', 's16le', # 输入16位有符号小端整数PCM + '-ar', str(rate), # 设置采样率 + '-ac', '1', # 单声道 + '-i', 'pipe:0', # 从管道读取输入 + '-c:a', 'aac', # 音频编码器为AAC + '-b:a', '192k', # 比特率 + '-vn', # 不包含视频 + '-f', 'adts', # 输出AAC数据流格式 + 'pipe:1' # 将输出写入管道 + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + +def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + + +# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + # This will create a wave header then append the frame input + # It should be first on a streaming wav file + # Other frames better should not have it (else you will hear some artifacts each chunk start) + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + return wav_buf.read() + + +def handle_control(command:str): + if command == "restart": + os.execl(sys.executable, sys.executable, *argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def check_params(req:dict): + text:str = req.get("text", "") + text_lang:str = req.get("text_lang", "") + ref_audio_path:str = req.get("ref_audio_path", "") + streaming_mode:bool = req.get("streaming_mode", False) + media_type:str = req.get("media_type", "wav") + prompt_lang:str = req.get("prompt_lang", "") + text_split_method:str = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if (text_lang in [None, ""]) : + return JSONResponse(status_code=400, content={"message": "text_lang is required"}) + elif text_lang.lower() not in tts_config.languages: + return JSONResponse(status_code=400, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}) + if (prompt_lang in [None, ""]) : + return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) + elif prompt_lang.lower() not in tts_config.languages: + return JSONResponse(status_code=400, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}) + if media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) + elif media_type == "ogg" and not streaming_mode: + return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + + if text_split_method not in cut_method_names: + return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}) + + return None + +async def tts_handle(req:dict): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 5, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". + "streaming_mode": False, # bool. whether to return a streaming response. + "parallel_infer": True, # bool.(optional) whether to use parallel inference. + "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + } + returns: + StreamingResponse: audio stream response. + """ + req['text'] = req.get('input') + streaming_mode = req.get("streaming_mode", False) + return_fragment = req.get("return_fragment", False) + media_type = req.get("media_type", "wav") + + check_res = check_params(req) + if check_res is not None: + return check_res + + if streaming_mode or return_fragment: + req["return_fragment"] = True + + try: + tts_generator=tts_pipeline.run(req) + + if streaming_mode: + def streaming_generator(tts_generator:Generator, media_type:str): + if media_type == "wav": + yield wave_header_chunk() + media_type = "raw" + for sr, chunk in tts_generator: + yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" + return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") + + else: + sr, audio_data = next(tts_generator) + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_data, media_type=f"audio/{media_type}") + except Exception as e: + return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) + + + + + + +@APP.get("/control") +async def control(command: str = None): + if command is None: + return JSONResponse(status_code=400, content={"message": "command is required"}) + handle_control(command) + + + +@APP.get("/tts") +async def tts_get_endpoint( + text: str = None, + text_lang: str = "all_zh", + ref_audio_path: str = REFERENCE_AUDIO, + aux_ref_audio_paths:list = AUXILARY_AUDIO, + prompt_lang: str = "all_zh", + prompt_text: str = "博士,休息好了吗?还觉得累的话,不用勉强的。有我在呢。", + top_k:int = 5, + top_p:float = 1, + temperature:float = 1, + text_split_method:str = "cut0", + batch_size:int = 1, + batch_threshold:float = 0.75, + split_bucket:bool = True, + speed_factor:float = 1.0, + fragment_interval:float = 0.3, + seed:int = -1, + media_type:str = "wav", + streaming_mode:bool = False, + parallel_infer:bool = True, + repetition_penalty:float = 1.35 + ): + req = { + "text": text, + "text_lang": text_lang.lower(), + "ref_audio_path": REFERENCE_AUDIO, + "aux_ref_audio_paths": AUXILARY_AUDIO, + "prompt_text": prompt_text, + "prompt_lang": prompt_lang.lower(), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size":int(batch_size), + "batch_threshold":float(batch_threshold), + "speed_factor":float(speed_factor), + "split_bucket":split_bucket, + "fragment_interval":fragment_interval, + "seed":seed, + "media_type":media_type, + "streaming_mode":streaming_mode, + "parallel_infer":parallel_infer, + "repetition_penalty":float(repetition_penalty) + } + return await tts_handle(req) + + +@APP.post("/tts") +async def tts_post_endpoint(request: TTS_Request): + req = request.model_dump() + return await tts_handle(req) + + +@APP.get("/set_refer_audio") +async def set_refer_aduio(refer_audio_path: str = None): + try: + tts_pipeline.set_ref_audio(refer_audio_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + +@APP.get("/set_gpt_weights") +async def set_gpt_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) + tts_pipeline.init_t2s_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)}) + + return JSONResponse(status_code=200, content={"message": "success"}) + +@APP.get("/set_sovits_weights") +async def set_sovits_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) + tts_pipeline.init_vits_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + + +if __name__ == "__main__": + try: + if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None + uvicorn.run(app=APP, host=host, port=port, workers=1) + except Exception as e: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) diff --git a/webui_shared.py b/webui_shared.py new file mode 100644 index 0000000..ec160d1 --- /dev/null +++ b/webui_shared.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python +# coding=utf-8 +import gradio as gr +import numpy as np +from transformers import AutoModelForMaskedLM, AutoTokenizer +import librosa +from feature_extractor import cnhubert +from GPT_SoVITS.module.models import SynthesizerTrn +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from text import chinese, cleaned_text_to_sequence +from text.cleaner import clean_text +from text.LangSegmenter import LangSegmenter +from time import time as ttime +from module.mel_processing import spectrogram_torch, spec_to_mel_torch +from tools.my_utils import load_audio +import torch, torchaudio +import traceback +import os, re + +# 模型路径 +gpt_path = 'GPT_weights_v2/amiya-e50.ckpt' +sovits_path = 'SoVITS_weights_v2/amiya_e25_s950.pth' +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" +cnhubert.cnhubert_base_path = cnhubert_base_path + +# 参考音频相关配置 +REFERENCE_AUDIO = "/Users/baysonfox/Desktop/amiya-chatbot/reference.mp3" +REFERENCE_TEXT = "博士,休息好了吗?还觉得累的话,不用勉强的。有我在呢。" +INF_REFS = [os.path.join("/Users/baysonfox/Desktop/amiya-chatbot/references", f) for f in os.listdir("/Users/baysonfox/Desktop/amiya-chatbot/references")] + +# 模型相关设置 +DEVICE = 'cpu' +dict_language = { + "中文": "all_zh",#全部按中文识别 + "英文": "en",#全部按英文识别#######不变 + "日文": "all_ja",#全部按日文识别 + "中英混合": "zh",#按中英混合识别####不变 + "日英混合": "ja",#按日英混合识别####不变 + "多语种混合": "auto",#多语种启动切分识别语种 +} +os.environ["TOKENIZERS_PARALLELISM"] = "False" + +DTYPE = torch.float32 +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +ssl_model = cnhubert.get_model() + +# Accelerated Inference +tokenizer = torch.compile(tokenizer) +bert_model = torch.compile(bert_model) +ssl_model = torch.compile(ssl_model) + +# 标点符号 +PUNCTUATION = {'!', '?', '…', ',', '.', '-', " "} +# 中文标点符号 +SPLITS = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + +def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(DEVICE) + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + return phone_level_feature.T + +resample_transform_dict={} +def resample(audio_tensor, sr0): + global resample_transform_dict + if sr0 not in resample_transform_dict: + resample_transform_dict[sr0] = torchaudio.transforms.Resample( + sr0, 24000 + ).to(DEVICE) + return resample_transform_dict[sr0](audio_tensor) + +def change_sovits_weights(prompt_language=None,text_language=None): + global vq_model, hps, version, model_version, dict_language + model_version = version = "v2" + + if prompt_language is not None and text_language is not None: + prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value': "all_zh"} + + if text_language in list(dict_language.keys()): + text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language} + else: + text_update = {'__type__':'update', 'value':''} + text_language_update = {'__type__':'update', 'value':"中文"} + yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": False},{"__type__": "update", "visible": False},{"__type__": "update", "value": False,"interactive": False} + + dict_s2 = torch.load(sovits_path, map_location="cpu") + + hps = DictToAttrRecursive(dict_s2["config"]) + + hps.model.semantic_frame_rate = "25hz" + hps.model.version = "v2" + version = hps.model.version + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model + ) + model_version = version + + vq_model = vq_model.to(DEVICE) + print("loading sovits_%s" % model_version,vq_model.load_state_dict( + dict_s2["weight"], + strict=False)) + vq_model.eval() + vq_model = torch.compile(vq_model) + +def change_gpt_weights(gpt_path): + global hz, max_sec, t2s_model, config + hz = 50 + dict_s1 = torch.load(gpt_path, map_location="cpu") + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + t2s_model = t2s_model.to(DEVICE) + t2s_model.eval() + t2s_model = torch.compile(t2s_model) + +def get_spepc(hps, filename): + print("hps samplingrate", hps.data.sampling_rate) + audio = load_audio(filename, int(hps.data.sampling_rate)) + audio = torch.FloatTensor(audio) + maxx = audio.abs().max() + if(maxx > 1): + audio /= min(2, maxx.item()) + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + return spec + +def clean_text_inf(text, language, version): + phones, word2ph, norm_text = clean_text(text, language, version) + print("phones: ", phones) + print("word2ph: ", word2ph) + print("norm_text: ", norm_text) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + +def get_bert_inf(phones, word2ph, norm_text, language): + language=language.replace("all_","") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(DEVICE) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float32 + ).to(DEVICE) + + return bert + +def get_first(text): + pattern = "[" + "".join(re.escape(sep) for sep in SPLITS) + "]" + text = re.split(pattern, text)[0].strip() + return text + +def get_phones_and_bert(text,language,version,final=False): + if language in {"en", "all_zh", "all_ja"}: + language = language.replace("all_","") + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + if language == "zh": + if re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext,"zh",version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = get_bert_feature(norm_text, word2ph).to(DEVICE) + elif language == "yue" and re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext,"yue",version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float32 + ).to(DEVICE) + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: + textlist=[] + langlist=[] + if language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = ''.join(norm_text_list) + + if not final and len(phones) < 6: + return get_phones_and_bert("." + text,language,version,final=True) + + return phones,bert.to(DTYPE),norm_text + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + spec=spectrogram_torch(y,n_fft,sampling_rate,hop_size,win_size,center) + mel=spec_to_mel_torch(spec,n_fft,num_mels,sampling_rate,fmin,fmax) + return mel +mel_fn_args = { + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False +} + +spec_min = -12 +spec_max = 2 +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min +mel_fn=lambda x: mel_spectrogram(x, **mel_fn_args) + +def merge_short_text_in_array(texts, threshold): + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if (len(text) > 0): + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in SPLITS: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in SPLITS: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) + else: + opts = [inp] + opts = [item for item in opts if not set(item).issubset(PUNCTUATION)] + return "\n".join(opts) + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(PUNCTUATION)] + return "\n".join(opts) + +def cut3(inp): + inp = inp.strip("\n") + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(PUNCTUATION)] + return "\n".join(opts) + +def cut4(inp): + inp = inp.strip("\n") + opts = re.split(r'(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): + items.append(char) + else: + items.append(char) + mergeitems.append("".join(items)) + items = [] + else: + items.append(char) + + if items: + mergeitems.append("".join(items)) + + opt = [item for item in mergeitems if not set(item).issubset(punds)] + return "\n".join(opt) + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split('(\d+)', s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + +def process_text(texts): + _text=[] + if all(text in [None, " ", "\n",""] for text in texts): + raise ValueError("请输入有效文本") + for text in texts: + if text in [None, " ", ""]: + pass + else: + _text.append(text) + return _text + +cache = {} +def get_tts_wav(text, + text_language, + how_to_cut="不切", + top_k=20, + top_p=0.6, + temperature=0.6, + speed=1, + if_freeze=False, + sample_steps=8): + + global cache + prompt_text = REFERENCE_TEXT + prompt_language = "中文" + ref_wav_path = REFERENCE_AUDIO + inp_refs = INF_REFS + t = [] + t0 = ttime() + prompt_language = dict_language[prompt_language] + text_language = dict_language[text_language] + + # 去除prompt_text的换行 + prompt_text = prompt_text.strip("\n") + # 手动添加标点符号 + if (prompt_text[-1] not in SPLITS): prompt_text += "。" if prompt_language != "en" else "." + print("实际输入的参考文本:", prompt_text) + text = text.strip("\n") + + print("实际输入的目标文本:", text) + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float32 + ) + + with torch.no_grad(): + # 参考音频和sampling rate,numpy格式 + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + # numpy -> torch + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(DEVICE) + zero_wav_torch = zero_wav_torch.to(DEVICE) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ + "last_hidden_state" + ].transpose( + 1, 2 + ) + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(DEVICE) + + t1 = ttime() + t.append(t1-t0) + + if how_to_cut != "不切": + cut_map = { + "凑四句一切": cut1, + "凑50字一切": cut2, + "按中文句号。切": cut3, + "按英文句号.切": cut4, + "按标点符号切": cut5 + } + cut_map[how_to_cut](text) + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + + print("实际输入的目标文本(切句后):", text) + + texts = text.split("\n") + texts = process_text(texts) + texts = merge_short_text_in_array(texts, 5) + audio_opt = [] + phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version) + + for i_text,text in enumerate(texts): + # 解决输入目标文本的空行导致报错的问题 + if (len(text.strip()) == 0): + continue + if (text[-1] not in SPLITS): text += "。" if text_language != "en" else "." + print("实际输入的目标文本(每句):", text) + phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version) + print("前端处理后的文本(每句):", norm_text2) + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(DEVICE).unsqueeze(0) + + bert = bert.to(DEVICE).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(DEVICE) + + t2 = ttime() + if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text] + else: + with torch.no_grad(): + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + cache[i_text]=pred_semantic + t3 = ttime() + refers=[] + if(inp_refs): + for path in inp_refs: + try: + refer = get_spepc(hps, path).to(DTYPE).to(DEVICE) + refers.append(refer) + except: + traceback.print_exc() + if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(DTYPE).to(DEVICE)] + audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(DEVICE).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0]) + + audio_opt.append(audio) + audio_opt.append(zero_wav) + t4 = ttime() + t.extend([t2 - t1,t3 - t2, t4 - t3]) + t1 = ttime() + print("%.3f\t%.3f\t%.3f\t%.3f" % + (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])) + ) + sr=hps.data.sampling_rate if model_version!="v3"else 24000 + yield sr, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) + +with gr.Blocks(title="GPT-SoVITS WebUI") as app: + + try:next(change_sovits_weights(sovits_path)) + except:pass + change_gpt_weights(gpt_path) # 初始化GPT模型 + + gr.Markdown(value="

大概可能也许是阿米娅的声音(

") + with gr.Row() as main_row: + with gr.Column(scale=7) as text_column: + text = gr.Textbox( + label="需要合成的文本", + value="", + lines=13, + max_lines=13 + ) + + with gr.Row(): + inference_button = gr.Button("合成语音", variant="primary", size='lg') + output = gr.Audio(label="输出的语音") + + with gr.Column(scale=5) as control_column: + text_language = gr.Dropdown( + label="需要合成的语种。限制范围越小判别效果越好。", + choices=list(dict_language.keys()), + value="中文" + ) + + how_to_cut = gr.Dropdown( + label="怎么切", + choices=["不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"], + value="凑四句一切" + ) + + gr.Markdown(value="语速调整,高为更快") + + if_freeze = gr.Checkbox( + label="是否直接对上次合成结果调整语速和音色。防止随机性。", + value=False + ) + + speed = gr.Slider( + minimum=0.6, + maximum=1.65, + step=0.05, + label="语速", + value=1 + ) + + gr.Markdown("GPT采样参数(无参考文本时不要太低。不懂就用默认):") + + top_k = gr.Slider( + minimum=1, + maximum=100, + step=1, + label="top_k", + value=15 + ) + + top_p = gr.Slider( + minimum=0, + maximum=1, + step=0.05, + label="top_p", + value=1 + ) + + temperature = gr.Slider( + minimum=0, + maximum=1, + step=0.05, + label="temperature", + value=1 + ) + + inference_button.click( + get_tts_wav, + [text, text_language, how_to_cut, top_k, top_p, temperature, speed, if_freeze], + [output], + ) + + + +if __name__ == '__main__': + app.queue().launch(#concurrency_count=511, max_size=1022 + server_name="0.0.0.0", + inbrowser=False, + share=False, + server_port=9872, + quiet=True, + )