From 131e2ffcb77436dd5fa5a658ec1ad9796ff40598 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Tue, 17 Jun 2025 03:40:26 +0800 Subject: [PATCH 1/9] =?UTF-8?q?=E5=B0=9D=E8=AF=95=20stream=20infer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 313 +++++++++++++++++++++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 GPT_SoVITS/stream_v2pro.py diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py new file mode 100644 index 00000000..3fd6dbe0 --- /dev/null +++ b/GPT_SoVITS/stream_v2pro.py @@ -0,0 +1,313 @@ +# 这是一个实验性质的实现,旨在探索 stream infer 的可能性。(xiao hai xie zhe wan de) +from typing import List +from export_torch_script import ExportERes2NetV2, SSLModel, T2SModel, VitsModel, get_raw_t2s_model, init_sv_cn, resamplex, sample, spectrogram_torch +import export_torch_script +from my_utils import load_audio +import torch +from torch import LongTensor, Tensor, nn +from torch.nn import functional as F + +import soundfile +from inference_webui import get_phones_and_bert + + +class StreamT2SModel(nn.Module): + def __init__(self, t2s: T2SModel): + super(StreamT2SModel, self).__init__() + self.t2s = t2s + self.k_cache: list[torch.Tensor] = [torch.zeros([1])] + self.v_cache: list[torch.Tensor] = [torch.zeros([1])] + + @torch.jit.export + def pre_infer( + self, + prompts: LongTensor, + ref_seq: LongTensor, + text_seq: LongTensor, + ref_bert: torch.Tensor, + text_bert: torch.Tensor, + top_k: int, + ) -> tuple[int, Tensor, Tensor]: + bert = torch.cat([ref_bert.T, text_bert.T], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + + x = self.t2s.ar_text_embedding(all_phoneme_ids) + x = x + self.t2s.bert_proj(bert.transpose(1, 2)) + x: torch.Tensor = self.t2s.ar_text_position(x) + + # [1,N,512] [1,N] + # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y = prompts + # x_example = x[:,:,0] * 0.0 + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + y_emb = self.t2s.ar_audio_embedding(y) + y_len: int = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.t2s.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + bsz = x.shape[0] + src_len = x_len + y_len + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = ( + torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + .unsqueeze(0) + .expand(bsz * self.t2s.num_head, -1, -1) + .view(bsz, self.t2s.num_head, src_len, src_len) + .to(device=x.device, dtype=torch.bool) + ) + + xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.process_prompt( + xy_pos, xy_attn_mask, None + ) + + logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) + logits = logits[:, :-1] + samples = sample( + logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 + )[0] + y = torch.concat([y, samples], dim=1) + y_emb: Tensor = self.t2s.ar_audio_embedding(y[:, -1:]) + xy_pos: Tensor = ( + y_emb * self.t2s.ar_audio_position.x_scale + + self.t2s.ar_audio_position.alpha + * self.t2s.ar_audio_position.pe[:, y_len].to( + dtype=y_emb.dtype, device=y_emb.device + ) + ) + + self.k_cache = k_cache + self.v_cache = v_cache + return y_len, y, xy_pos + + @torch.jit.export + def decode_next_token( + self, + idx: int, # 记住从1开始 到1500 + top_k: int, + y_len: int, + y: Tensor, + xy_pos: Tensor, + ) -> tuple[Tensor, Tensor, bool]: + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) + xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token( + xy_pos, self.k_cache, self.v_cache + ) + logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) + + if idx < 11: ###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + + samples = sample( + logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 + )[0] + + y = torch.concat([y, samples], dim=1) + + # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + # stop = True + if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: + self.k_cache = [torch.zeros([1])] + self.v_cache = [torch.zeros([1])] + return y[:,:-1], xy_pos, True + + # if stop: + # if y.shape[1] == 0: + # y = torch.concat([y, torch.zeros_like(samples)], dim=1) + # break + + y_emb = self.t2s.ar_audio_embedding(y[:, -1:]) + xy_pos = ( + y_emb * self.t2s.ar_audio_position.x_scale + + self.t2s.ar_audio_position.alpha + * self.t2s.ar_audio_position.pe[:, y_len + idx].to( + dtype=y_emb.dtype, device=y_emb.device + ) + ) + return y, xy_pos, False + + def forward( + self, + idx: int, # 记住从1开始 到1500 + top_k: int, + y_len: int, + y: Tensor, + xy_pos: Tensor, + ): + return self.decode_next_token(idx,top_k,y_len,y,xy_pos) + +import time + +def export_prov2( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + export_bert_and_ssl=False, + device="cpu", + is_half=True, +): + if export_torch_script.sv_cn_model == None: + init_sv_cn(device,is_half) + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, "all_zh", "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。可能这就是狐狸吧.你觉得狐狸神奇吗?", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version,is_half=is_half,device=device) + vits.eval() + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + # t2s = torch.jit.script(t2s_m).to(device) + t2s = t2s_m + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + + stream_t2s = StreamT2SModel(t2s).to(device) + # stream_t2s = torch.jit.script(stream_t2s) + + ref_audio_sr = resamplex(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + top_k = 15 + + codes = vits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = sv_model(audio_16k) + print("text_seq",text_seq.shape) + + refer = spectrogram_torch( + vits.hann_window, + ref_audio_sr, + vits.hps.data.filter_length, + vits.hps.data.sampling_rate, + vits.hps.data.hop_length, + vits.hps.data.win_length, + center=False, + ) + + st = time.time() + et = time.time() + + y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + idx = 1 + audio_index = 0 + last_idx = 0 + audios = [] + print("y.shape:", y.shape) + while True: + y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos) + # print("y.shape:", y.shape) + + # 玄学这档子事说不清楚 + if (y[0,-1] < 60 and idx-last_idx > 25) or stop: + audio = vits.vq_model(y[:,-idx:-1].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + if last_idx == 0: + audio = audio[:-640] + et = time.time() + else: + if stop: + audio = audio[last_idx*1280 -640:] + else: + audio = audio[last_idx*1280 -640:-640] + print(y[:,-idx+last_idx:]) + last_idx = idx + # print(f'write {output_path}/out_{audio_index}') + # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) + audio_index+=1 + audios.append(audio) + + idx+=1 + # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) + if idx>1500: + break + + if stop: + break + + at = time.time() + + for (i,a) in enumerate(audios): + print(f'write {output_path}/out_{i}') + soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000) + + print("final,",audio_index) + print(f"frist token: {et - st:.4f} seconds") + print(f"all token: {at - st:.4f} seconds") + audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) + audio = torch.cat(audios, dim=0) + soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) + + +if __name__ == "__main__": + with torch.no_grad(): + export_prov2( + gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", + vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", + version="v2Pro", + ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", + ref_text="真的,这件衣服才配得上本小姐嘛", + output_path="streaming", + export_bert_and_ssl=True, + device="cuda", + is_half=True, + ) From 6fe3861e730646717f72ebfb1f22dd980194a3f2 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Tue, 17 Jun 2025 13:39:16 +0800 Subject: [PATCH 2/9] =?UTF-8?q?=E5=9C=A8=20stream=5Finfer=20=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E4=B8=AD=E7=BB=98=E5=88=B6=E7=94=9F=E6=88=90=E7=9A=84?= =?UTF-8?q?=E9=9F=B3=E9=A2=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 46 +++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index 3fd6dbe0..ad4ed26e 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -9,6 +9,8 @@ from torch.nn import functional as F import soundfile from inference_webui import get_phones_and_bert +import matplotlib.pyplot as plt + class StreamT2SModel(nn.Module): @@ -181,7 +183,7 @@ def export_prov2( ref_bert = ref_bert.to(ref_seq.device) text_seq_id, text_bert_T, norm_text = get_phones_and_bert( - "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。可能这就是狐狸吧.你觉得狐狸神奇吗?", "auto", "v2" + "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" ) text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T @@ -217,7 +219,7 @@ def export_prov2( print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) stream_t2s = StreamT2SModel(t2s).to(device) - # stream_t2s = torch.jit.script(stream_t2s) + stream_t2s = torch.jit.script(stream_t2s) ref_audio_sr = resamplex(ref_audio, 16000, 32000) if is_half: @@ -249,30 +251,30 @@ def export_prov2( y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) idx = 1 - audio_index = 0 last_idx = 0 audios = [] + full_audios = [] print("y.shape:", y.shape) while True: y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos) # print("y.shape:", y.shape) # 玄学这档子事说不清楚 - if (y[0,-1] < 60 and idx-last_idx > 25) or stop: - audio = vits.vq_model(y[:,-idx:-1].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop: + audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] + full_audios.append(audio) if last_idx == 0: - audio = audio[:-640] + audio = audio[:-1280*8] et = time.time() else: if stop: - audio = audio[last_idx*1280 -640:] + audio = audio[last_idx*1280 -1280*8:] else: - audio = audio[last_idx*1280 -640:-640] + audio = audio[last_idx*1280 -1280*8:-1280*8] print(y[:,-idx+last_idx:]) last_idx = idx # print(f'write {output_path}/out_{audio_index}') # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) - audio_index+=1 audios.append(audio) idx+=1 @@ -289,13 +291,37 @@ def export_prov2( print(f'write {output_path}/out_{i}') soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000) - print("final,",audio_index) print(f"frist token: {et - st:.4f} seconds") print(f"all token: {at - st:.4f} seconds") audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) audio = torch.cat(audios, dim=0) soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) + + + colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] + + fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6)) + + max_duration = full_audios[-1].shape[0] + + last_line = 0 + for i,(ax,a) in enumerate(zip(axes[:-1],full_audios)): + ax.plot(a.float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}") + ax.axvline(x=last_line, color=colors[i], linestyle='--') + last_line = a.shape[0]-8*1280 + ax.axvline(x=last_line, color=colors[i], linestyle='--') + ax.set_xlim(0, max_duration) + axes[-1].axvline(x=last_line, color=colors[i], linestyle='--') + + axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') + axes[-1].set_xlim(0, max_duration) + + # plt.title('Overlapped Waveform Comparison') + # plt.xlabel('Sample Number') + # plt.ylabel('Amplitude') + # plt.tight_layout() + plt.show() if __name__ == "__main__": From 920bbafb12ea18c2cd6a5c10f04a2555ca3e977d Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Wed, 18 Jun 2025 01:47:40 +0800 Subject: [PATCH 3/9] =?UTF-8?q?stream=5Finfer=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=AF=BC=E5=87=BA=E9=83=A8=E5=88=86=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/export_torch_script.py | 46 ++----- GPT_SoVITS/stream_v2pro.py | 203 +++++++++++++++++++++++++++--- 2 files changed, 196 insertions(+), 53 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index bf32ed6e..c0cf03c2 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -256,41 +256,21 @@ class T2SBlock: attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) - if padding_mask is not None: - for i in range(batch_size): - # mask = padding_mask[i,:,0] - if self.false.device != padding_mask.device: - self.false = self.false.to(padding_mask.device) - idx = torch.where(padding_mask[i, :, 0] == self.false)[0] - x_item = x[i, idx, :].unsqueeze(0) - attn_item = attn[i, idx, :].unsqueeze(0) - x_item = x_item + attn_item - x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x_item = x_item + self.mlp.forward(x_item) - x_item = F.layer_norm( - x_item, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) - x[i, idx, :] = x_item.squeeze(0) - x = self.to_mask(x, padding_mask) - else: - x = x + attn - x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x = x + self.mlp.forward(x) - x = F.layer_norm( - x, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) return x, k_cache, v_cache def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor): diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index ad4ed26e..127dbca7 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -12,13 +12,10 @@ from inference_webui import get_phones_and_bert import matplotlib.pyplot as plt - class StreamT2SModel(nn.Module): def __init__(self, t2s: T2SModel): super(StreamT2SModel, self).__init__() self.t2s = t2s - self.k_cache: list[torch.Tensor] = [torch.zeros([1])] - self.v_cache: list[torch.Tensor] = [torch.zeros([1])] @torch.jit.export def pre_infer( @@ -29,7 +26,7 @@ class StreamT2SModel(nn.Module): ref_bert: torch.Tensor, text_bert: torch.Tensor, top_k: int, - ) -> tuple[int, Tensor, Tensor]: + ) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]: bert = torch.cat([ref_bert.T, text_bert.T], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) @@ -91,9 +88,7 @@ class StreamT2SModel(nn.Module): ) ) - self.k_cache = k_cache - self.v_cache = v_cache - return y_len, y, xy_pos + return y_len, y, xy_pos, k_cache, v_cache @torch.jit.export def decode_next_token( @@ -103,11 +98,13 @@ class StreamT2SModel(nn.Module): y_len: int, y: Tensor, xy_pos: Tensor, - ) -> tuple[Tensor, Tensor, bool]: + k_cache: List[Tensor], + v_cache: List[Tensor], + ) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]: # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token( - xy_pos, self.k_cache, self.v_cache + xy_pos, k_cache, v_cache ) logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) @@ -119,13 +116,12 @@ class StreamT2SModel(nn.Module): )[0] y = torch.concat([y, samples], dim=1) + last_token = int(samples[0, 0]) # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: # stop = True if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: - self.k_cache = [torch.zeros([1])] - self.v_cache = [torch.zeros([1])] - return y[:,:-1], xy_pos, True + return y[:,:-1], xy_pos, last_token, k_cache, v_cache # if stop: # if y.shape[1] == 0: @@ -140,7 +136,7 @@ class StreamT2SModel(nn.Module): dtype=y_emb.dtype, device=y_emb.device ) ) - return y, xy_pos, False + return y, xy_pos, last_token, k_cache, v_cache def forward( self, @@ -149,12 +145,47 @@ class StreamT2SModel(nn.Module): y_len: int, y: Tensor, xy_pos: Tensor, + k_cache: List[Tensor], + v_cache: List[Tensor], ): - return self.decode_next_token(idx,top_k,y_len,y,xy_pos) + return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache) + + +class StepVitsModel(nn.Module): + def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2): + super().__init__() + self.hps = vits.hps + self.vq_model = vits.vq_model + self.hann_window = vits.hann_window + self.sv = sv_model + + def ref_handle(self, ref_audio_32k): + refer = spectrogram_torch( + self.hann_window, + ref_audio_32k, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device) + sv_emb = self.sv(ref_audio_16k) + return refer, sv_emb + + def extract_latent(self, ssl_content): + codes = self.vq_model.extract_latent(ssl_content) + return codes[0] + + def forward(self, pred_semantic, text_seq, refer, sv_emb=None): + return self.vq_model( + pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb + )[0, 0] + import time -def export_prov2( +def test_stream( gpt_path, vits_path, version, @@ -249,15 +280,16 @@ def export_prov2( st = time.time() et = time.time() - y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) idx = 1 last_idx = 0 audios = [] full_audios = [] print("y.shape:", y.shape) while True: - y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos) + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) # print("y.shape:", y.shape) + stop = last_token==t2s.EOS # 玄学这档子事说不清楚 if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop: @@ -324,16 +356,147 @@ def export_prov2( plt.show() +def export_prov2( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + device="cpu", + is_half=True, +): + if export_torch_script.sv_cn_model == None: + init_sv_cn(device,is_half) + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, "all_zh", "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version,is_half=is_half,device=device) + vits.eval() + vits = StepVitsModel(vits, sv_model) + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + # t2s = torch.jit.script(t2s_m).to(device) + t2s = t2s_m + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + + stream_t2s = StreamT2SModel(t2s).to(device) + stream_t2s = torch.jit.script(stream_t2s) + + ref_audio_sr = resamplex(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + top_k = 15 + + prompts = vits.extract_latent(ssl_content) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = sv_model(audio_16k) + print("text_seq",text_seq.shape) + # torch.jit.trace() + + refer,sv_emb = vits.ref_handle(ref_audio_sr) + + st = time.time() + et = time.time() + + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + idx = 1 + print("y.shape:", y.shape) + while True: + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) + # print("y.shape:", y.shape) + + idx+=1 + # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) + if idx>1500: + break + + if last_token == t2s.EOS: + break + + at = time.time() + print("EOS:",t2s.EOS) + + print(f"frist token: {et - st:.4f} seconds") + print(f"all token: {at - st:.4f} seconds") + print("sv_emb", sv_emb.shape) + print("refer",refer.shape) + y = y[:,-idx:].unsqueeze(0) + print("y", y.shape) + audio = vits(y, text_seq, refer, sv_emb) + soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + torch._dynamo.mark_dynamic(refer, 2) + torch._dynamo.mark_dynamic(y, 2) + + inputs = { + "forward": (y, text_seq, refer, sv_emb), + "extract_latent": ssl_content, + "ref_handle": ref_audio_sr, + } + + stream_t2s.save(f"{output_path}/t2s.pt") + torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt") + + if __name__ == "__main__": with torch.no_grad(): export_prov2( gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", version="v2Pro", - ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", - ref_text="真的,这件衣服才配得上本小姐嘛", + ref_audio_path="/mnt/g/ad_ref.wav", + ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", output_path="streaming", - export_bert_and_ssl=True, device="cuda", is_half=True, ) From 03f99256c78b027d03ac456c80bf822d21df6aab Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Thu, 19 Jun 2025 00:36:17 +0800 Subject: [PATCH 4/9] =?UTF-8?q?stream=5Finfer:=20=E6=9B=B4=E6=96=B9?= =?UTF-8?q?=E4=BE=BF=E6=89=BE=E8=A7=84=E5=BE=8B=E7=9A=84=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index 127dbca7..e94c7a43 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -121,7 +121,7 @@ class StreamT2SModel(nn.Module): # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: # stop = True if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: - return y[:,:-1], xy_pos, last_token, k_cache, v_cache + return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache # if stop: # if y.shape[1] == 0: @@ -192,7 +192,6 @@ def test_stream( ref_audio_path, ref_text, output_path, - export_bert_and_ssl=False, device="cpu", is_half=True, ): @@ -286,13 +285,30 @@ def test_stream( audios = [] full_audios = [] print("y.shape:", y.shape) + cut_id = 0 while True: y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) # print("y.shape:", y.shape) stop = last_token==t2s.EOS + print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx) + + if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: + cut_id = idx + 7 + print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape) + # y = torch.cat([y, y[:,-1:]], dim=1) + # idx+=1 + + if stop : + idx -=1 + print('stop') + print(idx, y[:,-idx+last_idx:]) + print(idx,last_idx, y.shape) + print(y[:,-idx:-idx+20]) + # 玄学这档子事说不清楚 - if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop: + if idx == cut_id or stop: + print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}") audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] full_audios.append(audio) if last_idx == 0: @@ -303,13 +319,10 @@ def test_stream( audio = audio[last_idx*1280 -1280*8:] else: audio = audio[last_idx*1280 -1280*8:-1280*8] - print(y[:,-idx+last_idx:]) last_idx = idx # print(f'write {output_path}/out_{audio_index}') # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) audios.append(audio) - - idx+=1 # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) if idx>1500: break @@ -317,6 +330,8 @@ def test_stream( if stop: break + idx+=1 + at = time.time() for (i,a) in enumerate(audios): @@ -349,6 +364,10 @@ def test_stream( axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') axes[-1].set_xlim(0, max_duration) + for i,y in enumerate(y[0][-idx:]): + axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center') + axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5) + # plt.title('Overlapped Waveform Comparison') # plt.xlabel('Sample Number') # plt.ylabel('Amplitude') From cfb986a9c8e6d6530f67816583ee3ece365c1a79 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Tue, 1 Jul 2025 20:58:16 +0800 Subject: [PATCH 5/9] =?UTF-8?q?stream=5Finfer:=20=E5=9C=A8=E6=8B=BC?= =?UTF-8?q?=E6=8E=A5=E9=9F=B3=E9=A2=91=E6=97=B6=E8=BF=9B=E8=A1=8C=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E6=80=A7=E6=90=9C=E7=B4=A2=EF=BC=8C=E5=87=8F=E5=B0=91?= =?UTF-8?q?=E6=8B=BC=E6=8E=A5=E5=B8=A6=E6=9D=A5=E5=9F=BA=E9=A2=91=E6=96=AD?= =?UTF-8?q?=E8=A3=82=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 95 ++++++++++++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 10 deletions(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index e94c7a43..8c9313bb 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -183,6 +183,57 @@ class StepVitsModel(nn.Module): )[0, 0] +@torch.jit.script +def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor): + ref_len = len(reference_audio) + search_len = len(search_audio) + + if search_len < ref_len: + raise ValueError( + f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})" + ) + + # 使用F.conv1d计算原始互相关 + reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0) + search_padded = search_audio.unsqueeze(0).unsqueeze(0) + + # 计算点积 + dot_products = F.conv1d(search_padded, reference_flipped).squeeze() + + if len(dot_products.shape) == 0: + dot_products = dot_products.unsqueeze(0) + + # 计算参考音频的平方和 + ref_squared_sum = torch.sum(reference_audio**2) + + # 计算搜索音频每个位置的平方和(滑动窗口) + search_squared = search_audio**2 + search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0) + ones_kernel = torch.ones( + 1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device + ) + + segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze() + + if len(segment_squared_sums.shape) == 0: + segment_squared_sums = segment_squared_sums.unsqueeze(0) + + # 计算归一化因子 + ref_norm = torch.sqrt(ref_squared_sum) + segment_norms = torch.sqrt(segment_squared_sums) + + # 避免除零 + epsilon = 1e-8 + normalization_factor = ref_norm * segment_norms + epsilon + + # 归一化互相关 + correlation_scores = dot_products / normalization_factor + + best_offset = torch.argmax(correlation_scores).item() + + return best_offset, correlation_scores + + import time def test_stream( @@ -213,7 +264,7 @@ def test_stream( ref_bert = ref_bert.to(ref_seq.device) text_seq_id, text_bert_T, norm_text = get_phones_and_bert( - "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" + "这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" ) text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T @@ -283,6 +334,9 @@ def test_stream( idx = 1 last_idx = 0 audios = [] + raw_audios = [] + last_audio_ret = None + offset_index = [] full_audios = [] print("y.shape:", y.shape) cut_id = 0 @@ -292,7 +346,7 @@ def test_stream( stop = last_token==t2s.EOS print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx) - if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: + if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id: cut_id = idx + 7 print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape) # y = torch.cat([y, y[:,-1:]], dim=1) @@ -312,13 +366,24 @@ def test_stream( audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] full_audios.append(audio) if last_idx == 0: + last_audio_ret = audio[-1280*8:-1280*8+256] audio = audio[:-1280*8] + raw_audios.append(audio) et = time.time() else: if stop: - audio = audio[last_idx*1280 -1280*8:] + audio_ = audio[last_idx*1280 -1280*8:] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + audio = audio_[i:] else: - audio = audio[last_idx*1280 -1280*8:-1280*8] + audio_ = audio[last_idx*1280 -1280*8:-1280*8] + raw_audios.append(audio_) + i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280]) + offset_index.append(i) + last_audio_ret = audio[-1280*8:-1280*8+256] + audio = audio_[i:] last_idx = idx # print(f'write {output_path}/out_{audio_index}') # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) @@ -344,11 +409,13 @@ def test_stream( soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) audio = torch.cat(audios, dim=0) soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) + audio_raw = torch.cat(raw_audios, dim=0) + soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000) colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] - fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6)) + fig, axes = plt.subplots(len(full_audios)+2, 1, figsize=(10, 6)) max_duration = full_audios[-1].shape[0] @@ -360,18 +427,24 @@ def test_stream( ax.axvline(x=last_line, color=colors[i], linestyle='--') ax.set_xlim(0, max_duration) axes[-1].axvline(x=last_line, color=colors[i], linestyle='--') + axes[-2].axvline(x=last_line, color=colors[i], linestyle='--') - axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') + axes[-2].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') + axes[-2].set_xlim(0, max_duration) + + axes[-1].plot(audio_raw.float().detach().cpu().numpy(), color='black', label='Raw Audio') axes[-1].set_xlim(0, max_duration) for i,y in enumerate(y[0][-idx:]): - axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center') + # axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center') axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5) + # plt.title('Overlapped Waveform Comparison') # plt.xlabel('Sample Number') # plt.ylabel('Amplitude') # plt.tight_layout() + print("offset_index:", offset_index) plt.show() @@ -509,12 +582,14 @@ def export_prov2( if __name__ == "__main__": with torch.no_grad(): - export_prov2( + test_stream( gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", version="v2Pro", - ref_audio_path="/mnt/g/ad_ref.wav", - ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", + # ref_audio_path="/mnt/g/ad_ref.wav", + # ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", + ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", + ref_text='说真的,这件衣服才配得上本小姐嘛', output_path="streaming", device="cuda", is_half=True, From 4c091e34021f91e6b4898006ba7eade90d2edfbb Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Tue, 1 Jul 2025 21:04:14 +0800 Subject: [PATCH 6/9] =?UTF-8?q?stream=5Finfer:=20=E5=AF=BC=E5=87=BA=20`fin?= =?UTF-8?q?d=5Fbest=5Faudio=5Foffset=5Ffast`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 1 + 1 file changed, 1 insertion(+) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index 8c9313bb..c7b112de 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -578,6 +578,7 @@ def export_prov2( stream_t2s.save(f"{output_path}/t2s.pt") torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt") + torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt") if __name__ == "__main__": From b1c986d10b25e181701b3301e6abafdab4374818 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Mon, 7 Jul 2025 17:49:07 +0800 Subject: [PATCH 7/9] =?UTF-8?q?stream=5Finfer:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=B3=A2=E5=BD=A2=E6=98=BE=E7=A4=BA=EF=BC=8C=E6=96=B9=E4=BE=BF?= =?UTF-8?q?=E5=AF=B9=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index c7b112de..e80293b7 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -415,35 +415,21 @@ def test_stream( colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] - fig, axes = plt.subplots(len(full_audios)+2, 1, figsize=(10, 6)) - max_duration = full_audios[-1].shape[0] + plt.xlim(0, max_duration) last_line = 0 - for i,(ax,a) in enumerate(zip(axes[:-1],full_audios)): - ax.plot(a.float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}") - ax.axvline(x=last_line, color=colors[i], linestyle='--') + + for i,a in enumerate(full_audios): + plt.plot((a+2.0*i).float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}") + # plt.axvline(x=last_line, color=colors[i], linestyle='--') last_line = a.shape[0]-8*1280 - ax.axvline(x=last_line, color=colors[i], linestyle='--') - ax.set_xlim(0, max_duration) - axes[-1].axvline(x=last_line, color=colors[i], linestyle='--') - axes[-2].axvline(x=last_line, color=colors[i], linestyle='--') + plt.axvline(x=last_line, color=colors[i], linestyle='--') - axes[-2].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') - axes[-2].set_xlim(0, max_duration) + plt.plot((audio-2.0).float().detach().cpu().numpy(), color='black', label='Final Audio') - axes[-1].plot(audio_raw.float().detach().cpu().numpy(), color='black', label='Raw Audio') - axes[-1].set_xlim(0, max_duration) + plt.plot((audio_raw-4.0).float().detach().cpu().numpy(), color='cyan', label='Raw Audio') - for i,y in enumerate(y[0][-idx:]): - # axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center') - axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5) - - - # plt.title('Overlapped Waveform Comparison') - # plt.xlabel('Sample Number') - # plt.ylabel('Amplitude') - # plt.tight_layout() print("offset_index:", offset_index) plt.show() From 5e318e2f352acf0eeebd873aff308a99a2a20e98 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Wed, 16 Jul 2025 13:51:18 +0800 Subject: [PATCH 8/9] =?UTF-8?q?stream=5Fv2pro.py=20=E4=BB=8E=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=E8=A1=8C=E8=AF=BB=E5=8F=96=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 46 +++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index e80293b7..0e615a7e 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -566,18 +566,42 @@ def export_prov2( torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt") torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt") +import argparse +import os if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument( + "--sovits_model", required=True, help="Path to the SoVITS model file" + ) + parser.add_argument( + "--ref_audio", required=True, help="Path to the reference audio file" + ) + parser.add_argument( + "--ref_text", required=True, help="Path to the reference text file" + ) + parser.add_argument( + "--output_path", required=True, help="Path to the output directory" + ) + parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--version", help="version of the model", default="v2Pro") + parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights") + + args = parser.parse_args() + + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + is_half = not args.no_half with torch.no_grad(): - test_stream( - gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", - vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", - version="v2Pro", - # ref_audio_path="/mnt/g/ad_ref.wav", - # ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", - ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", - ref_text='说真的,这件衣服才配得上本小姐嘛', - output_path="streaming", - device="cuda", - is_half=True, + export_prov2( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + version=args.version, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + is_half=is_half, ) From 6d82af146b133cf92fd5a4d699cceb38dbf82838 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Wed, 16 Jul 2025 16:00:31 +0800 Subject: [PATCH 9/9] =?UTF-8?q?stream=5Fv2pro.py=20=E5=87=8F=E5=B0=91?= =?UTF-8?q?=E7=94=A8=E4=BA=8E=E5=AF=BC=E5=87=BA=E7=9A=84=E6=96=87=E6=9C=AC?= =?UTF-8?q?=E9=95=BF=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/stream_v2pro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index 0e615a7e..0a7712ad 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -462,7 +462,7 @@ def export_prov2( ref_bert = ref_bert.to(ref_seq.device) text_seq_id, text_bert_T, norm_text = get_phones_and_bert( - "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2" + "这是一个简单的示例,真没想到这么简单就完成了.The King and His Stories.Once there was a king.He likes to write stories, but his stories were not good.", "auto", "v2" ) text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T