This commit is contained in:
XXXXRT666 2024-07-03 23:17:09 +08:00
parent 582ba7d519
commit 110cc3560c
2 changed files with 7 additions and 7 deletions

View File

@ -313,6 +313,7 @@ def merge_short_text_in_array(texts, threshold):
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
t=[]
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime()
@ -353,6 +354,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_semantic = codes[0, 0]
t1 = ttime()
t.append(t1-t0)
if (how_to_cut == i18n("凑四句一切")):
text = cut1(text)
@ -376,6 +378,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
for text in texts:
# 解决输入目标文本的空行导致报错的问题
t1 = ttime()
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
@ -430,7 +433,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
t.extend([t2 - t1,t3 - t2, t4 - t3])
print("%.3f\t%.3f\t%.3f\t%.3f" %
(t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
)
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)

6
api.py
View File

@ -127,7 +127,6 @@ sys.path.append("%s/GPT_SoVITS" % (now_dir))
import signal
import LangSegment
from time import time as ttime
import torch
import librosa
import soundfile as sf
@ -447,7 +446,6 @@ def only_punc(text):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
@ -465,7 +463,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
@ -485,7 +482,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
@ -496,7 +492,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
@ -511,7 +506,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal":