stream_infer 增加导出部分。

This commit is contained in:
csh 2025-06-18 01:47:40 +08:00
parent 6fe3861e73
commit 920bbafb12
2 changed files with 196 additions and 53 deletions

View File

@ -256,41 +256,21 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
if padding_mask is not None:
for i in range(batch_size):
# mask = padding_mask[i,:,0]
if self.false.device != padding_mask.device:
self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
x_item = x[i, idx, :].unsqueeze(0)
attn_item = attn[i, idx, :].unsqueeze(0)
x_item = x_item + attn_item
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm(
x_item,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x[i, idx, :] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask)
else:
x = x + attn
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x = x + attn
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):

View File

@ -12,13 +12,10 @@ from inference_webui import get_phones_and_bert
import matplotlib.pyplot as plt
class StreamT2SModel(nn.Module):
def __init__(self, t2s: T2SModel):
super(StreamT2SModel, self).__init__()
self.t2s = t2s
self.k_cache: list[torch.Tensor] = [torch.zeros([1])]
self.v_cache: list[torch.Tensor] = [torch.zeros([1])]
@torch.jit.export
def pre_infer(
@ -29,7 +26,7 @@ class StreamT2SModel(nn.Module):
ref_bert: torch.Tensor,
text_bert: torch.Tensor,
top_k: int,
) -> tuple[int, Tensor, Tensor]:
) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]:
bert = torch.cat([ref_bert.T, text_bert.T], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
@ -91,9 +88,7 @@ class StreamT2SModel(nn.Module):
)
)
self.k_cache = k_cache
self.v_cache = v_cache
return y_len, y, xy_pos
return y_len, y, xy_pos, k_cache, v_cache
@torch.jit.export
def decode_next_token(
@ -103,11 +98,13 @@ class StreamT2SModel(nn.Module):
y_len: int,
y: Tensor,
xy_pos: Tensor,
) -> tuple[Tensor, Tensor, bool]:
k_cache: List[Tensor],
v_cache: List[Tensor],
) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]:
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token(
xy_pos, self.k_cache, self.v_cache
xy_pos, k_cache, v_cache
)
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
@ -119,13 +116,12 @@ class StreamT2SModel(nn.Module):
)[0]
y = torch.concat([y, samples], dim=1)
last_token = int(samples[0, 0])
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
# stop = True
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
self.k_cache = [torch.zeros([1])]
self.v_cache = [torch.zeros([1])]
return y[:,:-1], xy_pos, True
return y[:,:-1], xy_pos, last_token, k_cache, v_cache
# if stop:
# if y.shape[1] == 0:
@ -140,7 +136,7 @@ class StreamT2SModel(nn.Module):
dtype=y_emb.dtype, device=y_emb.device
)
)
return y, xy_pos, False
return y, xy_pos, last_token, k_cache, v_cache
def forward(
self,
@ -149,12 +145,47 @@ class StreamT2SModel(nn.Module):
y_len: int,
y: Tensor,
xy_pos: Tensor,
k_cache: List[Tensor],
v_cache: List[Tensor],
):
return self.decode_next_token(idx,top_k,y_len,y,xy_pos)
return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache)
class StepVitsModel(nn.Module):
def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2):
super().__init__()
self.hps = vits.hps
self.vq_model = vits.vq_model
self.hann_window = vits.hann_window
self.sv = sv_model
def ref_handle(self, ref_audio_32k):
refer = spectrogram_torch(
self.hann_window,
ref_audio_32k,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False,
)
ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device)
sv_emb = self.sv(ref_audio_16k)
return refer, sv_emb
def extract_latent(self, ssl_content):
codes = self.vq_model.extract_latent(ssl_content)
return codes[0]
def forward(self, pred_semantic, text_seq, refer, sv_emb=None):
return self.vq_model(
pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb
)[0, 0]
import time
def export_prov2(
def test_stream(
gpt_path,
vits_path,
version,
@ -249,15 +280,16 @@ def export_prov2(
st = time.time()
et = time.time()
y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
idx = 1
last_idx = 0
audios = []
full_audios = []
print("y.shape:", y.shape)
while True:
y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos)
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
# print("y.shape:", y.shape)
stop = last_token==t2s.EOS
# 玄学这档子事说不清楚
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
@ -324,16 +356,147 @@ def export_prov2(
plt.show()
def export_prov2(
gpt_path,
vits_path,
version,
ref_audio_path,
ref_text,
output_path,
device="cpu",
is_half=True,
):
if export_torch_script.sv_cn_model == None:
init_sv_cn(device,is_half)
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
ref_text, "all_zh", "v2"
)
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T
if is_half:
ref_bert = ref_bert.half()
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴尖尖的耳朵传说中还有九条尾巴。你觉得狐狸神奇吗。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
if is_half:
text_bert = text_bert.half()
text_bert = text_bert.to(text_seq.device)
ssl_content = ssl(ref_audio)
if is_half:
ssl_content = ssl_content.half()
ssl_content = ssl_content.to(device)
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
vits.eval()
vits = StepVitsModel(vits, sv_model)
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path, weights_only=False)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####")
print(raw_t2s.config)
if is_half:
raw_t2s = raw_t2s.half()
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
# t2s = torch.jit.script(t2s_m).to(device)
t2s = t2s_m
print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
stream_t2s = StreamT2SModel(t2s).to(device)
stream_t2s = torch.jit.script(stream_t2s)
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
if is_half:
ref_audio_sr = ref_audio_sr.half()
ref_audio_sr = ref_audio_sr.to(device)
top_k = 15
prompts = vits.extract_latent(ssl_content)
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
sv_emb = sv_model(audio_16k)
print("text_seq",text_seq.shape)
# torch.jit.trace()
refer,sv_emb = vits.ref_handle(ref_audio_sr)
st = time.time()
et = time.time()
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
idx = 1
print("y.shape:", y.shape)
while True:
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
# print("y.shape:", y.shape)
idx+=1
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
if idx>1500:
break
if last_token == t2s.EOS:
break
at = time.time()
print("EOS:",t2s.EOS)
print(f"frist token: {et - st:.4f} seconds")
print(f"all token: {at - st:.4f} seconds")
print("sv_emb", sv_emb.shape)
print("refer",refer.shape)
y = y[:,-idx:].unsqueeze(0)
print("y", y.shape)
audio = vits(y, text_seq, refer, sv_emb)
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
torch._dynamo.mark_dynamic(ref_seq, 1)
torch._dynamo.mark_dynamic(text_seq, 1)
torch._dynamo.mark_dynamic(ref_bert, 0)
torch._dynamo.mark_dynamic(text_bert, 0)
torch._dynamo.mark_dynamic(refer, 2)
torch._dynamo.mark_dynamic(y, 2)
inputs = {
"forward": (y, text_seq, refer, sv_emb),
"extract_latent": ssl_content,
"ref_handle": ref_audio_sr,
}
stream_t2s.save(f"{output_path}/t2s.pt")
torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt")
if __name__ == "__main__":
with torch.no_grad():
export_prov2(
gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt",
vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
version="v2Pro",
ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav",
ref_text="真的,这件衣服才配得上本小姐嘛",
ref_audio_path="/mnt/g/ad_ref.wav",
ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
output_path="streaming",
export_bert_and_ssl=True,
device="cuda",
is_half=True,
)