diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index e94c7a43..8c9313bb 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -183,6 +183,57 @@ class StepVitsModel(nn.Module): )[0, 0] +@torch.jit.script +def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor): + ref_len = len(reference_audio) + search_len = len(search_audio) + + if search_len < ref_len: + raise ValueError( + f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})" + ) + + # 使用F.conv1d计算原始互相关 + reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0) + search_padded = search_audio.unsqueeze(0).unsqueeze(0) + + # 计算点积 + dot_products = F.conv1d(search_padded, reference_flipped).squeeze() + + if len(dot_products.shape) == 0: + dot_products = dot_products.unsqueeze(0) + + # 计算参考音频的平方和 + ref_squared_sum = torch.sum(reference_audio**2) + + # 计算搜索音频每个位置的平方和(滑动窗口) + search_squared = search_audio**2 + search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0) + ones_kernel = torch.ones( + 1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device + ) + + segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze() + + if len(segment_squared_sums.shape) == 0: + segment_squared_sums = segment_squared_sums.unsqueeze(0) + + # 计算归一化因子 + ref_norm = torch.sqrt(ref_squared_sum) + segment_norms = torch.sqrt(segment_squared_sums) + + # 避免除零 + epsilon = 1e-8 + normalization_factor = ref_norm * segment_norms + epsilon + + # 归一化互相关 + correlation_scores = dot_products / normalization_factor + + best_offset = torch.argmax(correlation_scores).item() + + return best_offset, correlation_scores + + import time def test_stream( @@ -213,7 +264,7 @@ def test_stream( ref_bert = ref_bert.to(ref_seq.device) text_seq_id, text_bert_T, norm_text = get_phones_and_bert( - "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" + "这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" ) text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T @@ -283,6 +334,9 @@ def test_stream( idx = 1 last_idx = 0 audios = [] + raw_audios = [] + last_audio_ret = None + offset_index = [] full_audios = [] print("y.shape:", y.shape) cut_id = 0 @@ -292,7 +346,7 @@ def test_stream( stop = last_token==t2s.EOS print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx) - if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: + if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: cut_id = idx + 7 print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape) # y = torch.cat([y, y[:,-1:]], dim=1) @@ -312,13 +366,24 @@ def test_stream( audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] full_audios.append(audio) if last_idx == 0: + last_audio_ret = audio[-1280*8:-1280*8+256] audio = audio[:-1280*8] + raw_audios.append(audio) et = time.time() else: if stop: - audio = audio[last_idx*1280 -1280*8:] + audio_ = audio[last_idx*1280 -1280*8:] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + audio = audio_[i:] else: - audio = audio[last_idx*1280 -1280*8:-1280*8] + audio_ = audio[last_idx*1280 -1280*8:-1280*8] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + last_audio_ret = audio[-1280*8:-1280*8+256] + audio = audio_[i:] last_idx = idx # print(f'write {output_path}/out_{audio_index}') # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) @@ -344,11 +409,13 @@ def test_stream( soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) audio = torch.cat(audios, dim=0) soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) + audio_raw = torch.cat(raw_audios, dim=0) + soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000) colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] - fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6)) + fig, axes = plt.subplots(len(full_audios)+2, 1, figsize=(10, 6)) max_duration = full_audios[-1].shape[0] @@ -360,18 +427,24 @@ def test_stream( ax.axvline(x=last_line, color=colors[i], linestyle='--') ax.set_xlim(0, max_duration) axes[-1].axvline(x=last_line, color=colors[i], linestyle='--') + axes[-2].axvline(x=last_line, color=colors[i], linestyle='--') - axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') + axes[-2].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') + axes[-2].set_xlim(0, max_duration) + + axes[-1].plot(audio_raw.float().detach().cpu().numpy(), color='black', label='Raw Audio') axes[-1].set_xlim(0, max_duration) for i,y in enumerate(y[0][-idx:]): - axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center') + # axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center') axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5) + # plt.title('Overlapped Waveform Comparison') # plt.xlabel('Sample Number') # plt.ylabel('Amplitude') # plt.tight_layout() + print("offset_index:", offset_index) plt.show() @@ -509,12 +582,14 @@ def export_prov2( if __name__ == "__main__": with torch.no_grad(): - export_prov2( + test_stream( gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", version="v2Pro", - ref_audio_path="/mnt/g/ad_ref.wav", - ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", + # ref_audio_path="/mnt/g/ad_ref.wav", + # ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", + ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", + ref_text='说真的,这件衣服才配得上本小姐嘛', output_path="streaming", device="cuda", is_half=True,