在 stream_infer 脚本中绘制生成的音频

This commit is contained in:
csh 2025-06-17 13:39:16 +08:00
parent 131e2ffcb7
commit 6fe3861e73

View File

@ -9,6 +9,8 @@ from torch.nn import functional as F
import soundfile
from inference_webui import get_phones_and_bert
import matplotlib.pyplot as plt
class StreamT2SModel(nn.Module):
@ -181,7 +183,7 @@ def export_prov2(
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。可能这就是狐狸吧.你觉得狐狸神奇吗?", "auto", "v2"
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
@ -217,7 +219,7 @@ def export_prov2(
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
stream_t2s = StreamT2SModel(t2s).to(device)
# stream_t2s = torch.jit.script(stream_t2s)
stream_t2s = torch.jit.script(stream_t2s)
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
if is_half:
@ -249,30 +251,30 @@ def export_prov2(
y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
idx = 1
audio_index = 0
last_idx = 0
audios = []
full_audios = []
print("y.shape:", y.shape)
while True:
y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos)
# print("y.shape:", y.shape)
# 玄学这档子事说不清楚
if (y[0,-1] < 60 and idx-last_idx > 25) or stop:
audio = vits.vq_model(y[:,-idx:-1].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
full_audios.append(audio)
if last_idx == 0:
audio = audio[:-640]
audio = audio[:-1280*8]
et = time.time()
else:
if stop:
audio = audio[last_idx*1280 -640:]
audio = audio[last_idx*1280 -1280*8:]
else:
audio = audio[last_idx*1280 -640:-640]
audio = audio[last_idx*1280 -1280*8:-1280*8]
print(y[:,-idx+last_idx:])
last_idx = idx
# print(f'write {output_path}/out_{audio_index}')
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
audio_index+=1
audios.append(audio)
idx+=1
@ -289,13 +291,37 @@ def export_prov2(
print(f'write {output_path}/out_{i}')
soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000)
print("final,",audio_index)
print(f"frist token: {et - st:.4f} seconds")
print(f"all token: {at - st:.4f} seconds")
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
audio = torch.cat(audios, dim=0)
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6))
max_duration = full_audios[-1].shape[0]
last_line = 0
for i,(ax,a) in enumerate(zip(axes[:-1],full_audios)):
ax.plot(a.float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}")
ax.axvline(x=last_line, color=colors[i], linestyle='--')
last_line = a.shape[0]-8*1280
ax.axvline(x=last_line, color=colors[i], linestyle='--')
ax.set_xlim(0, max_duration)
axes[-1].axvline(x=last_line, color=colors[i], linestyle='--')
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
axes[-1].set_xlim(0, max_duration)
# plt.title('Overlapped Waveform Comparison')
# plt.xlabel('Sample Number')
# plt.ylabel('Amplitude')
# plt.tight_layout()
plt.show()
if __name__ == "__main__":