From dbb6b42fdb06467c31258189daac64a04e29bb0d Mon Sep 17 00:00:00 2001 From: KamioRinn Date: Mon, 19 Aug 2024 21:26:20 +0800 Subject: [PATCH] Optimize detail --- api.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/api.py b/api.py index 3eb0ee8..92924dc 100644 --- a/api.py +++ b/api.py @@ -221,7 +221,7 @@ def get_sovits_weights(sovits_path): hps.model.version = "v1" else: hps.model.version = "v2" - print("sovits版本:",hps.model.version) + logger.info(f"模型版本: {hps.model.version}") model_params_dict = vars(hps.model) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, @@ -489,8 +489,7 @@ def pack_raw(audio_bytes, data, rate): def pack_wav(audio_bytes, rate): data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16) wav_bytes = BytesIO() - sf.write(wav_bytes, data, rate, format='wav') - + sf.write(wav_bytes, data, rate, format='WAV') return wav_bytes @@ -543,6 +542,7 @@ def only_punc(text): return not any(t.isalnum() or t.isalpha() for t in text) +splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"): infer_sovits = speaker_list[spk].sovits vq_model = infer_sovits.vq_model @@ -554,6 +554,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, t0 = ttime() prompt_text = prompt_text.strip("\n") + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." prompt_language, text = prompt_language, text.strip("\n") dtype = torch.float16 if is_half == True else torch.float32 zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) @@ -599,6 +600,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, continue audio_opt = [] + if (text[-1] not in splits): text += "。" if text_language != "en" else "." phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) bert = torch.cat([bert1, bert2], 1) @@ -607,7 +609,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) t2 = ttime() with torch.no_grad(): - # pred_semantic = t2s_model.model.infer( pred_semantic, idx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, @@ -618,20 +619,20 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_p = top_p, temperature = temperature, early_stop_num=hz * max_sec) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 - # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = \ vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[ 0, 0] ###试试重建不带上prompt部分 - max_audio=np.abs(audio).max()#简单防止16bit爆音 - if max_audio>1:audio/=max_audio + max_audio=np.abs(audio).max() + if max_audio>1: + audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) + # audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) if stream_mode == "normal": audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)