Fixed some bug

This commit is contained in:
XXXXRT 2024-04-03 23:33:02 +01:00 committed by XXXXRT666
parent b50ab5419f
commit f45fff72d9

338
api.py
View File

@ -158,6 +158,78 @@ import math
i18n = I18nAuto()
class REF:
def __init__(self, ref_path="", ref_text="", ref_language=""):
if ref_text:
ref_text = ref_text.strip("\n")
if (ref_text[-1] not in splits): ref_text += "" if ref_language != "en" else "."
if ref_language:
ref_language = dict_language[ref_language.lower()]
self.path = ref_path
self.text = ref_text
self.language = ref_language
self.prompt_semantic = None
self.refer_spec = None
def set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
if is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0].to(device)
self.prompt_semantic = prompt_semantic
self.codes = codes
self.ssl_content = ssl_content
def set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
spec = spec.to(device)
if is_half:
spec = spec.half()
# self.refer_spec = spec
self.refer_spec = spec
def set_ref_audio(self):
'''
To set the reference audio for the TTS model,
including the prompt_semantic and refer_spec.
Args:
ref_audio_path: str, the path of the reference audio.
'''
self.set_prompt_semantic(self.path)
self.set_ref_spec(self.path)
self.phone, self.bert_feature, self.norm_text = get_phones_and_bert(self.text, self.language)
def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language)
def is_empty(*items): # 任意一项不为空返回False
for item in items:
@ -195,6 +267,7 @@ def change_sovits_weights(sovits_path):
vq_model.load_state_dict(dict_s2["weight"], strict=False)
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config, is_fast_inference
hz = 50
@ -243,10 +316,6 @@ def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
bert = torch.zeros(
(1024, len(phones)),
dtype=precision,
).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
@ -332,77 +401,6 @@ class DictToAttrRecursive:
setattr(self, key, value)
class REF:
def __init__(self, ref_path="", ref_text="", ref_language=""):
ref_text = ref_text.strip("\n")
if ref_text:
if (ref_text[-1] not in splits): ref_text += "" if ref_language != "en" else "."
if ref_language:
ref_language = dict_language[ref_language.lower()]
self.path = ref_path
self.text = ref_text
self.language = ref_language
def set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
if is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0].to(device)
self.prompt_semantic = prompt_semantic
self.codes = codes
self.ssl_content = ssl_content
def set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
spec = spec.to(device)
if is_half:
spec = spec.half()
# self.refer_spec = spec
self.refer_spec = spec
def set_ref_audio(self):
'''
To set the reference audio for the TTS model,
including the prompt_semantic and refer_spec.
Args:
ref_audio_path: str, the path of the reference audio.
'''
self.set_prompt_semantic(self.path)
self.set_ref_spec(self.path)
self.phone, self.bert_feature, self.norm_text = get_phones_and_bert(self.text, self.language)
def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language)
def pack_audio(audio_bytes, data, rate):
if media_type == "ogg":
audio_bytes = pack_ogg(audio_bytes, data, rate)
@ -485,6 +483,73 @@ def only_punc(text):
return not any(t.isalnum() or t.isalpha() for t in text)
def get_tts_wav(ref:REF, text, text_language):
logger.info("get_tts_wav")
t0 = ttime()
t1 = ttime()
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = ref.phone, ref.bert_feature, ref.norm_text
texts = text.split("\n")
audio_bytes = BytesIO()
for text in texts:
# 简单防止纯符号引发参考音频泄露
if only_punc(text):
continue
print(text)
audio_opt = []
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = ref.prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): # 神秘代码,有些时候sys.path会出问题,import的是fast inference分支的AR
pred_semantic = pred_semantic[0]
idx=idx[0]
pred_semantic = pred_semantic[-idx:]
pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0)
else:
pred_semantic = pred_semantic[:,-idx:]
pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
ref.refer_spec).detach().cpu().numpy()[0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if return_fragment:
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
yield audio_chunk
if not return_fragment:
if media_type == "wav":
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
yield audio_bytes.getvalue()
def preprocess(text:list, lang:str)->List[Dict]:
result = []
for _text in text:
@ -572,7 +637,7 @@ def to_batch(data:list, ref:REF,
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
logger.info("batch_size: "+str(batch_size))
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
@ -664,6 +729,7 @@ def recovery_order(data:list, batch_index_list:list)->list:
def run(ref:REF, text, text_lang):
logger.info("run")
logger.info(f"batch_size: {batch_size}")
########## variables initialization ###########
top_k = 5
@ -779,7 +845,6 @@ def run(ref:REF, text, text_lang):
else:
audio.append(batch_audio_fragment)
logger.info("return_fragment:"+str(return_fragment)+" split_bucket:"+str(split_bucket)+" batch_size"+str(batch_size)+" media_type:"+media_type)
if not return_fragment:
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield audio_postprocess(audio,
@ -796,72 +861,6 @@ def run(ref:REF, text, text_lang):
pass
def get_tts_wav(ref:REF, text, text_language):
logger.info("get_tts_wav")
t0 = ttime()
t1 = ttime()
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = ref.phone, ref.bert_feature, ref.norm_text
texts = text.split("\n")
audio_bytes = BytesIO()
for text in texts:
# 简单防止纯符号引发参考音频泄露
if only_punc(text):
continue
print(text)
audio_opt = []
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = ref.prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): # 神秘代码,有些时候sys.path会出问题,import的是fast inference分支的AR
pred_semantic = pred_semantic[0]
idx=idx[0]
pred_semantic = pred_semantic[-idx:]
pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0)
else:
pred_semantic = pred_semantic[:,-idx:]
pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
ref.refer_spec).detach().cpu().numpy()[0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if return_fragment:
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
yield audio_chunk
if not return_fragment:
if media_type == "wav":
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
yield audio_bytes.getvalue()
# --------------------------------
# 初始化部分
# --------------------------------
@ -885,10 +884,10 @@ parser.add_argument("-dl", "--default_refer_language", type=str, default="", hel
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-bs", "--batch_size", type=int, default=1, help="批处理大小")
parser.add_argument("-bs", "--batch_size", type=int, default=2, help="批处理大小")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
parser.add_argument("-rf", "--return_fragment", action="store_true", default=False, help="是否开启碎片返回")
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-sb", "--split_bucket", action="store_true", default=False, help="是否将批处理分成多个桶")
parser.add_argument("-fa", "--flash_atten", action="store_true", default=False, help="是否开启flash_attention")
# bool值的用法为 `python ./api.py -fp ...`
@ -909,7 +908,7 @@ cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
default_cut_punc = args.cut_punc
batch_size = args.batch_size
return_fragment = args.return_fragment
stream_mode = args.stream_mode
split_bucket = args.split_bucket
flash_atten = args.flash_atten
@ -955,6 +954,16 @@ precision = torch.float16 if is_half else torch.float32
device = torch.device(device)
##流式返回
if stream_mode.lower() in ["normal","n"]:
stream_mode = "normal"
return_fragment = True
logger.info("流式返回已开启")
else:
stream_mode = "close"
return_fragment = False
# 音频编码格式
if args.media_type.lower() in ["aac","ogg"]:
media_type = args.media_type.lower()
@ -1019,8 +1028,8 @@ def handle_control(command):
exit(0)
def handle_change(path, text, language):
global default_refer
def handle_change(path, text, language, cut_punc):
global default_refer, default_cut_punc
if is_empty(path, text, language):
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
@ -1028,10 +1037,14 @@ def handle_change(path, text, language):
(text != "" or text is not None) and\
(language != "" or language is not None):
default_refer = REF(path, text, language)
default_refer.set_ref_audio()
if (cut_punc !="" or cut_punc is not None):
default_cut_punc = cut_punc
logger.info(f"当前默认参考音频路径: {default_refer.path}")
logger.info(f"当前默认参考音频文本: {default_refer.text}")
logger.info(f"当前默认参考音频语种: {default_refer.language}")
logger.info(f"当前默认切分符号: {default_cut_punc}")
logger.info(f"is_ready: {default_refer.is_ready()}")
@ -1043,6 +1056,8 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
(prompt_text != default_refer.text) or\
(prompt_language != default_refer.language):
ref = REF(refer_wav_path, prompt_text, prompt_language)
if ref.is_ready():
ref.set_ref_audio
else:
ref = default_refer
@ -1055,13 +1070,14 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if cut_punc == None:
if cut_punc == "" or cut_punc is None:
text = cut_text(text,default_cut_punc)
else:
text = cut_text(text,cut_punc)
if is_fast_inference:
return StreamingResponse(run(ref, text,text_language), media_type="audio/"+media_type)
else:
@ -1080,9 +1096,11 @@ async def set_model(request: Request):
gpt_path=json_post_raw.get("gpt_model_path")
global sovits_path
sovits_path=json_post_raw.get("sovits_model_path")
logger.info("gptpath"+gpt_path+";vitspath"+sovits_path)
logger.info("gptpath: "+gpt_path)
logger.info("vitspath: "+sovits_path)
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
default_refer.set_ref_audio()
return "ok"
@ -1103,7 +1121,8 @@ async def change_refer(request: Request):
return handle_change(
json_post_raw.get("refer_wav_path"),
json_post_raw.get("prompt_text"),
json_post_raw.get("prompt_language")
json_post_raw.get("prompt_language"),
json_post_raw.get("cut_punc")
)
@ -1111,9 +1130,10 @@ async def change_refer(request: Request):
async def change_refer(
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None
prompt_language: str = None,
cut_punc:str = None
):
return handle_change(refer_wav_path, prompt_text, prompt_language)
return handle_change(refer_wav_path, prompt_text, prompt_language,cut_punc)
@app.post("/")
@ -1131,12 +1151,12 @@ async def tts_endpoint(request: Request):
@app.get("/")
async def tts_endpoint(
refer_wav_path: str = "",
prompt_text: str = "",
prompt_language: str = "",
text: str = "",
text_language: str = "",
cut_punc: str = "",
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None,
text: str = None,
text_language: str = None,
cut_punc: str = None,
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc)