stream_infer: 更方便找规律的图

This commit is contained in:
csh 2025-06-19 00:36:17 +08:00
parent 920bbafb12
commit 03f99256c7

View File

@ -121,7 +121,7 @@ class StreamT2SModel(nn.Module):
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
# stop = True
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
return y[:,:-1], xy_pos, last_token, k_cache, v_cache
return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache
# if stop:
# if y.shape[1] == 0:
@ -192,7 +192,6 @@ def test_stream(
ref_audio_path,
ref_text,
output_path,
export_bert_and_ssl=False,
device="cpu",
is_half=True,
):
@ -286,13 +285,30 @@ def test_stream(
audios = []
full_audios = []
print("y.shape:", y.shape)
cut_id = 0
while True:
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
# print("y.shape:", y.shape)
stop = last_token==t2s.EOS
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
cut_id = idx + 7
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
# y = torch.cat([y, y[:,-1:]], dim=1)
# idx+=1
if stop :
idx -=1
print('stop')
print(idx, y[:,-idx+last_idx:])
print(idx,last_idx, y.shape)
print(y[:,-idx:-idx+20])
# 玄学这档子事说不清楚
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
if idx == cut_id or stop:
print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}")
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
full_audios.append(audio)
if last_idx == 0:
@ -303,13 +319,10 @@ def test_stream(
audio = audio[last_idx*1280 -1280*8:]
else:
audio = audio[last_idx*1280 -1280*8:-1280*8]
print(y[:,-idx+last_idx:])
last_idx = idx
# print(f'write {output_path}/out_{audio_index}')
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
audios.append(audio)
idx+=1
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
if idx>1500:
break
@ -317,6 +330,8 @@ def test_stream(
if stop:
break
idx+=1
at = time.time()
for (i,a) in enumerate(audios):
@ -349,6 +364,10 @@ def test_stream(
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
axes[-1].set_xlim(0, max_duration)
for i,y in enumerate(y[0][-idx:]):
axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5)
# plt.title('Overlapped Waveform Comparison')
# plt.xlabel('Sample Number')
# plt.ylabel('Amplitude')