From 03f99256c78b027d03ac456c80bf822d21df6aab Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Thu, 19 Jun 2025 00:36:17 +0800 Subject: [PATCH] =?UTF-8?q?stream=5Finfer:=20=E6=9B=B4=E6=96=B9=E4=BE=BF?= =?UTF-8?q?=E6=89=BE=E8=A7=84=E5=BE=8B=E7=9A=84=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index 127dbca7..e94c7a43 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -121,7 +121,7 @@ class StreamT2SModel(nn.Module): # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: # stop = True if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: - return y[:,:-1], xy_pos, last_token, k_cache, v_cache + return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache # if stop: # if y.shape[1] == 0: @@ -192,7 +192,6 @@ def test_stream( ref_audio_path, ref_text, output_path, - export_bert_and_ssl=False, device="cpu", is_half=True, ): @@ -286,13 +285,30 @@ def test_stream( audios = [] full_audios = [] print("y.shape:", y.shape) + cut_id = 0 while True: y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) # print("y.shape:", y.shape) stop = last_token==t2s.EOS + print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx) + + if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: + cut_id = idx + 7 + print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape) + # y = torch.cat([y, y[:,-1:]], dim=1) + # idx+=1 + + if stop : + idx -=1 + print('stop') + print(idx, y[:,-idx+last_idx:]) + print(idx,last_idx, y.shape) + print(y[:,-idx:-idx+20]) + # 玄学这档子事说不清楚 - if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop: + if idx == cut_id or stop: + print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}") audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] full_audios.append(audio) if last_idx == 0: @@ -303,13 +319,10 @@ def test_stream( audio = audio[last_idx*1280 -1280*8:] else: audio = audio[last_idx*1280 -1280*8:-1280*8] - print(y[:,-idx+last_idx:]) last_idx = idx # print(f'write {output_path}/out_{audio_index}') # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) audios.append(audio) - - idx+=1 # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) if idx>1500: break @@ -317,6 +330,8 @@ def test_stream( if stop: break + idx+=1 + at = time.time() for (i,a) in enumerate(audios): @@ -349,6 +364,10 @@ def test_stream( axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') axes[-1].set_xlim(0, max_duration) + for i,y in enumerate(y[0][-idx:]): + axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center') + axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5) + # plt.title('Overlapped Waveform Comparison') # plt.xlabel('Sample Number') # plt.ylabel('Amplitude')