Mix timbre

This commit is contained in:
KamioRinn 2024-08-19 04:54:23 +08:00
parent c9b6945b22
commit 8b0fbe6d18

44
api.py
View File

@ -153,7 +153,7 @@ from time import time as ttime
import torch import torch
import librosa import librosa
import soundfile as sf import soundfile as sf
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, Query, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -411,7 +411,7 @@ class DictToAttrRecursive(dict):
def get_spepc(hps, filename): def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate)) audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
@ -540,7 +540,7 @@ def only_punc(text):
return not any(t.isalnum() or t.isalpha() for t in text) return not any(t.isalnum() or t.isalpha() for t in text)
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1, spk = "default"): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
infer_sovits = speaker_list[spk].sovits infer_sovits = speaker_list[spk].sovits
vq_model = infer_sovits.vq_model vq_model = infer_sovits.vq_model
hps = infer_sovits.hps hps = infer_sovits.hps
@ -552,6 +552,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t0 = ttime() t0 = ttime()
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n") prompt_language, text = prompt_language, text.strip("\n")
dtype = torch.float16 if is_half == True else torch.float32
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k, sr = librosa.load(ref_wav_path, sr=16000)
@ -568,6 +569,18 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
codes = vq_model.extract_latent(ssl_content) codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
refers=[]
if(inp_refs):
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
refers.append(refer)
except Exception as e:
logger.error(e)
if(len(refers)==0):
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
t1 = ttime() t1 = ttime()
version = vq_model.version version = vq_model.version
os.environ['version'] = version os.environ['version'] = version
@ -605,16 +618,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t3 = ttime() t3 = ttime()
# print(pred_semantic.shape,idx) # print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \ audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer,speed=speed).detach().cpu().numpy()[ refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分 0, 0] ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav)
t4 = ttime() t4 = ttime()
@ -659,7 +669,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200) return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed): def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
if ( if (
refer_wav_path == "" or refer_wav_path is None refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None or prompt_text == "" or prompt_text is None
@ -678,7 +688,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
else: else:
text = cut_text(text,cut_punc) text = cut_text(text,cut_punc)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type) return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
@ -873,10 +883,11 @@ async def tts_endpoint(request: Request):
json_post_raw.get("text"), json_post_raw.get("text"),
json_post_raw.get("text_language"), json_post_raw.get("text_language"),
json_post_raw.get("cut_punc"), json_post_raw.get("cut_punc"),
json_post_raw.get("top_k", 10), json_post_raw.get("top_k", 15),
json_post_raw.get("top_p", 1.0), json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0), json_post_raw.get("temperature", 1.0),
json_post_raw.get("speed", 1.0) json_post_raw.get("speed", 1.0),
json_post_raw.get("inp_refs", [])
) )
@ -888,12 +899,13 @@ async def tts_endpoint(
text: str = None, text: str = None,
text_language: str = None, text_language: str = None,
cut_punc: str = None, cut_punc: str = None,
top_k: int = 10, top_k: int = 15,
top_p: float = 1.0, top_p: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
speed: float = 1.0 speed: float = 1.0,
inp_refs: list = Query(default=[])
): ):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed) return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
if __name__ == "__main__": if __name__ == "__main__":