mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Mix timbre
This commit is contained in:
parent
c9b6945b22
commit
8b0fbe6d18
44
api.py
44
api.py
@ -153,7 +153,7 @@ from time import time as ttime
|
|||||||
import torch
|
import torch
|
||||||
import librosa
|
import librosa
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, Query, HTTPException
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
@ -411,7 +411,7 @@ class DictToAttrRecursive(dict):
|
|||||||
|
|
||||||
|
|
||||||
def get_spepc(hps, filename):
|
def get_spepc(hps, filename):
|
||||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
|
||||||
audio = torch.FloatTensor(audio)
|
audio = torch.FloatTensor(audio)
|
||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
@ -540,7 +540,7 @@ def only_punc(text):
|
|||||||
return not any(t.isalnum() or t.isalpha() for t in text)
|
return not any(t.isalnum() or t.isalpha() for t in text)
|
||||||
|
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1, spk = "default"):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
|
||||||
infer_sovits = speaker_list[spk].sovits
|
infer_sovits = speaker_list[spk].sovits
|
||||||
vq_model = infer_sovits.vq_model
|
vq_model = infer_sovits.vq_model
|
||||||
hps = infer_sovits.hps
|
hps = infer_sovits.hps
|
||||||
@ -552,6 +552,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
prompt_language, text = prompt_language, text.strip("\n")
|
||||||
|
dtype = torch.float16 if is_half == True else torch.float32
|
||||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||||
@ -568,6 +569,18 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
codes = vq_model.extract_latent(ssl_content)
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
refers=[]
|
||||||
|
if(inp_refs):
|
||||||
|
for path in inp_refs:
|
||||||
|
try:
|
||||||
|
refer = get_spepc(hps, path).to(dtype).to(device)
|
||||||
|
refers.append(refer)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
if(len(refers)==0):
|
||||||
|
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||||
|
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
version = vq_model.version
|
version = vq_model.version
|
||||||
os.environ['version'] = version
|
os.environ['version'] = version
|
||||||
@ -605,16 +618,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
# print(pred_semantic.shape,idx)
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
|
||||||
if (is_half == True):
|
|
||||||
refer = refer.half().to(device)
|
|
||||||
else:
|
|
||||||
refer = refer.to(device)
|
|
||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||||
audio = \
|
audio = \
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||||
refer,speed=speed).detach().cpu().numpy()[
|
refers,speed=speed).detach().cpu().numpy()[
|
||||||
0, 0] ###试试重建不带上prompt部分
|
0, 0] ###试试重建不带上prompt部分
|
||||||
|
max_audio=np.abs(audio).max()#简单防止16bit爆音
|
||||||
|
if max_audio>1:audio/=max_audio
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
@ -659,7 +669,7 @@ def handle_change(path, text, language):
|
|||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed):
|
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
|
||||||
if (
|
if (
|
||||||
refer_wav_path == "" or refer_wav_path is None
|
refer_wav_path == "" or refer_wav_path is None
|
||||||
or prompt_text == "" or prompt_text is None
|
or prompt_text == "" or prompt_text is None
|
||||||
@ -678,7 +688,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
|
|||||||
else:
|
else:
|
||||||
text = cut_text(text,cut_punc)
|
text = cut_text(text,cut_punc)
|
||||||
|
|
||||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type)
|
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -873,10 +883,11 @@ async def tts_endpoint(request: Request):
|
|||||||
json_post_raw.get("text"),
|
json_post_raw.get("text"),
|
||||||
json_post_raw.get("text_language"),
|
json_post_raw.get("text_language"),
|
||||||
json_post_raw.get("cut_punc"),
|
json_post_raw.get("cut_punc"),
|
||||||
json_post_raw.get("top_k", 10),
|
json_post_raw.get("top_k", 15),
|
||||||
json_post_raw.get("top_p", 1.0),
|
json_post_raw.get("top_p", 1.0),
|
||||||
json_post_raw.get("temperature", 1.0),
|
json_post_raw.get("temperature", 1.0),
|
||||||
json_post_raw.get("speed", 1.0)
|
json_post_raw.get("speed", 1.0),
|
||||||
|
json_post_raw.get("inp_refs", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -888,12 +899,13 @@ async def tts_endpoint(
|
|||||||
text: str = None,
|
text: str = None,
|
||||||
text_language: str = None,
|
text_language: str = None,
|
||||||
cut_punc: str = None,
|
cut_punc: str = None,
|
||||||
top_k: int = 10,
|
top_k: int = 15,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
speed: float = 1.0
|
speed: float = 1.0,
|
||||||
|
inp_refs: list = Query(default=[])
|
||||||
):
|
):
|
||||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed)
|
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user