mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-31 19:43:09 +08:00
在 stream_infer 脚本中绘制生成的音频
This commit is contained in:
parent
131e2ffcb7
commit
6fe3861e73
@ -9,6 +9,8 @@ from torch.nn import functional as F
|
||||
|
||||
import soundfile
|
||||
from inference_webui import get_phones_and_bert
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
class StreamT2SModel(nn.Module):
|
||||
@ -181,7 +183,7 @@ def export_prov2(
|
||||
ref_bert = ref_bert.to(ref_seq.device)
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。可能这就是狐狸吧.你觉得狐狸神奇吗?", "auto", "v2"
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T
|
||||
@ -217,7 +219,7 @@ def export_prov2(
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
|
||||
stream_t2s = StreamT2SModel(t2s).to(device)
|
||||
# stream_t2s = torch.jit.script(stream_t2s)
|
||||
stream_t2s = torch.jit.script(stream_t2s)
|
||||
|
||||
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
|
||||
if is_half:
|
||||
@ -249,30 +251,30 @@ def export_prov2(
|
||||
|
||||
y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
idx = 1
|
||||
audio_index = 0
|
||||
last_idx = 0
|
||||
audios = []
|
||||
full_audios = []
|
||||
print("y.shape:", y.shape)
|
||||
while True:
|
||||
y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos)
|
||||
# print("y.shape:", y.shape)
|
||||
|
||||
# 玄学这档子事说不清楚
|
||||
if (y[0,-1] < 60 and idx-last_idx > 25) or stop:
|
||||
audio = vits.vq_model(y[:,-idx:-1].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
|
||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||
full_audios.append(audio)
|
||||
if last_idx == 0:
|
||||
audio = audio[:-640]
|
||||
audio = audio[:-1280*8]
|
||||
et = time.time()
|
||||
else:
|
||||
if stop:
|
||||
audio = audio[last_idx*1280 -640:]
|
||||
audio = audio[last_idx*1280 -1280*8:]
|
||||
else:
|
||||
audio = audio[last_idx*1280 -640:-640]
|
||||
audio = audio[last_idx*1280 -1280*8:-1280*8]
|
||||
print(y[:,-idx+last_idx:])
|
||||
last_idx = idx
|
||||
# print(f'write {output_path}/out_{audio_index}')
|
||||
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audio_index+=1
|
||||
audios.append(audio)
|
||||
|
||||
idx+=1
|
||||
@ -289,13 +291,37 @@ def export_prov2(
|
||||
print(f'write {output_path}/out_{i}')
|
||||
soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
print("final,",audio_index)
|
||||
print(f"frist token: {et - st:.4f} seconds")
|
||||
print(f"all token: {at - st:.4f} seconds")
|
||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audio = torch.cat(audios, dim=0)
|
||||
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
|
||||
|
||||
fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6))
|
||||
|
||||
max_duration = full_audios[-1].shape[0]
|
||||
|
||||
last_line = 0
|
||||
for i,(ax,a) in enumerate(zip(axes[:-1],full_audios)):
|
||||
ax.plot(a.float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}")
|
||||
ax.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
last_line = a.shape[0]-8*1280
|
||||
ax.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
ax.set_xlim(0, max_duration)
|
||||
axes[-1].axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
|
||||
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||
axes[-1].set_xlim(0, max_duration)
|
||||
|
||||
# plt.title('Overlapped Waveform Comparison')
|
||||
# plt.xlabel('Sample Number')
|
||||
# plt.ylabel('Amplitude')
|
||||
# plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user