From 8b0fbe6d18b8fb32434fb4e44e47b4a89820e9fa Mon Sep 17 00:00:00 2001 From: KamioRinn Date: Mon, 19 Aug 2024 04:54:23 +0800 Subject: [PATCH] Mix timbre --- api.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/api.py b/api.py index a6ec11f..e1528d1 100644 --- a/api.py +++ b/api.py @@ -153,7 +153,7 @@ from time import time as ttime import torch import librosa import soundfile as sf -from fastapi import FastAPI, Request, HTTPException +from fastapi import FastAPI, Request, Query, HTTPException from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -411,7 +411,7 @@ class DictToAttrRecursive(dict): def get_spepc(hps, filename): - audio = load_audio(filename, int(hps.data.sampling_rate)) + audio,_ = librosa.load(filename, int(hps.data.sampling_rate)) audio = torch.FloatTensor(audio) audio_norm = audio audio_norm = audio_norm.unsqueeze(0) @@ -540,7 +540,7 @@ def only_punc(text): return not any(t.isalnum() or t.isalpha() for t in text) -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1, spk = "default"): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"): infer_sovits = speaker_list[spk].sovits vq_model = infer_sovits.vq_model hps = infer_sovits.hps @@ -552,6 +552,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, t0 = ttime() prompt_text = prompt_text.strip("\n") prompt_language, text = prompt_language, text.strip("\n") + dtype = torch.float16 if is_half == True else torch.float32 zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) @@ -568,6 +569,18 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] prompt = prompt_semantic.unsqueeze(0).to(device) + + refers=[] + if(inp_refs): + for path in inp_refs: + try: + refer = get_spepc(hps, path).to(dtype).to(device) + refers.append(refer) + except Exception as e: + logger.error(e) + if(len(refers)==0): + refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + t1 = ttime() version = vq_model.version os.environ['version'] = version @@ -605,16 +618,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, t3 = ttime() # print(pred_semantic.shape,idx) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 - refer = get_spepc(hps, ref_wav_path) # .to(device) - if (is_half == True): - refer = refer.half().to(device) - else: - refer = refer.to(device) # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = \ vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refer,speed=speed).detach().cpu().numpy()[ + refers,speed=speed).detach().cpu().numpy()[ 0, 0] ###试试重建不带上prompt部分 + max_audio=np.abs(audio).max()#简单防止16bit爆音 + if max_audio>1:audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() @@ -659,7 +669,7 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed): +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -678,7 +688,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu else: text = cut_text(text,cut_punc) - return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type) + return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type) @@ -873,10 +883,11 @@ async def tts_endpoint(request: Request): json_post_raw.get("text"), json_post_raw.get("text_language"), json_post_raw.get("cut_punc"), - json_post_raw.get("top_k", 10), + json_post_raw.get("top_k", 15), json_post_raw.get("top_p", 1.0), json_post_raw.get("temperature", 1.0), - json_post_raw.get("speed", 1.0) + json_post_raw.get("speed", 1.0), + json_post_raw.get("inp_refs", []) ) @@ -888,12 +899,13 @@ async def tts_endpoint( text: str = None, text_language: str = None, cut_punc: str = None, - top_k: int = 10, + top_k: int = 15, top_p: float = 1.0, temperature: float = 1.0, - speed: float = 1.0 + speed: float = 1.0, + inp_refs: list = Query(default=[]) ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs) if __name__ == "__main__":