Mix timbre

This commit is contained in:
KamioRinn 2024-08-19 04:54:23 +08:00
parent c9b6945b22
commit 8b0fbe6d18

44
api.py
View File

@ -153,7 +153,7 @@ from time import time as ttime
import torch
import librosa
import soundfile as sf
from fastapi import FastAPI, Request, HTTPException
from fastapi import FastAPI, Request, Query, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -411,7 +411,7 @@ class DictToAttrRecursive(dict):
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
@ -540,7 +540,7 @@ def only_punc(text):
return not any(t.isalnum() or t.isalpha() for t in text)
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1, spk = "default"):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
infer_sovits = speaker_list[spk].sovits
vq_model = infer_sovits.vq_model
hps = infer_sovits.hps
@ -552,6 +552,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t0 = ttime()
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
dtype = torch.float16 if is_half == True else torch.float32
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
@ -568,6 +569,18 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
refers=[]
if(inp_refs):
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
refers.append(refer)
except Exception as e:
logger.error(e)
if(len(refers)==0):
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
t1 = ttime()
version = vq_model.version
os.environ['version'] = version
@ -605,16 +618,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer,speed=speed).detach().cpu().numpy()[
refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
@ -659,7 +669,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed):
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
if (
refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None
@ -678,7 +688,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
else:
text = cut_text(text,cut_punc)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
@ -873,10 +883,11 @@ async def tts_endpoint(request: Request):
json_post_raw.get("text"),
json_post_raw.get("text_language"),
json_post_raw.get("cut_punc"),
json_post_raw.get("top_k", 10),
json_post_raw.get("top_k", 15),
json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0),
json_post_raw.get("speed", 1.0)
json_post_raw.get("speed", 1.0),
json_post_raw.get("inp_refs", [])
)
@ -888,12 +899,13 @@ async def tts_endpoint(
text: str = None,
text_language: str = None,
cut_punc: str = None,
top_k: int = 10,
top_k: int = 15,
top_p: float = 1.0,
temperature: float = 1.0,
speed: float = 1.0
speed: float = 1.0,
inp_refs: list = Query(default=[])
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed)
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
if __name__ == "__main__":