mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-20 10:18:32 +08:00
stream_infer 增加导出部分。
This commit is contained in:
parent
6fe3861e73
commit
920bbafb12
@ -256,41 +256,21 @@ class T2SBlock:
|
|||||||
|
|
||||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||||
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||||
|
|
||||||
if padding_mask is not None:
|
x = x + attn
|
||||||
for i in range(batch_size):
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
# mask = padding_mask[i,:,0]
|
x = x + self.mlp.forward(x)
|
||||||
if self.false.device != padding_mask.device:
|
x = F.layer_norm(
|
||||||
self.false = self.false.to(padding_mask.device)
|
x,
|
||||||
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
[self.hidden_dim],
|
||||||
x_item = x[i, idx, :].unsqueeze(0)
|
self.norm_w2,
|
||||||
attn_item = attn[i, idx, :].unsqueeze(0)
|
self.norm_b2,
|
||||||
x_item = x_item + attn_item
|
self.norm_eps2,
|
||||||
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
)
|
||||||
x_item = x_item + self.mlp.forward(x_item)
|
|
||||||
x_item = F.layer_norm(
|
|
||||||
x_item,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
x[i, idx, :] = x_item.squeeze(0)
|
|
||||||
x = self.to_mask(x, padding_mask)
|
|
||||||
else:
|
|
||||||
x = x + attn
|
|
||||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
|
||||||
x = x + self.mlp.forward(x)
|
|
||||||
x = F.layer_norm(
|
|
||||||
x,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||||
|
@ -12,13 +12,10 @@ from inference_webui import get_phones_and_bert
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StreamT2SModel(nn.Module):
|
class StreamT2SModel(nn.Module):
|
||||||
def __init__(self, t2s: T2SModel):
|
def __init__(self, t2s: T2SModel):
|
||||||
super(StreamT2SModel, self).__init__()
|
super(StreamT2SModel, self).__init__()
|
||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
self.k_cache: list[torch.Tensor] = [torch.zeros([1])]
|
|
||||||
self.v_cache: list[torch.Tensor] = [torch.zeros([1])]
|
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def pre_infer(
|
def pre_infer(
|
||||||
@ -29,7 +26,7 @@ class StreamT2SModel(nn.Module):
|
|||||||
ref_bert: torch.Tensor,
|
ref_bert: torch.Tensor,
|
||||||
text_bert: torch.Tensor,
|
text_bert: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> tuple[int, Tensor, Tensor]:
|
) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]:
|
||||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
@ -91,9 +88,7 @@ class StreamT2SModel(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.k_cache = k_cache
|
return y_len, y, xy_pos, k_cache, v_cache
|
||||||
self.v_cache = v_cache
|
|
||||||
return y_len, y, xy_pos
|
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def decode_next_token(
|
def decode_next_token(
|
||||||
@ -103,11 +98,13 @@ class StreamT2SModel(nn.Module):
|
|||||||
y_len: int,
|
y_len: int,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
xy_pos: Tensor,
|
xy_pos: Tensor,
|
||||||
) -> tuple[Tensor, Tensor, bool]:
|
k_cache: List[Tensor],
|
||||||
|
v_cache: List[Tensor],
|
||||||
|
) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]:
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token(
|
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token(
|
||||||
xy_pos, self.k_cache, self.v_cache
|
xy_pos, k_cache, v_cache
|
||||||
)
|
)
|
||||||
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
|
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
|
||||||
@ -119,13 +116,12 @@ class StreamT2SModel(nn.Module):
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
last_token = int(samples[0, 0])
|
||||||
|
|
||||||
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
# stop = True
|
# stop = True
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
|
||||||
self.k_cache = [torch.zeros([1])]
|
return y[:,:-1], xy_pos, last_token, k_cache, v_cache
|
||||||
self.v_cache = [torch.zeros([1])]
|
|
||||||
return y[:,:-1], xy_pos, True
|
|
||||||
|
|
||||||
# if stop:
|
# if stop:
|
||||||
# if y.shape[1] == 0:
|
# if y.shape[1] == 0:
|
||||||
@ -140,7 +136,7 @@ class StreamT2SModel(nn.Module):
|
|||||||
dtype=y_emb.dtype, device=y_emb.device
|
dtype=y_emb.dtype, device=y_emb.device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return y, xy_pos, False
|
return y, xy_pos, last_token, k_cache, v_cache
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -149,12 +145,47 @@ class StreamT2SModel(nn.Module):
|
|||||||
y_len: int,
|
y_len: int,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
xy_pos: Tensor,
|
xy_pos: Tensor,
|
||||||
|
k_cache: List[Tensor],
|
||||||
|
v_cache: List[Tensor],
|
||||||
):
|
):
|
||||||
return self.decode_next_token(idx,top_k,y_len,y,xy_pos)
|
return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache)
|
||||||
|
|
||||||
|
|
||||||
|
class StepVitsModel(nn.Module):
|
||||||
|
def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2):
|
||||||
|
super().__init__()
|
||||||
|
self.hps = vits.hps
|
||||||
|
self.vq_model = vits.vq_model
|
||||||
|
self.hann_window = vits.hann_window
|
||||||
|
self.sv = sv_model
|
||||||
|
|
||||||
|
def ref_handle(self, ref_audio_32k):
|
||||||
|
refer = spectrogram_torch(
|
||||||
|
self.hann_window,
|
||||||
|
ref_audio_32k,
|
||||||
|
self.hps.data.filter_length,
|
||||||
|
self.hps.data.sampling_rate,
|
||||||
|
self.hps.data.hop_length,
|
||||||
|
self.hps.data.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device)
|
||||||
|
sv_emb = self.sv(ref_audio_16k)
|
||||||
|
return refer, sv_emb
|
||||||
|
|
||||||
|
def extract_latent(self, ssl_content):
|
||||||
|
codes = self.vq_model.extract_latent(ssl_content)
|
||||||
|
return codes[0]
|
||||||
|
|
||||||
|
def forward(self, pred_semantic, text_seq, refer, sv_emb=None):
|
||||||
|
return self.vq_model(
|
||||||
|
pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb
|
||||||
|
)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
def export_prov2(
|
def test_stream(
|
||||||
gpt_path,
|
gpt_path,
|
||||||
vits_path,
|
vits_path,
|
||||||
version,
|
version,
|
||||||
@ -249,15 +280,16 @@ def export_prov2(
|
|||||||
st = time.time()
|
st = time.time()
|
||||||
et = time.time()
|
et = time.time()
|
||||||
|
|
||||||
y_len, y, xy_pos = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
idx = 1
|
idx = 1
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
audios = []
|
audios = []
|
||||||
full_audios = []
|
full_audios = []
|
||||||
print("y.shape:", y.shape)
|
print("y.shape:", y.shape)
|
||||||
while True:
|
while True:
|
||||||
y, xy_pos, stop = stream_t2s(idx, top_k, y_len, y, xy_pos)
|
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||||
# print("y.shape:", y.shape)
|
# print("y.shape:", y.shape)
|
||||||
|
stop = last_token==t2s.EOS
|
||||||
|
|
||||||
# 玄学这档子事说不清楚
|
# 玄学这档子事说不清楚
|
||||||
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
|
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
|
||||||
@ -324,16 +356,147 @@ def export_prov2(
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def export_prov2(
|
||||||
|
gpt_path,
|
||||||
|
vits_path,
|
||||||
|
version,
|
||||||
|
ref_audio_path,
|
||||||
|
ref_text,
|
||||||
|
output_path,
|
||||||
|
device="cpu",
|
||||||
|
is_half=True,
|
||||||
|
):
|
||||||
|
if export_torch_script.sv_cn_model == None:
|
||||||
|
init_sv_cn(device,is_half)
|
||||||
|
|
||||||
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||||
|
ssl = SSLModel()
|
||||||
|
|
||||||
|
print(f"device: {device}")
|
||||||
|
|
||||||
|
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||||||
|
ref_text, "all_zh", "v2"
|
||||||
|
)
|
||||||
|
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||||
|
ref_bert = ref_bert_T.T
|
||||||
|
if is_half:
|
||||||
|
ref_bert = ref_bert.half()
|
||||||
|
ref_bert = ref_bert.to(ref_seq.device)
|
||||||
|
|
||||||
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||||
|
"这是一个简单的示例,真没想到这么简单就完成了。真的神奇。接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
|
||||||
|
)
|
||||||
|
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||||
|
text_bert = text_bert_T.T
|
||||||
|
if is_half:
|
||||||
|
text_bert = text_bert.half()
|
||||||
|
text_bert = text_bert.to(text_seq.device)
|
||||||
|
|
||||||
|
ssl_content = ssl(ref_audio)
|
||||||
|
if is_half:
|
||||||
|
ssl_content = ssl_content.half()
|
||||||
|
ssl_content = ssl_content.to(device)
|
||||||
|
|
||||||
|
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
|
||||||
|
|
||||||
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
|
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
|
||||||
|
vits.eval()
|
||||||
|
vits = StepVitsModel(vits, sv_model)
|
||||||
|
|
||||||
|
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||||
|
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||||
|
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||||
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
|
print("#### get_raw_t2s_model ####")
|
||||||
|
print(raw_t2s.config)
|
||||||
|
if is_half:
|
||||||
|
raw_t2s = raw_t2s.half()
|
||||||
|
t2s_m = T2SModel(raw_t2s)
|
||||||
|
t2s_m.eval()
|
||||||
|
# t2s = torch.jit.script(t2s_m).to(device)
|
||||||
|
t2s = t2s_m
|
||||||
|
print("#### script t2s_m ####")
|
||||||
|
|
||||||
|
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
|
stream_t2s = StreamT2SModel(t2s).to(device)
|
||||||
|
stream_t2s = torch.jit.script(stream_t2s)
|
||||||
|
|
||||||
|
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
|
||||||
|
if is_half:
|
||||||
|
ref_audio_sr = ref_audio_sr.half()
|
||||||
|
ref_audio_sr = ref_audio_sr.to(device)
|
||||||
|
|
||||||
|
top_k = 15
|
||||||
|
|
||||||
|
prompts = vits.extract_latent(ssl_content)
|
||||||
|
|
||||||
|
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||||
|
sv_emb = sv_model(audio_16k)
|
||||||
|
print("text_seq",text_seq.shape)
|
||||||
|
# torch.jit.trace()
|
||||||
|
|
||||||
|
refer,sv_emb = vits.ref_handle(ref_audio_sr)
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
et = time.time()
|
||||||
|
|
||||||
|
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
|
idx = 1
|
||||||
|
print("y.shape:", y.shape)
|
||||||
|
while True:
|
||||||
|
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||||
|
# print("y.shape:", y.shape)
|
||||||
|
|
||||||
|
idx+=1
|
||||||
|
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
||||||
|
if idx>1500:
|
||||||
|
break
|
||||||
|
|
||||||
|
if last_token == t2s.EOS:
|
||||||
|
break
|
||||||
|
|
||||||
|
at = time.time()
|
||||||
|
print("EOS:",t2s.EOS)
|
||||||
|
|
||||||
|
print(f"frist token: {et - st:.4f} seconds")
|
||||||
|
print(f"all token: {at - st:.4f} seconds")
|
||||||
|
print("sv_emb", sv_emb.shape)
|
||||||
|
print("refer",refer.shape)
|
||||||
|
y = y[:,-idx:].unsqueeze(0)
|
||||||
|
print("y", y.shape)
|
||||||
|
audio = vits(y, text_seq, refer, sv_emb)
|
||||||
|
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||||
|
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||||
|
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||||
|
torch._dynamo.mark_dynamic(refer, 2)
|
||||||
|
torch._dynamo.mark_dynamic(y, 2)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"forward": (y, text_seq, refer, sv_emb),
|
||||||
|
"extract_latent": ssl_content,
|
||||||
|
"ref_handle": ref_audio_sr,
|
||||||
|
}
|
||||||
|
|
||||||
|
stream_t2s.save(f"{output_path}/t2s.pt")
|
||||||
|
torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
export_prov2(
|
export_prov2(
|
||||||
gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
||||||
version="v2Pro",
|
version="v2Pro",
|
||||||
ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav",
|
ref_audio_path="/mnt/g/ad_ref.wav",
|
||||||
ref_text="真的,这件衣服才配得上本小姐嘛",
|
ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.",
|
||||||
output_path="streaming",
|
output_path="streaming",
|
||||||
export_bert_and_ssl=True,
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
is_half=True,
|
is_half=True,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user