mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-07 15:33:29 +08:00
stream_infer: 在拼接音频时进行相关性搜索,减少拼接带来基频断裂的情况
This commit is contained in:
parent
03f99256c7
commit
cfb986a9c8
@ -183,6 +183,57 @@ class StepVitsModel(nn.Module):
|
||||
)[0, 0]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor):
|
||||
ref_len = len(reference_audio)
|
||||
search_len = len(search_audio)
|
||||
|
||||
if search_len < ref_len:
|
||||
raise ValueError(
|
||||
f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})"
|
||||
)
|
||||
|
||||
# 使用F.conv1d计算原始互相关
|
||||
reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0)
|
||||
search_padded = search_audio.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# 计算点积
|
||||
dot_products = F.conv1d(search_padded, reference_flipped).squeeze()
|
||||
|
||||
if len(dot_products.shape) == 0:
|
||||
dot_products = dot_products.unsqueeze(0)
|
||||
|
||||
# 计算参考音频的平方和
|
||||
ref_squared_sum = torch.sum(reference_audio**2)
|
||||
|
||||
# 计算搜索音频每个位置的平方和(滑动窗口)
|
||||
search_squared = search_audio**2
|
||||
search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0)
|
||||
ones_kernel = torch.ones(
|
||||
1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device
|
||||
)
|
||||
|
||||
segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze()
|
||||
|
||||
if len(segment_squared_sums.shape) == 0:
|
||||
segment_squared_sums = segment_squared_sums.unsqueeze(0)
|
||||
|
||||
# 计算归一化因子
|
||||
ref_norm = torch.sqrt(ref_squared_sum)
|
||||
segment_norms = torch.sqrt(segment_squared_sums)
|
||||
|
||||
# 避免除零
|
||||
epsilon = 1e-8
|
||||
normalization_factor = ref_norm * segment_norms + epsilon
|
||||
|
||||
# 归一化互相关
|
||||
correlation_scores = dot_products / normalization_factor
|
||||
|
||||
best_offset = torch.argmax(correlation_scores).item()
|
||||
|
||||
return best_offset, correlation_scores
|
||||
|
||||
|
||||
import time
|
||||
|
||||
def test_stream(
|
||||
@ -213,7 +264,7 @@ def test_stream(
|
||||
ref_bert = ref_bert.to(ref_seq.device)
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
|
||||
"这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T
|
||||
@ -283,6 +334,9 @@ def test_stream(
|
||||
idx = 1
|
||||
last_idx = 0
|
||||
audios = []
|
||||
raw_audios = []
|
||||
last_audio_ret = None
|
||||
offset_index = []
|
||||
full_audios = []
|
||||
print("y.shape:", y.shape)
|
||||
cut_id = 0
|
||||
@ -292,7 +346,7 @@ def test_stream(
|
||||
stop = last_token==t2s.EOS
|
||||
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
|
||||
|
||||
if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
||||
if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
||||
cut_id = idx + 7
|
||||
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
|
||||
# y = torch.cat([y, y[:,-1:]], dim=1)
|
||||
@ -312,13 +366,24 @@ def test_stream(
|
||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||
full_audios.append(audio)
|
||||
if last_idx == 0:
|
||||
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||
audio = audio[:-1280*8]
|
||||
raw_audios.append(audio)
|
||||
et = time.time()
|
||||
else:
|
||||
if stop:
|
||||
audio = audio[last_idx*1280 -1280*8:]
|
||||
audio_ = audio[last_idx*1280 -1280*8:]
|
||||
raw_audios.append(audio_)
|
||||
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||
offset_index.append(i)
|
||||
audio = audio_[i:]
|
||||
else:
|
||||
audio = audio[last_idx*1280 -1280*8:-1280*8]
|
||||
audio_ = audio[last_idx*1280 -1280*8:-1280*8]
|
||||
raw_audios.append(audio_)
|
||||
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||
offset_index.append(i)
|
||||
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||
audio = audio_[i:]
|
||||
last_idx = idx
|
||||
# print(f'write {output_path}/out_{audio_index}')
|
||||
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
@ -344,11 +409,13 @@ def test_stream(
|
||||
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audio = torch.cat(audios, dim=0)
|
||||
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audio_raw = torch.cat(raw_audios, dim=0)
|
||||
soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
|
||||
|
||||
fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6))
|
||||
fig, axes = plt.subplots(len(full_audios)+2, 1, figsize=(10, 6))
|
||||
|
||||
max_duration = full_audios[-1].shape[0]
|
||||
|
||||
@ -360,18 +427,24 @@ def test_stream(
|
||||
ax.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
ax.set_xlim(0, max_duration)
|
||||
axes[-1].axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
axes[-2].axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
|
||||
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||
axes[-2].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||
axes[-2].set_xlim(0, max_duration)
|
||||
|
||||
axes[-1].plot(audio_raw.float().detach().cpu().numpy(), color='black', label='Raw Audio')
|
||||
axes[-1].set_xlim(0, max_duration)
|
||||
|
||||
for i,y in enumerate(y[0][-idx:]):
|
||||
axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
|
||||
# axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
|
||||
axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5)
|
||||
|
||||
|
||||
# plt.title('Overlapped Waveform Comparison')
|
||||
# plt.xlabel('Sample Number')
|
||||
# plt.ylabel('Amplitude')
|
||||
# plt.tight_layout()
|
||||
print("offset_index:", offset_index)
|
||||
plt.show()
|
||||
|
||||
|
||||
@ -509,12 +582,14 @@ def export_prov2(
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.no_grad():
|
||||
export_prov2(
|
||||
test_stream(
|
||||
gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
||||
version="v2Pro",
|
||||
ref_audio_path="/mnt/g/ad_ref.wav",
|
||||
ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
|
||||
# ref_audio_path="/mnt/g/ad_ref.wav",
|
||||
# ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
|
||||
ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav",
|
||||
ref_text='说真的,这件衣服才配得上本小姐嘛',
|
||||
output_path="streaming",
|
||||
device="cuda",
|
||||
is_half=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user