diff --git a/api.py b/api.py index 34adfbe9..06350cf1 100644 --- a/api.py +++ b/api.py @@ -109,7 +109,7 @@ import sys now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) - +import re import signal from time import time as ttime import torch @@ -402,6 +402,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( + print("-"*30) pred_semantic, idx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, @@ -411,8 +412,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) top_k=config['inference']['top_k'], early_stop_num=hz * max_sec) t3 = ttime() - # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 + # print(pred_semantic[:,]) + if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): + pred_semantic = pred_semantic[0] + idx=idx[0] + pred_semantic = pred_semantic[-idx:] + pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0) + else: + pred_semantic = pred_semantic[:,-idx:] + pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 refer = get_spepc(hps, ref_wav_path) # .to(device) if (is_half == True): refer = refer.half().to(device)