优化 export_torch_script.py (#1720)

* export_torch_script 从命令行获取参数

* export 支持语速设置
This commit is contained in:
zzz 2024-10-26 16:14:39 +08:00 committed by GitHub
parent 5d126f98b2
commit 98cc47699c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 89 additions and 59 deletions

View File

@ -1,5 +1,6 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import argparse
from typing import Optional
from my_utils import load_audio
from text import cleaned_text_to_sequence
@ -15,7 +16,7 @@ from feature_extractor import cnhubert
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from module.models_onnx import SynthesizerTrn
from inference_webui import get_phones_and_bert
import os
import soundfile
@ -351,7 +352,7 @@ class VitsModel(nn.Module):
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio):
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
@ -360,7 +361,7 @@ class VitsModel(nn.Module):
self.hps.data.win_length,
center=False
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
class T2SModel(nn.Module):
def __init__(self,raw_t2s:Text2SemanticLightningModule):
@ -507,6 +508,8 @@ class T2SModel(nn.Module):
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
y[0,-1] = 0
return y[:, -idx:].unsqueeze(0)
bert_path = os.environ.get(
@ -558,44 +561,48 @@ class ExportSSLModel(torch.nn.Module):
return audio
def export_bert(ref_bert_inputs):
tokenizer = AutoTokenizer.from_pretrained(bert_path)
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
my_bert_model = MyBertModel(bert_model)
ref_bert_inputs = {
'input_ids': ref_bert_inputs['input_ids'],
'attention_mask': ref_bert_inputs['attention_mask'],
'token_type_ids': ref_bert_inputs['token_type_ids'],
'word2ph': ref_bert_inputs['word2ph']
}
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
my_bert_model = MyBertModel(bert_model)
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
my_bert_model.save("onnx/bert_model.pt")
print('#### exported bert ####')
def export(gpt_path, vits_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path)
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model)
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path):
# export_bert(ref_bert_inputs)
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
if not os.path.exists(output_path):
os.makedirs(output_path)
print(f"目录已创建: {output_path}")
else:
print(f"目录已存在: {output_path}")
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
torch.jit.script(s).save("onnx/xw/ssl_model.pt")
ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path)
print('#### exported ssl ####')
ref_bert = bert(**ref_bert_inputs)
text_bert = bert(**text_berf_inputs)
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
ref_seq = torch.LongTensor([ref_seq_id])
ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
text_seq = torch.LongTensor([text_seq_id])
text_bert = text_bert_T.T.to(text_seq.device)
ssl_content = ssl(ref_audio)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
@ -605,6 +612,8 @@ def export(gpt_path, vits_path):
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1)
print('#### get_raw_t2s_model ####')
print(raw_t2s.config)
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
t2s = torch.jit.script(t2s_m)
@ -613,6 +622,10 @@ def export(gpt_path, vits_path):
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS(t2s,vits)
gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio,16000,32000)
ref_audio_sr = s.resample(ref_audio,16000,32000)
print('ref_audio_sr:',ref_audio_sr.shape)
ref_audio_sr = s.resample(ref_audio,16000,32000)
print('ref_audio_sr:',ref_audio_sr.shape)
@ -624,10 +637,10 @@ def export(gpt_path, vits_path):
ref_seq,
text_seq,
ref_bert,
text_bert),
check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错
text_bert))
gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt")
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
print('#### exported gpt_sovits ####')
@torch.jit.script
@ -646,16 +659,28 @@ class GPT_SoVITS(nn.Module):
self.t2s = t2s
self.vits = vits
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor):
codes = self.vits.vq_model.extract_latent(ssl_content.float())
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0):
codes = self.vits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
return audio
def test(gpt_path, vits_path):
def test():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
args = parser.parse_args()
gpt_path = args.gpt_model
vits_path = args.sovits_model
ref_audio_path = args.ref_audio
ref_text = args.ref_text
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model)
@ -680,29 +705,19 @@ def test(gpt_path, vits_path):
# vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda')
# ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda')
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
ref_bert = bert(
ref_bert_inputs['input_ids'],
ref_bert_inputs['attention_mask'],
ref_bert_inputs['token_type_ids'],
ref_bert_inputs['word2ph']
)
text_bert = bert(text_berf_inputs['input_ids'],
text_berf_inputs['attention_mask'],
text_berf_inputs['token_type_ids'],
text_berf_inputs['word2ph'])
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
ref_seq = torch.LongTensor([ref_seq_id])
ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么。","all_zh",'v2')
text_seq = torch.LongTensor([text_seq_id])
print('text_seq:',text_seq_id)
text_bert = text_bert_T.T.to(text_seq.device)
# text_bert = torch.zeros(text_bert.shape, dtype=text_bert.dtype).to(text_bert.device)
print('text_seq:',text_seq.shape)
print('text_bert:',text_bert.shape)
#[1,N]
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
print('ref_audio:',ref_audio.shape)
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
@ -731,7 +746,19 @@ def export_symbel(version='v2'):
with open(f"onnx/symbols_v2.json", "w") as file:
json.dump(symbols, file, indent=4)
def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
parser.add_argument('--output_path', required=True, help="Path to the output directory")
args = parser.parse_args()
export(gpt_path=args.gpt_model, vits_path=args.sovits_model, ref_audio_path=args.ref_audio, ref_text=args.ref_text, output_path=args.output_path)
import inference_webui
if __name__ == "__main__":
export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
# test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
# export_symbel()
inference_webui.is_half=False
inference_webui.dtype=torch.float32
main()

View File

@ -231,7 +231,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge):
def forward(self, y, text, ge, speed=1):
y_mask = torch.ones_like(y[:1,:1,:])
y = self.ssl_proj(y * y_mask) * y_mask
@ -244,6 +244,9 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if(speed!=1):
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
@ -887,7 +890,7 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer):
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
@ -900,10 +903,10 @@ class SynthesizerTrn(nn.Module):
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, text, ge
quantized, text, ge, speed
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)