mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-10 10:09:51 +08:00
stream_infer: 在拼接音频时进行相关性搜索,减少拼接带来基频断裂的情况
This commit is contained in:
parent
03f99256c7
commit
cfb986a9c8
@ -183,6 +183,57 @@ class StepVitsModel(nn.Module):
|
|||||||
)[0, 0]
|
)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor):
|
||||||
|
ref_len = len(reference_audio)
|
||||||
|
search_len = len(search_audio)
|
||||||
|
|
||||||
|
if search_len < ref_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用F.conv1d计算原始互相关
|
||||||
|
reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0)
|
||||||
|
search_padded = search_audio.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算点积
|
||||||
|
dot_products = F.conv1d(search_padded, reference_flipped).squeeze()
|
||||||
|
|
||||||
|
if len(dot_products.shape) == 0:
|
||||||
|
dot_products = dot_products.unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算参考音频的平方和
|
||||||
|
ref_squared_sum = torch.sum(reference_audio**2)
|
||||||
|
|
||||||
|
# 计算搜索音频每个位置的平方和(滑动窗口)
|
||||||
|
search_squared = search_audio**2
|
||||||
|
search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0)
|
||||||
|
ones_kernel = torch.ones(
|
||||||
|
1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device
|
||||||
|
)
|
||||||
|
|
||||||
|
segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze()
|
||||||
|
|
||||||
|
if len(segment_squared_sums.shape) == 0:
|
||||||
|
segment_squared_sums = segment_squared_sums.unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算归一化因子
|
||||||
|
ref_norm = torch.sqrt(ref_squared_sum)
|
||||||
|
segment_norms = torch.sqrt(segment_squared_sums)
|
||||||
|
|
||||||
|
# 避免除零
|
||||||
|
epsilon = 1e-8
|
||||||
|
normalization_factor = ref_norm * segment_norms + epsilon
|
||||||
|
|
||||||
|
# 归一化互相关
|
||||||
|
correlation_scores = dot_products / normalization_factor
|
||||||
|
|
||||||
|
best_offset = torch.argmax(correlation_scores).item()
|
||||||
|
|
||||||
|
return best_offset, correlation_scores
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
def test_stream(
|
def test_stream(
|
||||||
@ -213,7 +264,7 @@ def test_stream(
|
|||||||
ref_bert = ref_bert.to(ref_seq.device)
|
ref_bert = ref_bert.to(ref_seq.device)
|
||||||
|
|
||||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||||
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
|
"这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
|
||||||
)
|
)
|
||||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||||
text_bert = text_bert_T.T
|
text_bert = text_bert_T.T
|
||||||
@ -283,6 +334,9 @@ def test_stream(
|
|||||||
idx = 1
|
idx = 1
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
audios = []
|
audios = []
|
||||||
|
raw_audios = []
|
||||||
|
last_audio_ret = None
|
||||||
|
offset_index = []
|
||||||
full_audios = []
|
full_audios = []
|
||||||
print("y.shape:", y.shape)
|
print("y.shape:", y.shape)
|
||||||
cut_id = 0
|
cut_id = 0
|
||||||
@ -292,7 +346,7 @@ def test_stream(
|
|||||||
stop = last_token==t2s.EOS
|
stop = last_token==t2s.EOS
|
||||||
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
|
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
|
||||||
|
|
||||||
if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
||||||
cut_id = idx + 7
|
cut_id = idx + 7
|
||||||
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
|
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
|
||||||
# y = torch.cat([y, y[:,-1:]], dim=1)
|
# y = torch.cat([y, y[:,-1:]], dim=1)
|
||||||
@ -312,13 +366,24 @@ def test_stream(
|
|||||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||||
full_audios.append(audio)
|
full_audios.append(audio)
|
||||||
if last_idx == 0:
|
if last_idx == 0:
|
||||||
|
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||||
audio = audio[:-1280*8]
|
audio = audio[:-1280*8]
|
||||||
|
raw_audios.append(audio)
|
||||||
et = time.time()
|
et = time.time()
|
||||||
else:
|
else:
|
||||||
if stop:
|
if stop:
|
||||||
audio = audio[last_idx*1280 -1280*8:]
|
audio_ = audio[last_idx*1280 -1280*8:]
|
||||||
|
raw_audios.append(audio_)
|
||||||
|
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||||
|
offset_index.append(i)
|
||||||
|
audio = audio_[i:]
|
||||||
else:
|
else:
|
||||||
audio = audio[last_idx*1280 -1280*8:-1280*8]
|
audio_ = audio[last_idx*1280 -1280*8:-1280*8]
|
||||||
|
raw_audios.append(audio_)
|
||||||
|
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||||
|
offset_index.append(i)
|
||||||
|
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||||
|
audio = audio_[i:]
|
||||||
last_idx = idx
|
last_idx = idx
|
||||||
# print(f'write {output_path}/out_{audio_index}')
|
# print(f'write {output_path}/out_{audio_index}')
|
||||||
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
@ -344,11 +409,13 @@ def test_stream(
|
|||||||
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
audio = torch.cat(audios, dim=0)
|
audio = torch.cat(audios, dim=0)
|
||||||
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
|
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
|
audio_raw = torch.cat(raw_audios, dim=0)
|
||||||
|
soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
|
||||||
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
|
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
|
||||||
|
|
||||||
fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6))
|
fig, axes = plt.subplots(len(full_audios)+2, 1, figsize=(10, 6))
|
||||||
|
|
||||||
max_duration = full_audios[-1].shape[0]
|
max_duration = full_audios[-1].shape[0]
|
||||||
|
|
||||||
@ -360,18 +427,24 @@ def test_stream(
|
|||||||
ax.axvline(x=last_line, color=colors[i], linestyle='--')
|
ax.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||||
ax.set_xlim(0, max_duration)
|
ax.set_xlim(0, max_duration)
|
||||||
axes[-1].axvline(x=last_line, color=colors[i], linestyle='--')
|
axes[-1].axvline(x=last_line, color=colors[i], linestyle='--')
|
||||||
|
axes[-2].axvline(x=last_line, color=colors[i], linestyle='--')
|
||||||
|
|
||||||
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
axes[-2].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||||
|
axes[-2].set_xlim(0, max_duration)
|
||||||
|
|
||||||
|
axes[-1].plot(audio_raw.float().detach().cpu().numpy(), color='black', label='Raw Audio')
|
||||||
axes[-1].set_xlim(0, max_duration)
|
axes[-1].set_xlim(0, max_duration)
|
||||||
|
|
||||||
for i,y in enumerate(y[0][-idx:]):
|
for i,y in enumerate(y[0][-idx:]):
|
||||||
axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
|
# axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
|
||||||
axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5)
|
axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5)
|
||||||
|
|
||||||
|
|
||||||
# plt.title('Overlapped Waveform Comparison')
|
# plt.title('Overlapped Waveform Comparison')
|
||||||
# plt.xlabel('Sample Number')
|
# plt.xlabel('Sample Number')
|
||||||
# plt.ylabel('Amplitude')
|
# plt.ylabel('Amplitude')
|
||||||
# plt.tight_layout()
|
# plt.tight_layout()
|
||||||
|
print("offset_index:", offset_index)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
@ -509,12 +582,14 @@ def export_prov2(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
export_prov2(
|
test_stream(
|
||||||
gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
||||||
version="v2Pro",
|
version="v2Pro",
|
||||||
ref_audio_path="/mnt/g/ad_ref.wav",
|
# ref_audio_path="/mnt/g/ad_ref.wav",
|
||||||
ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
|
# ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
|
||||||
|
ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav",
|
||||||
|
ref_text='说真的,这件衣服才配得上本小姐嘛',
|
||||||
output_path="streaming",
|
output_path="streaming",
|
||||||
device="cuda",
|
device="cuda",
|
||||||
is_half=True,
|
is_half=True,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user