From 5d126f98b2b1702f7e26f438c1912d13f1832a90 Mon Sep 17 00:00:00 2001 From: YSC-hain <58895782+YSC-hain@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:40:44 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E7=9B=91=E5=90=AC=E5=8F=8C=E6=A0=88=20(#1621)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在调用时使用 -a None 参数,可以让 api 监听双栈 --- api_v2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api_v2.py b/api_v2.py index ea1d0c7f..92a18f37 100644 --- a/api_v2.py +++ b/api_v2.py @@ -451,6 +451,8 @@ async def set_sovits_weights(weights_path: str = None): if __name__ == "__main__": try: + if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None uvicorn.run(app=APP, host=host, port=port, workers=1) except Exception as e: traceback.print_exc() From 98cc47699c97e82473a4948d0c5ffcd7239dc96c Mon Sep 17 00:00:00 2001 From: zzz <458761603@qq.com> Date: Sat, 26 Oct 2024 16:14:39 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E4=BC=98=E5=8C=96=20export=5Ftorch=5Fscrip?= =?UTF-8?q?t.py=20(#1720)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * export_torch_script 从命令行获取参数 * export 支持语速设置 --- GPT_SoVITS/export_torch_script.py | 137 ++++++++++++++++++------------ GPT_SoVITS/module/models_onnx.py | 11 ++- 2 files changed, 89 insertions(+), 59 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index c7f1306c..ce8821bd 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -1,5 +1,6 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e +import argparse from typing import Optional from my_utils import load_audio from text import cleaned_text_to_sequence @@ -15,7 +16,7 @@ from feature_extractor import cnhubert from AR.models.t2s_lightning_module import Text2SemanticLightningModule from module.models_onnx import SynthesizerTrn - +from inference_webui import get_phones_and_bert import os import soundfile @@ -351,7 +352,7 @@ class VitsModel(nn.Module): self.vq_model.eval() self.vq_model.load_state_dict(dict_s2["weight"], strict=False) - def forward(self, text_seq, pred_semantic, ref_audio): + def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0): refer = spectrogram_torch( ref_audio, self.hps.data.filter_length, @@ -360,7 +361,7 @@ class VitsModel(nn.Module): self.hps.data.win_length, center=False ) - return self.vq_model(pred_semantic, text_seq, refer)[0, 0] + return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0] class T2SModel(nn.Module): def __init__(self,raw_t2s:Text2SemanticLightningModule): @@ -507,6 +508,8 @@ class T2SModel(nn.Module): y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) + y[0,-1] = 0 + return y[:, -idx:].unsqueeze(0) bert_path = os.environ.get( @@ -558,44 +561,48 @@ class ExportSSLModel(torch.nn.Module): return audio def export_bert(ref_bert_inputs): + tokenizer = AutoTokenizer.from_pretrained(bert_path) + + ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") + ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() + + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) + my_bert_model = MyBertModel(bert_model) + ref_bert_inputs = { 'input_ids': ref_bert_inputs['input_ids'], 'attention_mask': ref_bert_inputs['attention_mask'], 'token_type_ids': ref_bert_inputs['token_type_ids'], 'word2ph': ref_bert_inputs['word2ph'] } - bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) - my_bert_model = MyBertModel(bert_model) my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) my_bert_model.save("onnx/bert_model.pt") print('#### exported bert ####') -def export(gpt_path, vits_path): - tokenizer = AutoTokenizer.from_pretrained(bert_path) - - ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") - ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')]) - ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() - - text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt") - text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')]) - text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int() - - bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) - - bert = MyBertModel(bert_model) - +def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path): # export_bert(ref_bert_inputs) + + if not os.path.exists(output_path): + os.makedirs(output_path) + print(f"目录已创建: {output_path}") + else: + print(f"目录已存在: {output_path}") - ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float() + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() ssl = SSLModel() s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio))) - torch.jit.script(s).save("onnx/xw/ssl_model.pt") + ssl_path = os.path.join(output_path, "ssl_model.pt") + torch.jit.script(s).save(ssl_path) print('#### exported ssl ####') - ref_bert = bert(**ref_bert_inputs) - text_bert = bert(**text_berf_inputs) + ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2') + ref_seq = torch.LongTensor([ref_seq_id]) + ref_bert = ref_bert_T.T.to(ref_seq.device) + text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2') + text_seq = torch.LongTensor([text_seq_id]) + text_bert = text_bert_T.T.to(text_seq.device) + ssl_content = ssl(ref_audio) # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" @@ -605,6 +612,8 @@ def export(gpt_path, vits_path): # gpt_path = "GPT_weights_v2/xw-e15.ckpt" dict_s1 = torch.load(gpt_path, map_location="cpu") raw_t2s = get_raw_t2s_model(dict_s1) + print('#### get_raw_t2s_model ####') + print(raw_t2s.config) t2s_m = T2SModel(raw_t2s) t2s_m.eval() t2s = torch.jit.script(t2s_m) @@ -614,6 +623,10 @@ def export(gpt_path, vits_path): gpt_sovits = GPT_SoVITS(t2s,vits) gpt_sovits.eval() ref_audio_sr = s.resample(ref_audio,16000,32000) + ref_audio_sr = s.resample(ref_audio,16000,32000) + print('ref_audio_sr:',ref_audio_sr.shape) + + ref_audio_sr = s.resample(ref_audio,16000,32000) print('ref_audio_sr:',ref_audio_sr.shape) gpt_sovits_export = torch.jit.trace( @@ -624,10 +637,10 @@ def export(gpt_path, vits_path): ref_seq, text_seq, ref_bert, - text_bert), - check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错 + text_bert)) - gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt") + gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") + gpt_sovits_export.save(gpt_sovits_path) print('#### exported gpt_sovits ####') @torch.jit.script @@ -646,16 +659,28 @@ class GPT_SoVITS(nn.Module): self.t2s = t2s self.vits = vits - def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor): - codes = self.vits.vq_model.extract_latent(ssl_content.float()) + def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0): + codes = self.vits.vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] prompts = prompt_semantic.unsqueeze(0) pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert) - audio = self.vits(text_seq, pred_semantic, ref_audio_sr) + audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed) return audio -def test(gpt_path, vits_path): +def test(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") + parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") + parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") + parser.add_argument('--ref_text', required=True, help="Path to the reference text file") + + args = parser.parse_args() + gpt_path = args.gpt_model + vits_path = args.sovits_model + ref_audio_path = args.ref_audio + ref_text = args.ref_text + tokenizer = AutoTokenizer.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) bert = MyBertModel(bert_model) @@ -680,29 +705,19 @@ def test(gpt_path, vits_path): # vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda') # ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda') - - ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") - ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')]) - ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() - - text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt") - text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')]) - text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int() - - ref_bert = bert( - ref_bert_inputs['input_ids'], - ref_bert_inputs['attention_mask'], - ref_bert_inputs['token_type_ids'], - ref_bert_inputs['word2ph'] - ) - - text_bert = bert(text_berf_inputs['input_ids'], - text_berf_inputs['attention_mask'], - text_berf_inputs['token_type_ids'], - text_berf_inputs['word2ph']) + ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2') + ref_seq = torch.LongTensor([ref_seq_id]) + ref_bert = ref_bert_T.T.to(ref_seq.device) + text_seq_id,text_bert_T,norm_text = get_phones_and_bert("问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么。","all_zh",'v2') + text_seq = torch.LongTensor([text_seq_id]) + print('text_seq:',text_seq_id) + text_bert = text_bert_T.T.to(text_seq.device) + # text_bert = torch.zeros(text_bert.shape, dtype=text_bert.dtype).to(text_bert.device) + print('text_seq:',text_seq.shape) + print('text_bert:',text_bert.shape) #[1,N] - ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float() + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() print('ref_audio:',ref_audio.shape) ref_audio_sr = ssl.resample(ref_audio,16000,32000) @@ -731,7 +746,19 @@ def export_symbel(version='v2'): with open(f"onnx/symbols_v2.json", "w") as file: json.dump(symbols, file, indent=4) +def main(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") + parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") + parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") + parser.add_argument('--ref_text', required=True, help="Path to the reference text file") + parser.add_argument('--output_path', required=True, help="Path to the output directory") + + args = parser.parse_args() + export(gpt_path=args.gpt_model, vits_path=args.sovits_model, ref_audio_path=args.ref_audio, ref_text=args.ref_text, output_path=args.output_path) + +import inference_webui if __name__ == "__main__": - export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth") - # test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth") - # export_symbel() \ No newline at end of file + inference_webui.is_half=False + inference_webui.dtype=torch.float32 + main() \ No newline at end of file diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index c5d96d0c..abe2a3c6 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -231,7 +231,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, text, ge): + def forward(self, y, text, ge, speed=1): y_mask = torch.ones_like(y[:1,:1,:]) y = self.ssl_proj(y * y_mask) * y_mask @@ -244,6 +244,9 @@ class TextEncoder(nn.Module): y = self.mrte(y, y_mask, text, text_mask, ge) y = self.encoder2(y * y_mask, y_mask) + if(speed!=1): + y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear") + y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") stats = self.proj(y) * y_mask m, logs = torch.split(stats, self.out_channels, dim=1) @@ -887,7 +890,7 @@ class SynthesizerTrn(nn.Module): # self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False) - def forward(self, codes, text, refer): + def forward(self, codes, text, refer,noise_scale=0.5, speed=1): refer_mask = torch.ones_like(refer[:1,:1,:]) if (self.version == "v1"): ge = self.ref_enc(refer * refer_mask, refer_mask) @@ -900,10 +903,10 @@ class SynthesizerTrn(nn.Module): quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) x, m_p, logs_p, y_mask = self.enc_p( - quantized, text, ge + quantized, text, ge, speed ) - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) From 6d82050b9cce3cd42c573b2179975a25ab5467af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=82=A6?= Date: Wed, 30 Oct 2024 16:38:05 +0800 Subject: [PATCH 3/4] add encoding (#1730) --- GPT_SoVITS/text/g2pw/g2pw.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/text/g2pw/g2pw.py b/GPT_SoVITS/text/g2pw/g2pw.py index a91bc115..1e3738ab 100644 --- a/GPT_SoVITS/text/g2pw/g2pw.py +++ b/GPT_SoVITS/text/g2pw/g2pw.py @@ -127,14 +127,14 @@ def get_dict(): def read_dict(): polyphonic_dict = {} - with open(PP_DICT_PATH) as f: + with open(PP_DICT_PATH,encoding="utf-8") as f: line = f.readline() while line: key, value_str = line.split(':') value = eval(value_str.strip()) polyphonic_dict[key.strip()] = value line = f.readline() - with open(PP_FIX_DICT_PATH) as f: + with open(PP_FIX_DICT_PATH,encoding="utf-8") as f: line = f.readline() while line: key, value_str = line.split(':') @@ -151,4 +151,4 @@ def correct_pronunciation(word,word_pinyins): return word_pinyins -pp_dict = get_dict() \ No newline at end of file +pp_dict = get_dict() From a70e1ad30c072cdbcfb716962abdc8008fa41cc2 Mon Sep 17 00:00:00 2001 From: zzz <458761603@qq.com> Date: Thu, 7 Nov 2024 18:19:20 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BC=98=E5=8C=96=20export=5Ftorch=5Fscrip?= =?UTF-8?q?t.py=20(#1739)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/export_torch_script.py | 192 ++++++++++++++++++++---------- 1 file changed, 130 insertions(+), 62 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index ce8821bd..f7bef133 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -330,11 +330,12 @@ class T2STransformer: for i in range(self.num_blocks): x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) return x, k_cache, v_cache - + class VitsModel(nn.Module): def __init__(self, vits_path): super().__init__() - dict_s2 = torch.load(vits_path,map_location="cpu") + # dict_s2 = torch.load(vits_path,map_location="cpu") + dict_s2 = torch.load(vits_path) self.hps = dict_s2["config"] if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: self.hps["model"]["version"] = "v1" @@ -527,7 +528,7 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor): phone_level_feature = torch.cat(phone_level_feature, dim=0) # [sum(word2ph), 1024] return phone_level_feature - + class MyBertModel(torch.nn.Module): def __init__(self, bert_model): super(MyBertModel, self).__init__() @@ -535,7 +536,8 @@ class MyBertModel(torch.nn.Module): def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) - res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] + # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] + res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1] return build_phone_level_feature(res, word2ph) class SSLModel(torch.nn.Module): @@ -560,13 +562,20 @@ class ExportSSLModel(torch.nn.Module): audio = resamplex(ref_audio,src_sr,dst_sr).float() return audio -def export_bert(ref_bert_inputs): +def export_bert(output_path): tokenizer = AutoTokenizer.from_pretrained(bert_path) - ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") - ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() + text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么." + ref_bert_inputs = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [',','。',':','?',",",".","?"]: + word2ph.append(1) + else: + word2ph.append(2) + ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int() - bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) my_bert_model = MyBertModel(bert_model) ref_bert_inputs = { @@ -576,13 +585,17 @@ def export_bert(ref_bert_inputs): 'word2ph': ref_bert_inputs['word2ph'] } - my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) - my_bert_model.save("onnx/bert_model.pt") - print('#### exported bert ####') - -def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path): - # export_bert(ref_bert_inputs) + torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0) + my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) + output_path = os.path.join(output_path, "bert_model.pt") + my_bert_model.save(output_path) + print('#### exported bert ####') + +def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'): if not os.path.exists(output_path): os.makedirs(output_path) print(f"目录已创建: {output_path}") @@ -591,45 +604,57 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path): ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() ssl = SSLModel() - s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio))) - ssl_path = os.path.join(output_path, "ssl_model.pt") - torch.jit.script(s).save(ssl_path) - print('#### exported ssl ####') + if export_bert_and_ssl: + s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio))) + ssl_path = os.path.join(output_path, "ssl_model.pt") + torch.jit.script(s).save(ssl_path) + print('#### exported ssl ####') + export_bert(output_path) + else: + s = ExportSSLModel(ssl) + + print(f"device: {device}") + ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2') - ref_seq = torch.LongTensor([ref_seq_id]) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) ref_bert = ref_bert_T.T.to(ref_seq.device) text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2') - text_seq = torch.LongTensor([text_seq_id]) + text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T.to(text_seq.device) - ssl_content = ssl(ref_audio) + ssl_content = ssl(ref_audio).to(device) # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - vits = VitsModel(vits_path) + vits = VitsModel(vits_path).to(device) vits.eval() # gpt_path = "GPT_weights_v2/xw-e15.ckpt" - dict_s1 = torch.load(gpt_path, map_location="cpu") - raw_t2s = get_raw_t2s_model(dict_s1) + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) print('#### get_raw_t2s_model ####') print(raw_t2s.config) t2s_m = T2SModel(raw_t2s) t2s_m.eval() - t2s = torch.jit.script(t2s_m) + t2s = torch.jit.script(t2s_m).to(device) print('#### script t2s_m ####') print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate) - gpt_sovits = GPT_SoVITS(t2s,vits) + gpt_sovits = GPT_SoVITS(t2s,vits).to(device) gpt_sovits.eval() - ref_audio_sr = s.resample(ref_audio,16000,32000) - ref_audio_sr = s.resample(ref_audio,16000,32000) - print('ref_audio_sr:',ref_audio_sr.shape) - ref_audio_sr = s.resample(ref_audio,16000,32000) - print('ref_audio_sr:',ref_audio_sr.shape) - - gpt_sovits_export = torch.jit.trace( + ref_audio_sr = s.resample(ref_audio,16000,32000).to(device) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + + with torch.no_grad(): + gpt_sovits_export = torch.jit.trace( gpt_sovits, example_inputs=( ssl_content, @@ -639,9 +664,9 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path): ref_bert, text_bert)) - gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") - gpt_sovits_export.save(gpt_sovits_path) - print('#### exported gpt_sovits ####') + gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") + gpt_sovits_export.save(gpt_sovits_path) + print('#### exported gpt_sovits ####') @torch.jit.script def parse_audio(ref_audio): @@ -674,6 +699,8 @@ def test(): parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") parser.add_argument('--ref_text', required=True, help="Path to the reference text file") + parser.add_argument('--output_path', required=True, help="Path to the output directory") + args = parser.parse_args() gpt_path = args.gpt_model @@ -682,42 +709,63 @@ def test(): ref_text = args.ref_text tokenizer = AutoTokenizer.from_pretrained(bert_path) - bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) - bert = MyBertModel(bert_model) - # bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') + # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) + # bert = MyBertModel(bert_model) + my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') - # gpt_path = "GPT_weights_v2/xw-e15.ckpt" - dict_s1 = torch.load(gpt_path, map_location="cpu") - raw_t2s = get_raw_t2s_model(dict_s1) - t2s = T2SModel(raw_t2s) - t2s.eval() + # dict_s1 = torch.load(gpt_path, map_location="cuda") + # raw_t2s = get_raw_t2s_model(dict_s1) + # t2s = T2SModel(raw_t2s) + # t2s.eval() # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - vits = VitsModel(vits_path) - vits.eval() + # vits = VitsModel(vits_path) + # vits.eval() - ssl = ExportSSLModel(SSLModel()) - ssl.eval() + # ssl = ExportSSLModel(SSLModel()).to('cuda') + # ssl.eval() + ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda') - gpt_sovits = GPT_SoVITS(t2s,vits) - - # vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda') - # ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda') + # gpt_sovits = GPT_SoVITS(t2s,vits) + gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda') ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2') ref_seq = torch.LongTensor([ref_seq_id]) ref_bert = ref_bert_T.T.to(ref_seq.device) - text_seq_id,text_bert_T,norm_text = get_phones_and_bert("问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么。","all_zh",'v2') + # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2') + text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字." + + text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2') + + test_bert = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [',','。',':','?',"?",",","."]: + word2ph.append(1) + else: + word2ph.append(2) + test_bert['word2ph'] = torch.Tensor(word2ph).int() + + test_bert = my_bert( + test_bert['input_ids'].to('cuda'), + test_bert['attention_mask'].to('cuda'), + test_bert['token_type_ids'].to('cuda'), + test_bert['word2ph'].to('cuda') + ) + text_seq = torch.LongTensor([text_seq_id]) - print('text_seq:',text_seq_id) text_bert = text_bert_T.T.to(text_seq.device) - # text_bert = torch.zeros(text_bert.shape, dtype=text_bert.dtype).to(text_bert.device) + + print('text_bert:',text_bert.shape,text_bert) + print('test_bert:',test_bert.shape,test_bert) + print(torch.allclose(text_bert.to('cuda'),test_bert)) + print('text_seq:',text_seq.shape) - print('text_bert:',text_bert.shape) + print('text_bert:',text_bert.shape,text_bert.type()) #[1,N] - ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda') print('ref_audio:',ref_audio.shape) ref_audio_sr = ssl.resample(ref_audio,16000,32000) @@ -725,13 +773,22 @@ def test(): ssl_content = ssl(ref_audio) print('start gpt_sovits:') + print('ssl_content:',ssl_content.shape) + print('ref_audio_sr:',ref_audio_sr.shape) + print('ref_seq:',ref_seq.shape) + ref_seq=ref_seq.to('cuda') + print('text_seq:',text_seq.shape) + text_seq=text_seq.to('cuda') + print('ref_bert:',ref_bert.shape) + ref_bert=ref_bert.to('cuda') + print('text_bert:',text_bert.shape) + text_bert=text_bert.to('cuda') + with torch.no_grad(): - audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert) + audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert) print('start write wav') soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000) - # audio = vits(text_seq, pred_semantic1, ref_audio) - # soundfile.write("out.wav", audio, 32000) import text import json @@ -753,12 +810,23 @@ def main(): parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") parser.add_argument('--ref_text', required=True, help="Path to the reference text file") parser.add_argument('--output_path', required=True, help="Path to the output directory") + parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model") + parser.add_argument('--device', help="Device to use") args = parser.parse_args() - export(gpt_path=args.gpt_model, vits_path=args.sovits_model, ref_audio_path=args.ref_audio, ref_text=args.ref_text, output_path=args.output_path) + export( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + export_bert_and_ssl=args.export_common_model, + ) import inference_webui if __name__ == "__main__": inference_webui.is_half=False inference_webui.dtype=torch.float32 - main() \ No newline at end of file + main() + # test()