From 920bbafb12ea18c2cd6a5c10f04a2555ca3e977d Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Wed, 18 Jun 2025 01:47:40 +0800 Subject: [PATCH] =?UTF-8?q?stream=5Finfer=20=E5=A2=9E=E5=8A=A0=E5=AF=BC?= =?UTF-8?q?=E5=87=BA=E9=83=A8=E5=88=86=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/export_torch_script.py | 46 ++----- GPT_SoVITS/stream_v2pro.py | 203 +++++++++++++++++++++++++++--- 2 files changed, 196 insertions(+), 53 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index bf32ed6e..c0cf03c2 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -256,41 +256,21 @@ class T2SBlock: attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) - if padding_mask is not None: - for i in range(batch_size): - # mask = padding_mask[i,:,0] - if self.false.device != padding_mask.device: - self.false = self.false.to(padding_mask.device) - idx = torch.where(padding_mask[i, :, 0] == self.false)[0] - x_item = x[i, idx, :].unsqueeze(0) - attn_item = attn[i, idx, :].unsqueeze(0) - x_item = x_item + attn_item - x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x_item = x_item + self.mlp.forward(x_item) - x_item = F.layer_norm( - x_item, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) - x[i, idx, :] = x_item.squeeze(0) - x = self.to_mask(x, padding_mask) - else: - x = x + attn - x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) - x = x + self.mlp.forward(x) - x = F.layer_norm( - x, - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) return x, k_cache, v_cache def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor): diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index ad4ed26e..127dbca7 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -12,13 +12,10 @@ from inference_webui import get_phones_and_bert import matplotlib.pyplot as plt - class StreamT2SModel(nn.Module): def __init__(self, t2s: T2SModel): super(StreamT2SModel, self).__init__() self.t2s = t2s - self.k_cache: list[torch.Tensor] = [torch.zeros([1])] - self.v_cache: list[torch.Tensor] = [torch.zeros([1])] @torch.jit.export def pre_infer( @@ -29,7 +26,7 @@ class StreamT2SModel(nn.Module): ref_bert: torch.Tensor, text_bert: torch.Tensor, top_k: int, - ) -> tuple[int, Tensor, Tensor]: + ) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]: bert = torch.cat([ref_bert.T, text_bert.T], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) @@ -91,9 +88,7 @@ class StreamT2SModel(nn.Module): ) ) - self.k_cache = k_cache - self.v_cache = v_cache - return y_len, y, xy_pos + return y_len, y, xy_pos, k_cache, v_cache @torch.jit.export def decode_next_token( @@ -103,11 +98,13 @@ class StreamT2SModel(nn.Module): y_len: int, y: Tensor, xy_pos: Tensor, - ) -> tuple[Tensor, Tensor, bool]: + k_cache: List[Tensor], + v_cache: List[Tensor], + ) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]: # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token( - xy_pos, self.k_cache, self.v_cache + xy_pos, k_cache, v_cache ) logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) @@ -119,13 +116,12 @@ class StreamT2SModel(nn.Module): )[0] y = torch.concat([y, samples], dim=1) + last_token = int(samples[0, 0]) # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: # stop = True if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: - self.k_cache = [torch.zeros([1])] - self.v_cache = [torch.zeros([1])] - return y[:,:-1], xy_pos, True + return y[:,:-1], xy_pos, last_token, k_cache, v_cache # if stop: # if y.shape[1] == 0: @@ -140,7 +136,7 @@ class StreamT2SModel(nn.Module): dtype=y_emb.dtype, device=y_emb.device ) ) - return y, xy_pos, False + return y, xy_pos, last_token, k_cache, v_cache def forward( self, @@ -149,12 +145,47 @@ class StreamT2SModel(nn.Module): y_len: int, y: Tensor, xy_pos: Tensor, + k_cache: List[Tensor], + v_cache: List[Tensor], ): - return self.decode_next_token(idx,top_k,y_len,y,xy_pos) + return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache) + + +class StepVitsModel(nn.Module): + def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2): + super().__init__() + self.hps = vits.hps + self.vq_model = vits.vq_model + self.hann_window = vits.hann_window + self.sv = sv_model + + def ref_handle(self, ref_audio_32k): + refer = spectrogram_torch( + self.hann_window, + ref_audio_32k, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device) + sv_emb = self.sv(ref_audio_16k) + return refer, sv_emb + + def extract_latent(self, ssl_content): + codes = self.vq_model.extract_latent(ssl_content) + return codes[0] + + def forward(self, pred_semantic, text_seq, refer, sv_emb=None): + return self.vq_model( + pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb + )[0, 0] + import time -def export_prov2( +def test_stream( gpt_path, vits_path, version, @@ -249,15 +280,16 @@ def export_prov2( st = time.time() et = time.time() - y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) idx = 1 last_idx = 0 audios = [] full_audios = [] print("y.shape:", y.shape) while True: - y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos) + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) # print("y.shape:", y.shape) + stop = last_token==t2s.EOS # 玄学这档子事说不清楚 if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop: @@ -324,16 +356,147 @@ def export_prov2( plt.show() +def export_prov2( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + device="cpu", + is_half=True, +): + if export_torch_script.sv_cn_model == None: + init_sv_cn(device,is_half) + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, "all_zh", "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version,is_half=is_half,device=device) + vits.eval() + vits = StepVitsModel(vits, sv_model) + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + # t2s = torch.jit.script(t2s_m).to(device) + t2s = t2s_m + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + + stream_t2s = StreamT2SModel(t2s).to(device) + stream_t2s = torch.jit.script(stream_t2s) + + ref_audio_sr = resamplex(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + top_k = 15 + + prompts = vits.extract_latent(ssl_content) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = sv_model(audio_16k) + print("text_seq",text_seq.shape) + # torch.jit.trace() + + refer,sv_emb = vits.ref_handle(ref_audio_sr) + + st = time.time() + et = time.time() + + y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + idx = 1 + print("y.shape:", y.shape) + while True: + y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache) + # print("y.shape:", y.shape) + + idx+=1 + # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) + if idx>1500: + break + + if last_token == t2s.EOS: + break + + at = time.time() + print("EOS:",t2s.EOS) + + print(f"frist token: {et - st:.4f} seconds") + print(f"all token: {at - st:.4f} seconds") + print("sv_emb", sv_emb.shape) + print("refer",refer.shape) + y = y[:,-idx:].unsqueeze(0) + print("y", y.shape) + audio = vits(y, text_seq, refer, sv_emb) + soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + torch._dynamo.mark_dynamic(refer, 2) + torch._dynamo.mark_dynamic(y, 2) + + inputs = { + "forward": (y, text_seq, refer, sv_emb), + "extract_latent": ssl_content, + "ref_handle": ref_audio_sr, + } + + stream_t2s.save(f"{output_path}/t2s.pt") + torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt") + + if __name__ == "__main__": with torch.no_grad(): export_prov2( gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", version="v2Pro", - ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", - ref_text="真的,这件衣服才配得上本小姐嘛", + ref_audio_path="/mnt/g/ad_ref.wav", + ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", output_path="streaming", - export_bert_and_ssl=True, device="cuda", is_half=True, )