diff --git a/api.py b/api.py index b8d584e7..720ef529 100644 --- a/api.py +++ b/api.py @@ -30,7 +30,7 @@ endpoint: `/` 使用执行参数指定的参考音频: GET: - `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=中文` POST: ```json { @@ -41,7 +41,7 @@ POST: 手动指定当次推理所使用的参考音频: GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=中文&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=中文` POST: ```json { @@ -129,6 +129,7 @@ from text.cleaner import clean_text from module.mel_processing import spectrogram_torch from my_utils import load_audio import config as global_config +from inference_webui import get_tts_wav g_config = global_config.Config() @@ -316,82 +317,6 @@ dict_language = { } -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): - t0 = ttime() - prompt_text = prompt_text.strip("\n") - prompt_language, text = prompt_language, text.strip("\n") - zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - if (is_half == True): - wav16k = wav16k.half().to(device) - zero_wav_torch = zero_wav_torch.half().to(device) - else: - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() - codes = vq_model.extract_latent(ssl_content) - prompt_semantic = codes[0, 0] - t1 = ttime() - prompt_language = dict_language[prompt_language] - text_language = dict_language[text_language] - phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) - phones1 = cleaned_text_to_sequence(phones1) - texts = text.split("\n") - audio_opt = [] - - for text in texts: - phones2, word2ph2, norm_text2 = clean_text(text, text_language) - phones2 = cleaned_text_to_sequence(phones2) - if (prompt_language == "zh"): - bert1 = get_bert_feature(norm_text1, word2ph1).to(device) - else: - bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to( - device) - if (text_language == "zh"): - bert2 = get_bert_feature(norm_text2, word2ph2).to(device) - else: - bert2 = torch.zeros((1024, len(phones2))).to(bert1) - bert = torch.cat([bert1, bert2], 1) - - all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = prompt_semantic.unsqueeze(0).to(device) - t2 = ttime() - with torch.no_grad(): - # pred_semantic = t2s_model.model.infer( - pred_semantic, idx = t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_len, - prompt, - bert, - # prompt_phone_len=ph_offset, - top_k=config['inference']['top_k'], - early_stop_num=hz * max_sec) - t3 = ttime() - # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 - refer = get_spepc(hps, ref_wav_path) # .to(device) - if (is_half == True): - refer = refer.half().to(device) - else: - refer = refer.to(device) - # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] - audio = \ - vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refer).detach().cpu().numpy()[ - 0, 0] ###试试重建不带上prompt部分 - audio_opt.append(audio) - audio_opt.append(zero_wav) - t4 = ttime() - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) - - def handle_control(command): if command == "restart": os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)