From 110cc3560cac746ca8576b4ed71573c7fec4d5de Mon Sep 17 00:00:00 2001 From: XXXXRT666 Date: Wed, 3 Jul 2024 23:17:09 +0800 Subject: [PATCH] timing --- GPT_SoVITS/inference_webui.py | 8 +++++++- api.py | 6 ------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 44c6d0eb..a6ff042a 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -313,6 +313,7 @@ def merge_short_text_in_array(texts, threshold): return result def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False): + t=[] if prompt_text is None or len(prompt_text) == 0: ref_free = True t0 = ttime() @@ -353,6 +354,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_semantic = codes[0, 0] t1 = ttime() + t.append(t1-t0) if (how_to_cut == i18n("凑四句一切")): text = cut1(text) @@ -376,6 +378,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, for text in texts: # 解决输入目标文本的空行导致报错的问题 + t1 = ttime() if (len(text.strip()) == 0): continue if (text[-1] not in splits): text += "。" if text_language != "en" else "." @@ -430,7 +433,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + t.extend([t2 - t1,t3 - t2, t4 - t3]) + print("%.3f\t%.3f\t%.3f\t%.3f" % + (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])) + ) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype( np.int16 ) diff --git a/api.py b/api.py index aa822ca7..d1249ecd 100644 --- a/api.py +++ b/api.py @@ -127,7 +127,6 @@ sys.path.append("%s/GPT_SoVITS" % (now_dir)) import signal import LangSegment -from time import time as ttime import torch import librosa import soundfile as sf @@ -447,7 +446,6 @@ def only_punc(text): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): - t0 = ttime() prompt_text = prompt_text.strip("\n") prompt_language, text = prompt_language, text.strip("\n") zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) @@ -465,7 +463,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] - t1 = ttime() prompt_language = dict_language[prompt_language.lower()] text_language = dict_language[text_language.lower()] phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language) @@ -485,7 +482,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) prompt = prompt_semantic.unsqueeze(0).to(device) - t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( pred_semantic, idx = t2s_model.model.infer_panel( @@ -496,7 +492,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) # prompt_phone_len=ph_offset, top_k=config['inference']['top_k'], early_stop_num=hz * max_sec) - t3 = ttime() # print(pred_semantic.shape,idx) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 refer = get_spepc(hps, ref_wav_path) # .to(device) @@ -511,7 +506,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) 0, 0] ###试试重建不带上prompt部分 audio_opt.append(audio) audio_opt.append(zero_wav) - t4 = ttime() audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) if stream_mode == "normal":