diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 42b7fda..ecbbd2f 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -314,7 +314,7 @@ def merge_short_text_in_array(texts, threshold): ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature # cache_tokens={}#暂未实现清理机制 -cache=None +cache= {} def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False): global cache if prompt_text is None or len(prompt_text) == 0: @@ -380,7 +380,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if not ref_free: phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language) - for text in texts: + for i_text,text in enumerate(texts): # 解决输入目标文本的空行导致报错的问题 if (len(text.strip()) == 0): continue @@ -400,7 +400,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, t2 = ttime() # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature) - if(type(cache)!=type(None)and if_freeze==True):pred_semantic=cache + # print(cache.keys(),if_freeze) + if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text] else: with torch.no_grad(): pred_semantic, idx = t2s_model.model.infer_panel( @@ -415,7 +416,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, early_stop_num=hz * max_sec, ) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) - cache=pred_semantic + cache[i_text]=pred_semantic t3 = ttime() refer = get_spepc(hps, ref_wav_path) # .to(device) if is_half == True: