API修复优化 (#1503)

* model control

* Mix timbre

* Fix some detail problems

* Optimize detail

* Add int32

* Add example

* Add aac pcm32 support
This commit is contained in:
KamioRinn 2024-08-20 11:47:24 +08:00 committed by GitHub
parent b3e8eb40c2
commit 6ca4aecea2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

184
api.py
View File

@ -20,6 +20,7 @@
`-hp` - `覆盖 config.py 使用半精度`
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
·-cp` - `文本切分符号设定, 默认为空, ",.,。"字符串的方式传入`
`-hb` - `cnhubert路径`
@ -74,7 +75,7 @@ RESP:
手动指定当次推理所使用的参考音频并提供参数:
GET:
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1`
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"`
POST:
```json
{
@ -86,7 +87,8 @@ POST:
"top_k": 20,
"top_p": 0.6,
"temperature": 0.6,
"speed": 1
"speed": 1,
"inp_refs": ["456.wav","789.wav"]
}
```
@ -153,7 +155,7 @@ from time import time as ttime
import torch
import librosa
import soundfile as sf
from fastapi import FastAPI, Request, HTTPException
from fastapi import FastAPI, Request, Query, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -195,8 +197,24 @@ def is_full(*items): # 任意一项为空返回False
return True
def change_sovits_weights(sovits_path):
global vq_model, hps
class Speaker:
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
self.name = name
self.sovits = sovits
self.gpt = gpt
self.phones = phones
self.bert = bert
self.prompt = prompt
speaker_list = {}
class Sovits:
def __init__(self, vq_model, hps):
self.vq_model = vq_model
self.hps = hps
def get_sovits_weights(sovits_path):
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
@ -205,7 +223,7 @@ def change_sovits_weights(sovits_path):
hps.model.version = "v1"
else:
hps.model.version = "v2"
print("sovits版本:",hps.model.version)
logger.info(f"模型版本: {hps.model.version}")
model_params_dict = vars(hps.model)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
@ -222,10 +240,17 @@ def change_sovits_weights(sovits_path):
vq_model.eval()
vq_model.load_state_dict(dict_s2["weight"], strict=False)
sovits = Sovits(vq_model, hps)
return sovits
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config
hz = 50
class Gpt:
def __init__(self, max_sec, t2s_model):
self.max_sec = max_sec
self.t2s_model = t2s_model
global hz
hz = 50
def get_gpt_weights(gpt_path):
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
max_sec = config["data"]["max_sec"]
@ -238,6 +263,19 @@ def change_gpt_weights(gpt_path):
total = sum([param.nelement() for param in t2s_model.parameters()])
logger.info("Number of parameter: %.2fM" % (total / 1e6))
gpt = Gpt(max_sec, t2s_model)
return gpt
def change_gpt_sovits_weights(gpt_path,sovits_path):
try:
gpt = get_gpt_weights(gpt_path)
sovits = get_sovits_weights(sovits_path)
except Exception as e:
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def get_bert_feature(text, word2ph):
with torch.no_grad():
@ -289,14 +327,14 @@ def get_phones_and_bert(text,language,version,final=False):
if language == "zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
@ -375,8 +413,11 @@ class DictToAttrRecursive(dict):
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
maxx=audio.abs().max()
if(maxx>1):
audio/=min(2,maxx)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
@ -448,22 +489,32 @@ def pack_raw(audio_bytes, data, rate):
def pack_wav(audio_bytes, rate):
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='wav')
if is_int32:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
else:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV')
return wav_bytes
def pack_aac(audio_bytes, data, rate):
if is_int32:
pcm = 's32le'
bit_rate = '256k'
else:
pcm = 's16le'
bit_rate = '128k'
process = subprocess.Popen([
'ffmpeg',
'-f', 's16le', # 输入16位有符号小端整数PCM
'-f', pcm, # 输入16位有符号小端整数PCM
'-ar', str(rate), # 设置采样率
'-ac', '1', # 单声道
'-i', 'pipe:0', # 从管道读取输入
'-c:a', 'aac', # 音频编码器为AAC
'-b:a', '192k', # 比特率
'-b:a', bit_rate, # 比特率
'-vn', # 不包含视频
'-f', 'adts', # 输出AAC数据流格式
'pipe:1' # 将输出写入管道
@ -504,10 +555,21 @@ def only_punc(text):
return not any(t.isalnum() or t.isalpha() for t in text)
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1):
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
infer_sovits = speaker_list[spk].sovits
vq_model = infer_sovits.vq_model
hps = infer_sovits.hps
infer_gpt = speaker_list[spk].gpt
t2s_model = infer_gpt.t2s_model
max_sec = infer_gpt.max_sec
t0 = ttime()
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
prompt_language, text = prompt_language, text.strip("\n")
dtype = torch.float16 if is_half == True else torch.float32
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
@ -523,6 +585,19 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
refers=[]
if(inp_refs):
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
refers.append(refer)
except Exception as e:
logger.error(e)
if(len(refers)==0):
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
t1 = ttime()
version = vq_model.version
os.environ['version'] = version
@ -538,16 +613,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
audio_opt = []
if (text[-1] not in splits): text += "" if text_language != "en" else "."
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
@ -558,23 +632,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
top_p = top_p,
temperature = temperature,
early_stop_num=hz * max_sec)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer,speed=speed).detach().cpu().numpy()[
refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()
if max_audio>1:
audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
if is_int32:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
else:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal":
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
@ -615,7 +688,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed):
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
if (
refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None
@ -634,7 +707,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
else:
text = cut_text(text,cut_punc)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
@ -691,6 +764,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
# 此时 full_precision==True, half_precision==False
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
@ -752,6 +826,14 @@ else:
media_type = "ogg"
logger.info(f"编码格式: {media_type}")
# 音频数据类型
if args.sub_type.lower() == 'int32':
is_int32 = True
logger.info(f"数据类型: int32")
else:
is_int32 = False
logger.info(f"数据类型: int16")
# 初始化模型
cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path)
@ -763,9 +845,7 @@ if is_half:
else:
bert_model = bert_model.to(device)
ssl_model = ssl_model.to(device)
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
@ -777,14 +857,18 @@ app = FastAPI()
@app.post("/set_model")
async def set_model(request: Request):
json_post_raw = await request.json()
global gpt_path
gpt_path=json_post_raw.get("gpt_model_path")
global sovits_path
sovits_path=json_post_raw.get("sovits_model_path")
logger.info("gptpath"+gpt_path+";vitspath"+sovits_path)
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
return "ok"
return change_gpt_sovits_weights(
gpt_path = json_post_raw.get("gpt_model_path"),
sovits_path = json_post_raw.get("sovits_model_path")
)
@app.get("/set_model")
async def set_model(
gpt_model_path: str = None,
sovits_model_path: str = None,
):
return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
@app.post("/control")
@ -827,10 +911,11 @@ async def tts_endpoint(request: Request):
json_post_raw.get("text"),
json_post_raw.get("text_language"),
json_post_raw.get("cut_punc"),
json_post_raw.get("top_k", 10),
json_post_raw.get("top_k", 15),
json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0),
json_post_raw.get("speed", 1.0)
json_post_raw.get("speed", 1.0),
json_post_raw.get("inp_refs", [])
)
@ -842,12 +927,13 @@ async def tts_endpoint(
text: str = None,
text_language: str = None,
cut_punc: str = None,
top_k: int = 10,
top_k: int = 15,
top_p: float = 1.0,
temperature: float = 1.0,
speed: float = 1.0
speed: float = 1.0,
inp_refs: list = Query(default=[])
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed)
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
if __name__ == "__main__":