# 这是一个实验性质的实现,旨在探索 stream infer 的可能性。(xiao hai xie zhe wan de) from typing import List from export_torch_script import ExportERes2NetV2, SSLModel, T2SModel, VitsModel, get_raw_t2s_model, init_sv_cn, resamplex, sample, spectrogram_torch import export_torch_script from my_utils import load_audio import torch from torch import LongTensor, Tensor, nn from torch.nn import functional as F import soundfile from inference_webui import get_phones_and_bert import matplotlib.pyplot as plt class StreamT2SModel(nn.Module): def __init__(self, t2s: T2SModel): super(StreamT2SModel, self).__init__() self.t2s = t2s self.k_cache: list[torch.Tensor] = [torch.zeros([1])] self.v_cache: list[torch.Tensor] = [torch.zeros([1])] @torch.jit.export def pre_infer( self, prompts: LongTensor, ref_seq: LongTensor, text_seq: LongTensor, ref_bert: torch.Tensor, text_bert: torch.Tensor, top_k: int, ) -> tuple[int, Tensor, Tensor]: bert = torch.cat([ref_bert.T, text_bert.T], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) x = self.t2s.ar_text_embedding(all_phoneme_ids) x = x + self.t2s.bert_proj(bert.transpose(1, 2)) x: torch.Tensor = self.t2s.ar_text_position(x) # [1,N,512] [1,N] # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) y = prompts # x_example = x[:,:,0] * 0.0 x_len = x.shape[1] x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) y_emb = self.t2s.ar_audio_embedding(y) y_len: int = y_emb.shape[1] prefix_len = y.shape[1] y_pos = self.t2s.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) bsz = x.shape[0] src_len = x_len + y_len x_attn_mask_pad = F.pad( x_attn_mask, (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) value=True, ) y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) xy_attn_mask = ( torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) .unsqueeze(0) .expand(bsz * self.t2s.num_head, -1, -1) .view(bsz, self.t2s.num_head, src_len, src_len) .to(device=x.device, dtype=torch.bool) ) xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.process_prompt( xy_pos, xy_attn_mask, None ) logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) logits = logits[:, :-1] samples = sample( logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 )[0] y = torch.concat([y, samples], dim=1) y_emb: Tensor = self.t2s.ar_audio_embedding(y[:, -1:]) xy_pos: Tensor = ( y_emb * self.t2s.ar_audio_position.x_scale + self.t2s.ar_audio_position.alpha * self.t2s.ar_audio_position.pe[:, y_len].to( dtype=y_emb.dtype, device=y_emb.device ) ) self.k_cache = k_cache self.v_cache = v_cache return y_len, y, xy_pos @torch.jit.export def decode_next_token( self, idx: int, # 记住从1开始 到1500 top_k: int, y_len: int, y: Tensor, xy_pos: Tensor, ) -> tuple[Tensor, Tensor, bool]: # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token( xy_pos, self.k_cache, self.v_cache ) logits = self.t2s.ar_predict_layer(xy_dec[:, -1]) if idx < 11: ###至少预测出10个token不然不给停止(0.4s) logits = logits[:, :-1] samples = sample( logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0 )[0] y = torch.concat([y, samples], dim=1) # if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: # stop = True if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS: self.k_cache = [torch.zeros([1])] self.v_cache = [torch.zeros([1])] return y[:,:-1], xy_pos, True # if stop: # if y.shape[1] == 0: # y = torch.concat([y, torch.zeros_like(samples)], dim=1) # break y_emb = self.t2s.ar_audio_embedding(y[:, -1:]) xy_pos = ( y_emb * self.t2s.ar_audio_position.x_scale + self.t2s.ar_audio_position.alpha * self.t2s.ar_audio_position.pe[:, y_len + idx].to( dtype=y_emb.dtype, device=y_emb.device ) ) return y, xy_pos, False def forward( self, idx: int, # 记住从1开始 到1500 top_k: int, y_len: int, y: Tensor, xy_pos: Tensor, ): return self.decode_next_token(idx,top_k,y_len,y,xy_pos) import time def export_prov2( gpt_path, vits_path, version, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu", is_half=True, ): if export_torch_script.sv_cn_model == None: init_sv_cn(device,is_half) ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() ssl = SSLModel() print(f"device: {device}") ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( ref_text, "all_zh", "v2" ) ref_seq = torch.LongTensor([ref_seq_id]).to(device) ref_bert = ref_bert_T.T if is_half: ref_bert = ref_bert.half() ref_bert = ref_bert.to(ref_seq.device) text_seq_id, text_bert_T, norm_text = get_phones_and_bert( "这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2" ) text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T if is_half: text_bert = text_bert.half() text_bert = text_bert.to(text_seq.device) ssl_content = ssl(ref_audio) if is_half: ssl_content = ssl_content.half() ssl_content = ssl_content.to(device) sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model) # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" vits = VitsModel(vits_path, version,is_half=is_half,device=device) vits.eval() # gpt_path = "GPT_weights_v2/xw-e15.ckpt" # dict_s1 = torch.load(gpt_path, map_location=device) dict_s1 = torch.load(gpt_path, weights_only=False) raw_t2s = get_raw_t2s_model(dict_s1).to(device) print("#### get_raw_t2s_model ####") print(raw_t2s.config) if is_half: raw_t2s = raw_t2s.half() t2s_m = T2SModel(raw_t2s) t2s_m.eval() # t2s = torch.jit.script(t2s_m).to(device) t2s = t2s_m print("#### script t2s_m ####") print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) stream_t2s = StreamT2SModel(t2s).to(device) stream_t2s = torch.jit.script(stream_t2s) ref_audio_sr = resamplex(ref_audio, 16000, 32000) if is_half: ref_audio_sr = ref_audio_sr.half() ref_audio_sr = ref_audio_sr.to(device) top_k = 15 codes = vits.vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] prompts = prompt_semantic.unsqueeze(0) audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) sv_emb = sv_model(audio_16k) print("text_seq",text_seq.shape) refer = spectrogram_torch( vits.hann_window, ref_audio_sr, vits.hps.data.filter_length, vits.hps.data.sampling_rate, vits.hps.data.hop_length, vits.hps.data.win_length, center=False, ) st = time.time() et = time.time() y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) idx = 1 last_idx = 0 audios = [] full_audios = [] print("y.shape:", y.shape) while True: y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos) # print("y.shape:", y.shape) # 玄学这档子事说不清楚 if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop: audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] full_audios.append(audio) if last_idx == 0: audio = audio[:-1280*8] et = time.time() else: if stop: audio = audio[last_idx*1280 -1280*8:] else: audio = audio[last_idx*1280 -1280*8:-1280*8] print(y[:,-idx+last_idx:]) last_idx = idx # print(f'write {output_path}/out_{audio_index}') # soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000) audios.append(audio) idx+=1 # print(idx,'/',1500 , y.shape, y[0,-1].item(), stop) if idx>1500: break if stop: break at = time.time() for (i,a) in enumerate(audios): print(f'write {output_path}/out_{i}') soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000) print(f"frist token: {et - st:.4f} seconds") print(f"all token: {at - st:.4f} seconds") audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0] soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000) audio = torch.cat(audios, dim=0) soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000) colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow'] fig, axes = plt.subplots(len(full_audios)+1, 1, figsize=(10, 6)) max_duration = full_audios[-1].shape[0] last_line = 0 for i,(ax,a) in enumerate(zip(axes[:-1],full_audios)): ax.plot(a.float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}") ax.axvline(x=last_line, color=colors[i], linestyle='--') last_line = a.shape[0]-8*1280 ax.axvline(x=last_line, color=colors[i], linestyle='--') ax.set_xlim(0, max_duration) axes[-1].axvline(x=last_line, color=colors[i], linestyle='--') axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio') axes[-1].set_xlim(0, max_duration) # plt.title('Overlapped Waveform Comparison') # plt.xlabel('Sample Number') # plt.ylabel('Amplitude') # plt.tight_layout() plt.show() if __name__ == "__main__": with torch.no_grad(): export_prov2( gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", version="v2Pro", ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", ref_text="真的,这件衣服才配得上本小姐嘛", output_path="streaming", export_bert_and_ssl=True, device="cuda", is_half=True, )