mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat:combined fsdc and encoder, todo:extract audio pipeline
This commit is contained in:
parent
71cbe28e68
commit
da5aa78224
@ -94,6 +94,7 @@ class T2SEncoder(nn.Module):
|
|||||||
def __init__(self, t2s, vits):
|
def __init__(self, t2s, vits):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = t2s.onnx_encoder
|
self.encoder = t2s.onnx_encoder
|
||||||
|
self.fsdc = t2s.first_stage_decoder
|
||||||
self.vits = vits
|
self.vits = vits
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||||
@ -103,7 +104,8 @@ class T2SEncoder(nn.Module):
|
|||||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
return self.encoder(all_phoneme_ids, bert), prompt
|
return self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class T2SModel(nn.Module):
|
class T2SModel(nn.Module):
|
||||||
@ -130,20 +132,13 @@ class T2SModel(nn.Module):
|
|||||||
early_stop_num = self.t2s_model.early_stop_num
|
early_stop_num = self.t2s_model.early_stop_num
|
||||||
|
|
||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
prefix_len = prompts.shape[1]
|
|
||||||
|
|
||||||
# [1,N,512] [1,N]
|
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
|
||||||
|
|
||||||
stop = False
|
stop = False
|
||||||
for idx in range(1, 1500):
|
for idx in range(1, 1500):
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
|
||||||
stop = True
|
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
stop = True
|
stop = True
|
||||||
if stop:
|
if stop:
|
||||||
@ -167,7 +162,7 @@ class T2SModel(nn.Module):
|
|||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
||||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||||
output_names=["x", "prompts"],
|
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"ref_seq": {1: "ref_length"},
|
"ref_seq": {1: "ref_length"},
|
||||||
"text_seq": {1: "text_length"},
|
"text_seq": {1: "text_length"},
|
||||||
@ -177,22 +172,22 @@ class T2SModel(nn.Module):
|
|||||||
},
|
},
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
)
|
)
|
||||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
torch.onnx.export(
|
# torch.onnx.export(
|
||||||
self.first_stage_decoder,
|
# self.first_stage_decoder,
|
||||||
(x, prompts),
|
# (x, prompts),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
# f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
||||||
input_names=["x", "prompts"],
|
# input_names=["x", "prompts"],
|
||||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
# output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||||
dynamic_axes={
|
# dynamic_axes={
|
||||||
"x": {1: "x_length"},
|
# "x": {1: "x_length"},
|
||||||
"prompts": {1: "prompts_length"},
|
# "prompts": {1: "prompts_length"},
|
||||||
},
|
# },
|
||||||
verbose=False,
|
# verbose=False,
|
||||||
opset_version=16,
|
# opset_version=16,
|
||||||
)
|
# )
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.stage_decoder,
|
self.stage_decoder,
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATH = "playground/v2proplus_export/v2proplus"
|
MODEL_PATH = "playground/v2pro_export/v2pro"
|
||||||
|
|
||||||
def audio_postprocess(
|
def audio_postprocess(
|
||||||
audios,
|
audios,
|
||||||
@ -49,7 +49,7 @@ def load_and_preprocess_audio(audio_path, max_length=160000):
|
|||||||
waveform = waveform[:, :max_length]
|
waveform = waveform[:, :max_length]
|
||||||
|
|
||||||
# make a zero tensor that has length 3200*0.3
|
# make a zero tensor that has length 3200*0.3
|
||||||
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
|
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
||||||
|
|
||||||
# concate zero_tensor with waveform
|
# concate zero_tensor with waveform
|
||||||
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||||
@ -75,7 +75,7 @@ def preprocess_text(text:str):
|
|||||||
|
|
||||||
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
||||||
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
||||||
[input_phones, input_bert] = preprocess_text("震撼视角,感受成都世运会,闭幕式烟花")
|
[input_phones, input_bert] = preprocess_text("震撼视角,感受开幕式,闭幕式烟花")
|
||||||
|
|
||||||
|
|
||||||
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
||||||
@ -88,7 +88,7 @@ audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
|
|||||||
|
|
||||||
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
||||||
|
|
||||||
[x, prompts] = encoder.run(None, {
|
[y, k, v, y_emb, x_example] = encoder.run(None, {
|
||||||
"text_seq": input_phones,
|
"text_seq": input_phones,
|
||||||
"text_bert": input_bert,
|
"text_bert": input_bert,
|
||||||
"ref_seq": ref_phones,
|
"ref_seq": ref_phones,
|
||||||
@ -96,18 +96,16 @@ encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
|||||||
"ssl_content": audio_prompt_hubert
|
"ssl_content": audio_prompt_hubert
|
||||||
})
|
})
|
||||||
|
|
||||||
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
|
# fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
|
||||||
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
||||||
|
|
||||||
# for i in tqdm(range(10000)):
|
# for i in tqdm(range(10000)):
|
||||||
[y, k, v, y_emb, x_example] = fsdec.run(None, {
|
# [y, k, v, y_emb, x_example] = fsdec.run(None, {
|
||||||
"x": x,
|
# "x": x,
|
||||||
"prompts": prompts
|
# "prompts": prompts
|
||||||
})
|
# })
|
||||||
|
|
||||||
|
|
||||||
prefix_len = prompts.shape[1]
|
|
||||||
|
|
||||||
stop = False
|
stop = False
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
@ -129,6 +127,7 @@ waveform, sample_rate = torchaudio.load("playground/ref/audio.wav")
|
|||||||
if sample_rate != 32000:
|
if sample_rate != 32000:
|
||||||
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
|
||||||
waveform = resampler(waveform)
|
waveform = resampler(waveform)
|
||||||
|
print(f"Waveform shape: {waveform.shape}")
|
||||||
ref_audio = waveform.numpy().astype(np.float32)
|
ref_audio = waveform.numpy().astype(np.float32)
|
||||||
|
|
||||||
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user