From 6fe3861e730646717f72ebfb1f22dd980194a3f2 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Tue, 17 Jun 2025 13:39:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9C=A8=20stream=5Finfer=20=E8=84=9A=E6=9C=AC?= =?UTF-8?q?=E4=B8=AD=E7=BB=98=E5=88=B6=E7=94=9F=E6=88=90=E7=9A=84=E9=9F=B3?= =?UTF-8?q?=E9=A2=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 46 +++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index 3fd6dbe0..ad4ed26e 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -9,6 +9,8 @@ from torch.nn import functional as F import soundfile from inference_webui import get_phones_and_bert +import matplotlib.pyplot as plt + class StreamT2SModel(nn.Module): @@ -181,7 +183,7 @@ def export_prov2( ref_bert = ref_bert.to(ref_seq.device) text_seq_id, text_bert_T, norm_text = get_phones_and_bert( - "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。可能这就是狐狸吧.你觉得狐狸神奇吗?", "auto", "v2" + "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" ) text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T @@ -217,7 +219,7 @@ def export_prov2( print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) stream_t2s = StreamT2SModel(t2s).to(device) - # stream_t2s = torch.jit.script(stream_t2s) + stream_t2s = torch.jit.script(stream_t2s) ref_audio_sr = resamplex(ref_audio, 16000, 32000) if is_half: @@ -249,30 +251,30 @@ def export_prov2( y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) idx = 1 - audio_index = 0 last_idx = 0 audios = [] + full_audios = [] print("y.shape:", y.shape) while True: y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos) # print("y.shape:", y.shape) # 玄学这档子事说不清楚 - if (y[0,-1] < 60 and idx-last_idx > 25) or stop: - audio = vits.vq_model(y[:,-idx:-1].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop: + audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + full_audios.append(audio) if last_idx == 0: - audio = audio[:-640] + audio = audio[:-1280*8] et = time.time() else: if stop: - audio = audio[last_idx*1280 -640:] + audio = audio[last_idx*1280 -1280*8:] else: - audio = audio[last_idx*1280 -640:-640] + audio = audio[last_idx*1280 -1280*8:-1280*8] print(y[:,-idx+last_idx:]) last_idx = idx # print(f'write {output_path}/out_{audio_index}') # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) - audio_index+=1 audios.append(audio) idx+=1 @@ -289,13 +291,37 @@ def export_prov2( print(f'write {output_path}/out_{i}') soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000) - print("final,",audio_index) print(f"frist token: {et - st:.4f} seconds") print(f"all token: {at - st:.4f} seconds") audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) audio = torch.cat(audios, dim=0) soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) + + + colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] + + fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6)) + + max_duration = full_audios[-1].shape[0] + + last_line = 0 + for i,(ax,a) in enumerate(zip(axes[:-1],full_audios)): + ax.plot(a.float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}") + ax.axvline(x=last_line, color=colors[i], linestyle='--') + last_line = a.shape[0]-8*1280 + ax.axvline(x=last_line, color=colors[i], linestyle='--') + ax.set_xlim(0, max_duration) + axes[-1].axvline(x=last_line, color=colors[i], linestyle='--') + + axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') + axes[-1].set_xlim(0, max_duration) + + # plt.title('Overlapped Waveform Comparison') + # plt.xlabel('Sample Number') + # plt.ylabel('Amplitude') + # plt.tight_layout() + plt.show() if __name__ == "__main__":