stream_infer: 在拼接音频时进行相关性搜索,减少拼接带来基频断裂的情况

This commit is contained in:
csh 2025-07-01 20:58:16 +08:00
parent 03f99256c7
commit cfb986a9c8

View File

@ -183,6 +183,57 @@ class StepVitsModel(nn.Module):
)[0, 0]
@torch.jit.script
def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor):
ref_len = len(reference_audio)
search_len = len(search_audio)
if search_len < ref_len:
raise ValueError(
f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})"
)
# 使用F.conv1d计算原始互相关
reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0)
search_padded = search_audio.unsqueeze(0).unsqueeze(0)
# 计算点积
dot_products = F.conv1d(search_padded, reference_flipped).squeeze()
if len(dot_products.shape) == 0:
dot_products = dot_products.unsqueeze(0)
# 计算参考音频的平方和
ref_squared_sum = torch.sum(reference_audio**2)
# 计算搜索音频每个位置的平方和(滑动窗口)
search_squared = search_audio**2
search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0)
ones_kernel = torch.ones(
1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device
)
segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze()
if len(segment_squared_sums.shape) == 0:
segment_squared_sums = segment_squared_sums.unsqueeze(0)
# 计算归一化因子
ref_norm = torch.sqrt(ref_squared_sum)
segment_norms = torch.sqrt(segment_squared_sums)
# 避免除零
epsilon = 1e-8
normalization_factor = ref_norm * segment_norms + epsilon
# 归一化互相关
correlation_scores = dot_products / normalization_factor
best_offset = torch.argmax(correlation_scores).item()
return best_offset, correlation_scores
import time
def test_stream(
@ -213,7 +264,7 @@ def test_stream(
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
"这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
@ -283,6 +334,9 @@ def test_stream(
idx = 1
last_idx = 0
audios = []
raw_audios = []
last_audio_ret = None
offset_index = []
full_audios = []
print("y.shape:", y.shape)
cut_id = 0
@ -292,7 +346,7 @@ def test_stream(
stop = last_token==t2s.EOS
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
cut_id = idx + 7
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
# y = torch.cat([y, y[:,-1:]], dim=1)
@ -312,13 +366,24 @@ def test_stream(
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
full_audios.append(audio)
if last_idx == 0:
last_audio_ret = audio[-1280*8:-1280*8+256]
audio = audio[:-1280*8]
raw_audios.append(audio)
et = time.time()
else:
if stop:
audio = audio[last_idx*1280 -1280*8:]
audio_ = audio[last_idx*1280 -1280*8:]
raw_audios.append(audio_)
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
offset_index.append(i)
audio = audio_[i:]
else:
audio = audio[last_idx*1280 -1280*8:-1280*8]
audio_ = audio[last_idx*1280 -1280*8:-1280*8]
raw_audios.append(audio_)
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
offset_index.append(i)
last_audio_ret = audio[-1280*8:-1280*8+256]
audio = audio_[i:]
last_idx = idx
# print(f'write {output_path}/out_{audio_index}')
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
@ -344,11 +409,13 @@ def test_stream(
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
audio = torch.cat(audios, dim=0)
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
audio_raw = torch.cat(raw_audios, dim=0)
soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000)
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6))
fig, axes = plt.subplots(len(full_audios)+2, 1, figsize=(10, 6))
max_duration = full_audios[-1].shape[0]
@ -360,18 +427,24 @@ def test_stream(
ax.axvline(x=last_line, color=colors[i], linestyle='--')
ax.set_xlim(0, max_duration)
axes[-1].axvline(x=last_line, color=colors[i], linestyle='--')
axes[-2].axvline(x=last_line, color=colors[i], linestyle='--')
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
axes[-2].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
axes[-2].set_xlim(0, max_duration)
axes[-1].plot(audio_raw.float().detach().cpu().numpy(), color='black', label='Raw Audio')
axes[-1].set_xlim(0, max_duration)
for i,y in enumerate(y[0][-idx:]):
axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
# axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5)
# plt.title('Overlapped Waveform Comparison')
# plt.xlabel('Sample Number')
# plt.ylabel('Amplitude')
# plt.tight_layout()
print("offset_index:", offset_index)
plt.show()
@ -509,12 +582,14 @@ def export_prov2(
if __name__ == "__main__":
with torch.no_grad():
export_prov2(
test_stream(
gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt",
vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
version="v2Pro",
ref_audio_path="/mnt/g/ad_ref.wav",
ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
# ref_audio_path="/mnt/g/ad_ref.wav",
# ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav",
ref_text='说真的,这件衣服才配得上本小姐嘛',
output_path="streaming",
device="cuda",
is_half=True,