From 6ca4aecea2b154551949b4f184fa9c0555b1395d Mon Sep 17 00:00:00 2001 From: KamioRinn <63162909+KamioRinn@users.noreply.github.com> Date: Tue, 20 Aug 2024 11:47:24 +0800 Subject: [PATCH] =?UTF-8?q?API=E4=BF=AE=E5=A4=8D=E4=BC=98=E5=8C=96=20(#150?= =?UTF-8?q?3)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * model control * Mix timbre * Fix some detail problems * Optimize detail * Add int32 * Add example * Add aac pcm32 support --- api.py | 184 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 135 insertions(+), 49 deletions(-) diff --git a/api.py b/api.py index 3b17394..c5f7024 100644 --- a/api.py +++ b/api.py @@ -20,6 +20,7 @@ `-hp` - `覆盖 config.py 使用半精度` `-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` ·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` +·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"` ·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入` `-hb` - `cnhubert路径` @@ -74,7 +75,7 @@ RESP: 手动指定当次推理所使用的参考音频,并提供参数: GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1` + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"` POST: ```json { @@ -86,7 +87,8 @@ POST: "top_k": 20, "top_p": 0.6, "temperature": 0.6, - "speed": 1 + "speed": 1, + "inp_refs": ["456.wav","789.wav"] } ``` @@ -153,7 +155,7 @@ from time import time as ttime import torch import librosa import soundfile as sf -from fastapi import FastAPI, Request, HTTPException +from fastapi import FastAPI, Request, Query, HTTPException from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -195,8 +197,24 @@ def is_full(*items): # 任意一项为空返回False return True -def change_sovits_weights(sovits_path): - global vq_model, hps +class Speaker: + def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None): + self.name = name + self.sovits = sovits + self.gpt = gpt + self.phones = phones + self.bert = bert + self.prompt = prompt + +speaker_list = {} + + +class Sovits: + def __init__(self, vq_model, hps): + self.vq_model = vq_model + self.hps = hps + +def get_sovits_weights(sovits_path): dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] hps = DictToAttrRecursive(hps) @@ -205,7 +223,7 @@ def change_sovits_weights(sovits_path): hps.model.version = "v1" else: hps.model.version = "v2" - print("sovits版本:",hps.model.version) + logger.info(f"模型版本: {hps.model.version}") model_params_dict = vars(hps.model) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, @@ -222,10 +240,17 @@ def change_sovits_weights(sovits_path): vq_model.eval() vq_model.load_state_dict(dict_s2["weight"], strict=False) + sovits = Sovits(vq_model, hps) + return sovits -def change_gpt_weights(gpt_path): - global hz, max_sec, t2s_model, config - hz = 50 +class Gpt: + def __init__(self, max_sec, t2s_model): + self.max_sec = max_sec + self.t2s_model = t2s_model + +global hz +hz = 50 +def get_gpt_weights(gpt_path): dict_s1 = torch.load(gpt_path, map_location="cpu") config = dict_s1["config"] max_sec = config["data"]["max_sec"] @@ -238,6 +263,19 @@ def change_gpt_weights(gpt_path): total = sum([param.nelement() for param in t2s_model.parameters()]) logger.info("Number of parameter: %.2fM" % (total / 1e6)) + gpt = Gpt(max_sec, t2s_model) + return gpt + +def change_gpt_sovits_weights(gpt_path,sovits_path): + try: + gpt = get_gpt_weights(gpt_path) + sovits = get_sovits_weights(sovits_path) + except Exception as e: + return JSONResponse({"code": 400, "message": str(e)}, status_code=400) + + speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + def get_bert_feature(text, word2ph): with torch.no_grad(): @@ -289,14 +327,14 @@ def get_phones_and_bert(text,language,version,final=False): if language == "zh": if re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) - formattext = chinese.text_normalize(formattext) + formattext = chinese.mix_text_normalize(formattext) return get_phones_and_bert(formattext,"zh",version) else: phones, word2ph, norm_text = clean_text_inf(formattext, language, version) bert = get_bert_feature(norm_text, word2ph).to(device) elif language == "yue" and re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) - formattext = chinese.text_normalize(formattext) + formattext = chinese.mix_text_normalize(formattext) return get_phones_and_bert(formattext,"yue",version) else: phones, word2ph, norm_text = clean_text_inf(formattext, language, version) @@ -375,8 +413,11 @@ class DictToAttrRecursive(dict): def get_spepc(hps, filename): - audio = load_audio(filename, int(hps.data.sampling_rate)) + audio,_ = librosa.load(filename, int(hps.data.sampling_rate)) audio = torch.FloatTensor(audio) + maxx=audio.abs().max() + if(maxx>1): + audio/=min(2,maxx) audio_norm = audio audio_norm = audio_norm.unsqueeze(0) spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, @@ -448,22 +489,32 @@ def pack_raw(audio_bytes, data, rate): def pack_wav(audio_bytes, rate): - data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16) - wav_bytes = BytesIO() - sf.write(wav_bytes, data, rate, format='wav') - + if is_int32: + data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32') + else: + data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format='WAV') return wav_bytes def pack_aac(audio_bytes, data, rate): + if is_int32: + pcm = 's32le' + bit_rate = '256k' + else: + pcm = 's16le' + bit_rate = '128k' process = subprocess.Popen([ 'ffmpeg', - '-f', 's16le', # 输入16位有符号小端整数PCM + '-f', pcm, # 输入16位有符号小端整数PCM '-ar', str(rate), # 设置采样率 '-ac', '1', # 单声道 '-i', 'pipe:0', # 从管道读取输入 '-c:a', 'aac', # 音频编码器为AAC - '-b:a', '192k', # 比特率 + '-b:a', bit_rate, # 比特率 '-vn', # 不包含视频 '-f', 'adts', # 输出AAC数据流格式 'pipe:1' # 将输出写入管道 @@ -504,10 +555,21 @@ def only_punc(text): return not any(t.isalnum() or t.isalpha() for t in text) -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1): +splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"): + infer_sovits = speaker_list[spk].sovits + vq_model = infer_sovits.vq_model + hps = infer_sovits.hps + + infer_gpt = speaker_list[spk].gpt + t2s_model = infer_gpt.t2s_model + max_sec = infer_gpt.max_sec + t0 = ttime() prompt_text = prompt_text.strip("\n") + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." prompt_language, text = prompt_language, text.strip("\n") + dtype = torch.float16 if is_half == True else torch.float32 zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) @@ -523,6 +585,19 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + refers=[] + if(inp_refs): + for path in inp_refs: + try: + refer = get_spepc(hps, path).to(dtype).to(device) + refers.append(refer) + except Exception as e: + logger.error(e) + if(len(refers)==0): + refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + t1 = ttime() version = vq_model.version os.environ['version'] = version @@ -538,16 +613,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, continue audio_opt = [] + if (text[-1] not in splits): text += "。" if text_language != "en" else "." phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = prompt_semantic.unsqueeze(0).to(device) t2 = ttime() with torch.no_grad(): - # pred_semantic = t2s_model.model.infer( pred_semantic, idx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, @@ -558,23 +632,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_p = top_p, temperature = temperature, early_stop_num=hz * max_sec) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 - refer = get_spepc(hps, ref_wav_path) # .to(device) - if (is_half == True): - refer = refer.half().to(device) - else: - refer = refer.to(device) - # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = \ vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refer,speed=speed).detach().cpu().numpy()[ + refers,speed=speed).detach().cpu().numpy()[ 0, 0] ###试试重建不带上prompt部分 + max_audio=np.abs(audio).max() + if max_audio>1: + audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() - audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) + if is_int32: + audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate) + else: + audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) if stream_mode == "normal": audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) @@ -615,7 +688,7 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed): +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -634,7 +707,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu else: text = cut_text(text,cut_punc) - return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type) + return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type) @@ -691,6 +764,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals # 此时 full_precision==True, half_precision==False parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") +parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32") parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") # 切割常用分句符为 `python ./api.py -cp ".?!。?!"` parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") @@ -752,6 +826,14 @@ else: media_type = "ogg" logger.info(f"编码格式: {media_type}") +# 音频数据类型 +if args.sub_type.lower() == 'int32': + is_int32 = True + logger.info(f"数据类型: int32") +else: + is_int32 = False + logger.info(f"数据类型: int16") + # 初始化模型 cnhubert.cnhubert_base_path = cnhubert_base_path tokenizer = AutoTokenizer.from_pretrained(bert_path) @@ -763,9 +845,7 @@ if is_half: else: bert_model = bert_model.to(device) ssl_model = ssl_model.to(device) -change_sovits_weights(sovits_path) -change_gpt_weights(gpt_path) - +change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path) @@ -777,14 +857,18 @@ app = FastAPI() @app.post("/set_model") async def set_model(request: Request): json_post_raw = await request.json() - global gpt_path - gpt_path=json_post_raw.get("gpt_model_path") - global sovits_path - sovits_path=json_post_raw.get("sovits_model_path") - logger.info("gptpath"+gpt_path+";vitspath"+sovits_path) - change_sovits_weights(sovits_path) - change_gpt_weights(gpt_path) - return "ok" + return change_gpt_sovits_weights( + gpt_path = json_post_raw.get("gpt_model_path"), + sovits_path = json_post_raw.get("sovits_model_path") + ) + + +@app.get("/set_model") +async def set_model( + gpt_model_path: str = None, + sovits_model_path: str = None, +): + return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path) @app.post("/control") @@ -827,10 +911,11 @@ async def tts_endpoint(request: Request): json_post_raw.get("text"), json_post_raw.get("text_language"), json_post_raw.get("cut_punc"), - json_post_raw.get("top_k", 10), + json_post_raw.get("top_k", 15), json_post_raw.get("top_p", 1.0), json_post_raw.get("temperature", 1.0), - json_post_raw.get("speed", 1.0) + json_post_raw.get("speed", 1.0), + json_post_raw.get("inp_refs", []) ) @@ -842,12 +927,13 @@ async def tts_endpoint( text: str = None, text_language: str = None, cut_punc: str = None, - top_k: int = 10, + top_k: int = 15, top_p: float = 1.0, temperature: float = 1.0, - speed: float = 1.0 + speed: float = 1.0, + inp_refs: list = Query(default=[]) ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs) if __name__ == "__main__":