mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
custom modifications
This commit is contained in:
parent
c4606c1cc1
commit
e939bd61b2
@ -285,10 +285,11 @@ class TTS:
|
||||
def init_cnhuhbert_weights(self, base_path: str):
|
||||
print(f"Loading CNHuBERT weights from {base_path}")
|
||||
self.cnhuhbert_model = CNHubert(base_path)
|
||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.eval()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
self.cnhuhbert_model = torch.compile(self.cnhuhbert_model)
|
||||
|
||||
|
||||
|
||||
@ -296,10 +297,11 @@ class TTS:
|
||||
print(f"Loading BERT weights from {base_path}")
|
||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||
self.bert_model=self.bert_model.eval()
|
||||
self.bert_model = self.bert_model.eval()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.bert_model = self.bert_model.half()
|
||||
self.bert_model = torch.compile(self.bert_model)
|
||||
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
print(f"Loading VITS weights from {weights_path}")
|
||||
@ -334,26 +336,28 @@ class TTS:
|
||||
vits_model = vits_model.to(self.configs.device)
|
||||
vits_model = vits_model.eval()
|
||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
vits_model = torch.compile(vits_model)
|
||||
self.vits_model = vits_model
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.vits_model = self.vits_model.half()
|
||||
|
||||
|
||||
def init_t2s_weights(self, weights_path: str):
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
self.configs.hz = 50
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
|
||||
config = dict_s1["config"]
|
||||
self.configs.max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model = t2s_model.eval()
|
||||
self.t2s_model = t2s_model
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
self.configs.hz = 50
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
|
||||
config = dict_s1["config"]
|
||||
self.configs.max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model = t2s_model.eval()
|
||||
t2s_model = torch.compile(t2s_model)
|
||||
self.t2s_model = t2s_model
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||
'''
|
||||
|
@ -1,11 +1,11 @@
|
||||
custom:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cuda
|
||||
is_half: true
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_weights_v2/amiya-e50.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
vits_weights_path: SoVITS_weights_v2/amiya_e25_s950.pth
|
||||
default:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
|
BIN
GPT_SoVITS/text/ja_userdic/user.dict
Normal file
BIN
GPT_SoVITS/text/ja_userdic/user.dict
Normal file
Binary file not shown.
1
GPT_SoVITS/text/ja_userdic/userdict.md5
Normal file
1
GPT_SoVITS/text/ja_userdic/userdict.md5
Normal file
@ -0,0 +1 @@
|
||||
d36bd5ffba62f195d22bf4f1a41cd08f
|
442
api_amiya.py
Normal file
442
api_amiya.py
Normal file
@ -0,0 +1,442 @@
|
||||
"""
|
||||
# WebAPI文档
|
||||
|
||||
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
|
||||
|
||||
## 执行参数:
|
||||
`-a` - `绑定地址, 默认"127.0.0.1"`
|
||||
`-p` - `绑定端口, 默认9880`
|
||||
`-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
|
||||
|
||||
## 调用:
|
||||
|
||||
### 推理
|
||||
|
||||
endpoint: `/tts`
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
|
||||
```
|
||||
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"input": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
}
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: 直接返回 wav 音频流, http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
### 命令控制
|
||||
|
||||
endpoint: `/control`
|
||||
|
||||
command:
|
||||
"restart": 重新运行
|
||||
"exit": 结束运行
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/control?command=restart
|
||||
```
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"command": "restart"
|
||||
}
|
||||
```
|
||||
|
||||
RESP: 无
|
||||
|
||||
|
||||
### 切换GPT模型
|
||||
|
||||
endpoint: `/set_gpt_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
```
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
|
||||
### 切换Sovits模型
|
||||
|
||||
endpoint: `/set_sovits_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % now_dir)
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import wave
|
||||
import signal
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
from io import BytesIO
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from pydantic import BaseModel
|
||||
cut_method_names = get_cut_method_names()
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
||||
args = parser.parse_args()
|
||||
config_path = args.tts_config
|
||||
# device = args.device
|
||||
port = args.port
|
||||
host = args.bind_addr
|
||||
argv = sys.argv
|
||||
|
||||
if config_path in [None, ""]:
|
||||
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
|
||||
|
||||
tts_config = TTS_Config(config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
REFERENCE_AUDIO = "/Users/baysonfox/Desktop/amiya-chatbot/reference.mp3"
|
||||
AUXILARY_AUDIO = [
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_001_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_002_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_005_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_007_seg_1.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_008_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_009_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_010_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_011_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_012_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_028_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_029_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_032_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_043_seg_0.mp3",
|
||||
"/Users/baysonfox/Desktop/amiya-chatbot/references/data_char_1037_amiya3_CN_044_seg_0.mp3"
|
||||
]
|
||||
APP = FastAPI()
|
||||
class TTS_Request(BaseModel):
|
||||
input: str = None
|
||||
text_lang: str = "all_zh"
|
||||
ref_audio_path: str = REFERENCE_AUDIO
|
||||
aux_ref_audio_paths: list = AUXILARY_AUDIO
|
||||
prompt_lang: str = "all_zh"
|
||||
prompt_text: str = "博士,休息好了吗?还觉得累的话,不用勉强的。有我在呢。"
|
||||
top_k:int = 5
|
||||
top_p:float = 1
|
||||
temperature:float = 1
|
||||
text_split_method:str = "cut5"
|
||||
batch_size:int = 1
|
||||
batch_threshold:float = 0.75
|
||||
split_bucket:bool = True
|
||||
speed_factor:float = 1.0
|
||||
fragment_interval:float = 0.3
|
||||
seed:int = -1
|
||||
media_type:str = "wav"
|
||||
streaming_mode:bool = False
|
||||
parallel_infer:bool = True
|
||||
repetition_penalty:float = 1.35
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||
audio_file.write(data)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
io_buffer = BytesIO()
|
||||
sf.write(io_buffer, data, rate, format='wav')
|
||||
return io_buffer
|
||||
|
||||
def pack_aac(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
process = subprocess.Popen([
|
||||
'ffmpeg',
|
||||
'-f', 's16le', # 输入16位有符号小端整数PCM
|
||||
'-ar', str(rate), # 设置采样率
|
||||
'-ac', '1', # 单声道
|
||||
'-i', 'pipe:0', # 从管道读取输入
|
||||
'-c:a', 'aac', # 音频编码器为AAC
|
||||
'-b:a', '192k', # 比特率
|
||||
'-vn', # 不包含视频
|
||||
'-f', 'adts', # 输出AAC数据流格式
|
||||
'pipe:1' # 将输出写入管道
|
||||
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
|
||||
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
# This will create a wave header then append the frame input
|
||||
# It should be first on a streaming wav file
|
||||
# Other frames better should not have it (else you will hear some artifacts each chunk start)
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
def handle_control(command:str):
|
||||
if command == "restart":
|
||||
os.execl(sys.executable, sys.executable, *argv)
|
||||
elif command == "exit":
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
|
||||
|
||||
def check_params(req:dict):
|
||||
text:str = req.get("text", "")
|
||||
text_lang:str = req.get("text_lang", "")
|
||||
ref_audio_path:str = req.get("ref_audio_path", "")
|
||||
streaming_mode:bool = req.get("streaming_mode", False)
|
||||
media_type:str = req.get("media_type", "wav")
|
||||
prompt_lang:str = req.get("prompt_lang", "")
|
||||
text_split_method:str = req.get("text_split_method", "cut5")
|
||||
|
||||
if ref_audio_path in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||
if text in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||
if (text_lang in [None, ""]) :
|
||||
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
|
||||
elif text_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(status_code=400, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"})
|
||||
if (prompt_lang in [None, ""]) :
|
||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
||||
elif prompt_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(status_code=400, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"})
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
|
||||
elif media_type == "ogg" and not streaming_mode:
|
||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||
|
||||
if text_split_method not in cut_method_names:
|
||||
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
|
||||
|
||||
return None
|
||||
|
||||
async def tts_handle(req:dict):
|
||||
"""
|
||||
Text to speech handler.
|
||||
|
||||
Args:
|
||||
req (dict):
|
||||
{
|
||||
"text": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
"""
|
||||
req['text'] = req.get('input')
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
return_fragment = req.get("return_fragment", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode or return_fragment:
|
||||
req["return_fragment"] = True
|
||||
|
||||
try:
|
||||
tts_generator=tts_pipeline.run(req)
|
||||
|
||||
if streaming_mode:
|
||||
def streaming_generator(tts_generator:Generator, media_type:str):
|
||||
if media_type == "wav":
|
||||
yield wave_header_chunk()
|
||||
media_type = "raw"
|
||||
for sr, chunk in tts_generator:
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return Response(audio_data, media_type=f"audio/{media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@APP.get("/control")
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||
handle_control(command)
|
||||
|
||||
|
||||
|
||||
@APP.get("/tts")
|
||||
async def tts_get_endpoint(
|
||||
text: str = None,
|
||||
text_lang: str = "all_zh",
|
||||
ref_audio_path: str = REFERENCE_AUDIO,
|
||||
aux_ref_audio_paths:list = AUXILARY_AUDIO,
|
||||
prompt_lang: str = "all_zh",
|
||||
prompt_text: str = "博士,休息好了吗?还觉得累的话,不用勉强的。有我在呢。",
|
||||
top_k:int = 5,
|
||||
top_p:float = 1,
|
||||
temperature:float = 1,
|
||||
text_split_method:str = "cut0",
|
||||
batch_size:int = 1,
|
||||
batch_threshold:float = 0.75,
|
||||
split_bucket:bool = True,
|
||||
speed_factor:float = 1.0,
|
||||
fragment_interval:float = 0.3,
|
||||
seed:int = -1,
|
||||
media_type:str = "wav",
|
||||
streaming_mode:bool = False,
|
||||
parallel_infer:bool = True,
|
||||
repetition_penalty:float = 1.35
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
"text_lang": text_lang.lower(),
|
||||
"ref_audio_path": REFERENCE_AUDIO,
|
||||
"aux_ref_audio_paths": AUXILARY_AUDIO,
|
||||
"prompt_text": prompt_text,
|
||||
"prompt_lang": prompt_lang.lower(),
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": text_split_method,
|
||||
"batch_size":int(batch_size),
|
||||
"batch_threshold":float(batch_threshold),
|
||||
"speed_factor":float(speed_factor),
|
||||
"split_bucket":split_bucket,
|
||||
"fragment_interval":fragment_interval,
|
||||
"seed":seed,
|
||||
"media_type":media_type,
|
||||
"streaming_mode":streaming_mode,
|
||||
"parallel_infer":parallel_infer,
|
||||
"repetition_penalty":float(repetition_penalty)
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
@APP.post("/tts")
|
||||
async def tts_post_endpoint(request: TTS_Request):
|
||||
req = request.model_dump()
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
@APP.get("/set_refer_audio")
|
||||
async def set_refer_aduio(refer_audio_path: str = None):
|
||||
try:
|
||||
tts_pipeline.set_ref_audio(refer_audio_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
@APP.get("/set_gpt_weights")
|
||||
async def set_gpt_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
|
||||
tts_pipeline.init_t2s_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)})
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
@APP.get("/set_sovits_weights")
|
||||
async def set_sovits_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
|
||||
tts_pipeline.init_vits_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈
|
||||
host = None
|
||||
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
630
webui_shared.py
Normal file
630
webui_shared.py
Normal file
@ -0,0 +1,630 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import librosa
|
||||
from feature_extractor import cnhubert
|
||||
from GPT_SoVITS.module.models import SynthesizerTrn
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import chinese, cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from time import time as ttime
|
||||
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
||||
from tools.my_utils import load_audio
|
||||
import torch, torchaudio
|
||||
import traceback
|
||||
import os, re
|
||||
|
||||
# 模型路径
|
||||
gpt_path = 'GPT_weights_v2/amiya-e50.ckpt'
|
||||
sovits_path = 'SoVITS_weights_v2/amiya_e25_s950.pth'
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
|
||||
# 参考音频相关配置
|
||||
REFERENCE_AUDIO = "/Users/baysonfox/Desktop/amiya-chatbot/reference.mp3"
|
||||
REFERENCE_TEXT = "博士,休息好了吗?还觉得累的话,不用勉强的。有我在呢。"
|
||||
INF_REFS = [os.path.join("/Users/baysonfox/Desktop/amiya-chatbot/references", f) for f in os.listdir("/Users/baysonfox/Desktop/amiya-chatbot/references")]
|
||||
|
||||
# 模型相关设置
|
||||
DEVICE = 'cpu'
|
||||
dict_language = {
|
||||
"中文": "all_zh",#全部按中文识别
|
||||
"英文": "en",#全部按英文识别#######不变
|
||||
"日文": "all_ja",#全部按日文识别
|
||||
"中英混合": "zh",#按中英混合识别####不变
|
||||
"日英混合": "ja",#按日英混合识别####不变
|
||||
"多语种混合": "auto",#多语种启动切分识别语种
|
||||
}
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "False"
|
||||
|
||||
DTYPE = torch.float32
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
ssl_model = cnhubert.get_model()
|
||||
|
||||
# Accelerated Inference
|
||||
tokenizer = torch.compile(tokenizer)
|
||||
bert_model = torch.compile(bert_model)
|
||||
ssl_model = torch.compile(ssl_model)
|
||||
|
||||
# 标点符号
|
||||
PUNCTUATION = {'!', '?', '…', ',', '.', '-', " "}
|
||||
# 中文标点符号
|
||||
SPLITS = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(DEVICE)
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
resample_transform_dict={}
|
||||
def resample(audio_tensor, sr0):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
|
||||
sr0, 24000
|
||||
).to(DEVICE)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
||||
|
||||
def change_sovits_weights(prompt_language=None,text_language=None):
|
||||
global vq_model, hps, version, model_version, dict_language
|
||||
model_version = version = "v2"
|
||||
|
||||
if prompt_language is not None and text_language is not None:
|
||||
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value': "all_zh"}
|
||||
|
||||
if text_language in list(dict_language.keys()):
|
||||
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':"中文"}
|
||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": False},{"__type__": "update", "visible": False},{"__type__": "update", "value": False,"interactive": False}
|
||||
|
||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||
|
||||
hps = DictToAttrRecursive(dict_s2["config"])
|
||||
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
hps.model.version = "v2"
|
||||
version = hps.model.version
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model
|
||||
)
|
||||
model_version = version
|
||||
|
||||
vq_model = vq_model.to(DEVICE)
|
||||
print("loading sovits_%s" % model_version,vq_model.load_state_dict(
|
||||
dict_s2["weight"],
|
||||
strict=False))
|
||||
vq_model.eval()
|
||||
vq_model = torch.compile(vq_model)
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
global hz, max_sec, t2s_model, config
|
||||
hz = 50
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
config = dict_s1["config"]
|
||||
max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
t2s_model = t2s_model.to(DEVICE)
|
||||
t2s_model.eval()
|
||||
t2s_model = torch.compile(t2s_model)
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
print("hps samplingrate", hps.data.sampling_rate)
|
||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
maxx = audio.abs().max()
|
||||
if(maxx > 1):
|
||||
audio /= min(2, maxx.item())
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
hps.data.filter_length,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return spec
|
||||
|
||||
def clean_text_inf(text, language, version):
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
print("phones: ", phones)
|
||||
print("word2ph: ", word2ph)
|
||||
print("norm_text: ", norm_text)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
language=language.replace("all_","")
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(DEVICE)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32
|
||||
).to(DEVICE)
|
||||
|
||||
return bert
|
||||
|
||||
def get_first(text):
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in SPLITS) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
def get_phones_and_bert(text,language,version,final=False):
|
||||
if language in {"en", "all_zh", "all_ja"}:
|
||||
language = language.replace("all_","")
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"zh",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = get_bert_feature(norm_text, word2ph).to(DEVICE)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"yue",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32
|
||||
).to(DEVICE)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
if language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
print(textlist)
|
||||
print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
||||
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return get_phones_and_bert("." + text,language,version,final=True)
|
||||
|
||||
return phones,bert.to(DTYPE),norm_text
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
spec=spectrogram_torch(y,n_fft,sampling_rate,hop_size,win_size,center)
|
||||
mel=spec_to_mel_torch(spec,n_fft,num_mels,sampling_rate,fmin,fmax)
|
||||
return mel
|
||||
mel_fn_args = {
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False
|
||||
}
|
||||
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
def norm_spec(x):
|
||||
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||||
def denorm_spec(x):
|
||||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||
mel_fn=lambda x: mel_spectrogram(x, **mel_fn_args)
|
||||
|
||||
def merge_short_text_in_array(texts, threshold):
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
text = ""
|
||||
for ele in texts:
|
||||
text += ele
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if (len(text) > 0):
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
result[len(result) - 1] += text
|
||||
return result
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in SPLITS:
|
||||
todo_text += "。"
|
||||
i_split_head = i_split_tail = 0
|
||||
len_text = len(todo_text)
|
||||
todo_texts = []
|
||||
while 1:
|
||||
if i_split_head >= len_text:
|
||||
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||
if todo_text[i_split_head] in SPLITS:
|
||||
i_split_head += 1
|
||||
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||
i_split_tail = i_split_head
|
||||
else:
|
||||
i_split_head += 1
|
||||
return todo_texts
|
||||
|
||||
def cut1(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
split_idx = list(range(0, len(inps), 4))
|
||||
split_idx[-1] = None
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
opts = [item for item in opts if not set(item).issubset(PUNCTUATION)]
|
||||
return "\n".join(opts)
|
||||
|
||||
def cut2(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
if len(inps) < 2:
|
||||
return inp
|
||||
opts = []
|
||||
summ = 0
|
||||
tmp_str = ""
|
||||
for i in range(len(inps)):
|
||||
summ += len(inps[i])
|
||||
tmp_str += inps[i]
|
||||
if summ > 50:
|
||||
summ = 0
|
||||
opts.append(tmp_str)
|
||||
tmp_str = ""
|
||||
if tmp_str != "":
|
||||
opts.append(tmp_str)
|
||||
# print(opts)
|
||||
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||
opts[-2] = opts[-2] + opts[-1]
|
||||
opts = opts[:-1]
|
||||
opts = [item for item in opts if not set(item).issubset(PUNCTUATION)]
|
||||
return "\n".join(opts)
|
||||
|
||||
def cut3(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = ["%s" % item for item in inp.strip("。").split("。")]
|
||||
opts = [item for item in opts if not set(item).issubset(PUNCTUATION)]
|
||||
return "\n".join(opts)
|
||||
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
|
||||
opts = [item for item in opts if not set(item).issubset(PUNCTUATION)]
|
||||
return "\n".join(opts)
|
||||
|
||||
def cut5(inp):
|
||||
inp = inp.strip("\n")
|
||||
punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
|
||||
mergeitems = []
|
||||
items = []
|
||||
|
||||
for i, char in enumerate(inp):
|
||||
if char in punds:
|
||||
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
items.append(char)
|
||||
else:
|
||||
items.append(char)
|
||||
mergeitems.append("".join(items))
|
||||
items = []
|
||||
else:
|
||||
items.append(char)
|
||||
|
||||
if items:
|
||||
mergeitems.append("".join(items))
|
||||
|
||||
opt = [item for item in mergeitems if not set(item).issubset(punds)]
|
||||
return "\n".join(opt)
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
parts = re.split('(\d+)', s)
|
||||
# 将数字部分转换为整数,非数字部分保持不变
|
||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||
return parts
|
||||
|
||||
def process_text(texts):
|
||||
_text=[]
|
||||
if all(text in [None, " ", "\n",""] for text in texts):
|
||||
raise ValueError("请输入有效文本")
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
pass
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
cache = {}
|
||||
def get_tts_wav(text,
|
||||
text_language,
|
||||
how_to_cut="不切",
|
||||
top_k=20,
|
||||
top_p=0.6,
|
||||
temperature=0.6,
|
||||
speed=1,
|
||||
if_freeze=False,
|
||||
sample_steps=8):
|
||||
|
||||
global cache
|
||||
prompt_text = REFERENCE_TEXT
|
||||
prompt_language = "中文"
|
||||
ref_wav_path = REFERENCE_AUDIO
|
||||
inp_refs = INF_REFS
|
||||
t = []
|
||||
t0 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
# 去除prompt_text的换行
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
# 手动添加标点符号
|
||||
if (prompt_text[-1] not in SPLITS): prompt_text += "。" if prompt_language != "en" else "."
|
||||
print("实际输入的参考文本:", prompt_text)
|
||||
text = text.strip("\n")
|
||||
|
||||
print("实际输入的目标文本:", text)
|
||||
zero_wav = np.zeros(
|
||||
int(hps.data.sampling_rate * 0.3),
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# 参考音频和sampling rate,numpy格式
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
# numpy -> torch
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
wav16k = wav16k.to(DEVICE)
|
||||
zero_wav_torch = zero_wav_torch.to(DEVICE)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
)
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(DEVICE)
|
||||
|
||||
t1 = ttime()
|
||||
t.append(t1-t0)
|
||||
|
||||
if how_to_cut != "不切":
|
||||
cut_map = {
|
||||
"凑四句一切": cut1,
|
||||
"凑50字一切": cut2,
|
||||
"按中文句号。切": cut3,
|
||||
"按英文句号.切": cut4,
|
||||
"按标点符号切": cut5
|
||||
}
|
||||
cut_map[how_to_cut](text)
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
|
||||
print("实际输入的目标文本(切句后):", text)
|
||||
|
||||
texts = text.split("\n")
|
||||
texts = process_text(texts)
|
||||
texts = merge_short_text_in_array(texts, 5)
|
||||
audio_opt = []
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
|
||||
for i_text,text in enumerate(texts):
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if (text[-1] not in SPLITS): text += "。" if text_language != "en" else "."
|
||||
print("实际输入的目标文本(每句):", text)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
|
||||
print("前端处理后的文本(每句):", norm_text2)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(DEVICE).unsqueeze(0)
|
||||
|
||||
bert = bert.to(DEVICE).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(DEVICE)
|
||||
|
||||
t2 = ttime()
|
||||
if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
cache[i_text]=pred_semantic
|
||||
t3 = ttime()
|
||||
refers=[]
|
||||
if(inp_refs):
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path).to(DTYPE).to(DEVICE)
|
||||
refers.append(refer)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(DTYPE).to(DEVICE)]
|
||||
audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(DEVICE).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
|
||||
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
t.extend([t2 - t1,t3 - t2, t4 - t3])
|
||||
t1 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" %
|
||||
(t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
|
||||
)
|
||||
sr=hps.data.sampling_rate if model_version!="v3"else 24000
|
||||
yield sr, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
|
||||
try:next(change_sovits_weights(sovits_path))
|
||||
except:pass
|
||||
change_gpt_weights(gpt_path) # 初始化GPT模型
|
||||
|
||||
gr.Markdown(value="<h1 style='color: #2b5278; font-size: 24px; text-align: center;'>大概可能也许是阿米娅的声音(</h1>")
|
||||
with gr.Row() as main_row:
|
||||
with gr.Column(scale=7) as text_column:
|
||||
text = gr.Textbox(
|
||||
label="需要合成的文本",
|
||||
value="",
|
||||
lines=13,
|
||||
max_lines=13
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
inference_button = gr.Button("合成语音", variant="primary", size='lg')
|
||||
output = gr.Audio(label="输出的语音")
|
||||
|
||||
with gr.Column(scale=5) as control_column:
|
||||
text_language = gr.Dropdown(
|
||||
label="需要合成的语种。限制范围越小判别效果越好。",
|
||||
choices=list(dict_language.keys()),
|
||||
value="中文"
|
||||
)
|
||||
|
||||
how_to_cut = gr.Dropdown(
|
||||
label="怎么切",
|
||||
choices=["不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"],
|
||||
value="凑四句一切"
|
||||
)
|
||||
|
||||
gr.Markdown(value="语速调整,高为更快")
|
||||
|
||||
if_freeze = gr.Checkbox(
|
||||
label="是否直接对上次合成结果调整语速和音色。防止随机性。",
|
||||
value=False
|
||||
)
|
||||
|
||||
speed = gr.Slider(
|
||||
minimum=0.6,
|
||||
maximum=1.65,
|
||||
step=0.05,
|
||||
label="语速",
|
||||
value=1
|
||||
)
|
||||
|
||||
gr.Markdown("GPT采样参数(无参考文本时不要太低。不懂就用默认):")
|
||||
|
||||
top_k = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
label="top_k",
|
||||
value=15
|
||||
)
|
||||
|
||||
top_p = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.05,
|
||||
label="top_p",
|
||||
value=1
|
||||
)
|
||||
|
||||
temperature = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.05,
|
||||
label="temperature",
|
||||
value=1
|
||||
)
|
||||
|
||||
inference_button.click(
|
||||
get_tts_wav,
|
||||
[text, text_language, how_to_cut, top_k, top_p, temperature, speed, if_freeze],
|
||||
[output],
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.queue().launch(#concurrency_count=511, max_size=1022
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=False,
|
||||
share=False,
|
||||
server_port=9872,
|
||||
quiet=True,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user