diff --git a/api.py b/api.py index 34adfbe9..f94de103 100644 --- a/api.py +++ b/api.py @@ -412,7 +412,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) early_stop_num=hz * max_sec) t3 = ttime() # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 + if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): + pred_semantic = pred_semantic[0] + idx=idx[0] + pred_semantic = pred_semantic[-idx:] + pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0) + else: + pred_semantic = pred_semantic[:,-idx:] + pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 refer = get_spepc(hps, ref_wav_path) # .to(device) if (is_half == True): refer = refer.half().to(device)