diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 05ef486..cc79817 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -331,27 +331,29 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32, ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - if is_half == True: - wav16k = wav16k.half().to(device) - zero_wav_torch = zero_wav_torch.half().to(device) - else: - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ - "last_hidden_state" - ].transpose( - 1, 2 - ) # .float() - codes = vq_model.extract_latent(ssl_content) - - prompt_semantic = codes[0, 0] + if not ref_free: + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ + "last_hidden_state" + ].transpose( + 1, 2 + ) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + t1 = ttime() if (how_to_cut == i18n("凑四句一切")): @@ -391,7 +393,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = prompt_semantic.unsqueeze(0).to(device) + t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer(