From 5c91e66d2e99146d6650f9dcdb542ef6d70ac69c Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Thu, 12 Jun 2025 21:53:14 +0800 Subject: [PATCH] export_torch_script.py support v2Pro & v2ProPlus --- GPT_SoVITS/export_torch_script.py | 286 +++++++++++++++++++++++++++--- GPT_SoVITS/module/models_onnx.py | 18 +- 2 files changed, 275 insertions(+), 29 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index b68dd38..6a13c2d 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -1,6 +1,7 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e import argparse +from io import BytesIO from typing import Optional from my_utils import load_audio import torch @@ -17,6 +18,9 @@ from module.models_onnx import SynthesizerTrn from inference_webui import get_phones_and_bert +from sv import SV +import kaldi as Kaldi + import os import soundfile @@ -32,6 +36,22 @@ default_config = { "EOS": 1024, } +sv_cn_model = None +def init_sv_cn(device, is_half): + global sv_cn_model + sv_cn_model = SV(device, is_half) + +def load_sovits_new(sovits_path): + f = open(sovits_path, "rb") + meta = f.read(2) + if meta != b"PK": + data = b"PK" + f.read() + bio = BytesIO() + bio.write(data) + bio.seek(0) + return torch.load(bio, map_location="cpu", weights_only=False) + return torch.load(sovits_path, map_location="cpu", weights_only=False) + def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: config = dict_s1["config"] @@ -328,15 +348,22 @@ class T2STransformer: class VitsModel(nn.Module): - def __init__(self, vits_path): + def __init__(self, vits_path, version=None): super().__init__() # dict_s2 = torch.load(vits_path,map_location="cpu") - dict_s2 = torch.load(vits_path, weights_only=False) + dict_s2 = load_sovits_new(vits_path) self.hps = dict_s2["config"] - if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: - self.hps["model"]["version"] = "v1" + + if version is None: + if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = "v2" else: - self.hps["model"]["version"] = "v2" + if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]: + self.hps["model"]["version"] = version + else: + raise ValueError(f"Unsupported version: {version}") self.hps = DictToAttrRecursive(self.hps) self.hps.model.semantic_frame_rate = "25hz" @@ -349,7 +376,7 @@ class VitsModel(nn.Module): self.vq_model.eval() self.vq_model.load_state_dict(dict_s2["weight"], strict=False) - def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0): + def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None): refer = spectrogram_torch( ref_audio, self.hps.data.filter_length, @@ -358,7 +385,7 @@ class VitsModel(nn.Module): self.hps.data.win_length, center=False, ) - return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0] + return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0] class T2SModel(nn.Module): @@ -632,7 +659,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be ref_seq = torch.LongTensor([ref_seq_id]).to(device) ref_bert = ref_bert_T.T.to(ref_seq.device) text_seq_id, text_bert_T, norm_text = get_phones_and_bert( - "这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2" + "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2" ) text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T.to(text_seq.device) @@ -679,6 +706,127 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be print("#### exported gpt_sovits ####") +def export_prov2( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + export_bert_and_ssl=False, + device="cpu", + is_half=True, +): + if sv_cn_model == None: + init_sv_cn(device,is_half) + + if not os.path.exists(output_path): + os.makedirs(output_path) + print(f"目录已创建: {output_path}") + else: + print(f"目录已存在: {output_path}") + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + if export_bert_and_ssl: + s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio))) + ssl_path = os.path.join(output_path, "ssl_model.pt") + torch.jit.script(s).save(ssl_path) + print("#### exported ssl ####") + export_bert(output_path) + else: + s = ExportSSLModel(ssl) + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert( + ref_text, "all_zh", "v2" + ) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version) + if is_half: + vits.vq_model = vits.vq_model.half() + vits.to(device) + vits.eval() + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + t2s = torch.jit.script(t2s_m).to(device) + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device) + gpt_sovits.eval() + + ref_audio_sr = s.resample(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + # torch._dynamo.mark_dynamic(sv_emb, 0) + + top_k = torch.LongTensor([5]).to(device) + # 先跑一遍 sv_model 让它加载 cache,详情见 L880 + gpt_sovits.sv_model(ref_audio_sr) + + with torch.no_grad(): + gpt_sovits_export = torch.jit.trace( + gpt_sovits, + example_inputs=( + ssl_content, + ref_audio_sr, + ref_seq, + text_seq, + ref_bert, + text_bert, + top_k, + ), + ) + + gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") + gpt_sovits_export.save(gpt_sovits_path) + print("#### exported gpt_sovits ####") + audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k) + print("start write wav") + soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000) + + @torch.jit.script def parse_audio(ref_audio): ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device) @@ -717,6 +865,66 @@ class GPT_SoVITS(nn.Module): return audio +class ExportERes2NetV2(nn.Module): + def __init__(self, sv_cn_model:SV): + super(ExportERes2NetV2, self).__init__() + self.bn1 = sv_cn_model.embedding_model.bn1 + self.conv1 = sv_cn_model.embedding_model.conv1 + self.layer1 = sv_cn_model.embedding_model.layer1 + self.layer2 = sv_cn_model.embedding_model.layer2 + self.layer3 = sv_cn_model.embedding_model.layer3 + self.layer4 = sv_cn_model.embedding_model.layer4 + self.layer3_ds = sv_cn_model.embedding_model.layer3_ds + self.fuse34 = sv_cn_model.embedding_model.fuse34 + + # audio_16k.shape: [1,N] + def forward(self, audio_16k): + # 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关 + # 只跟 device 和 dtype 有关 + x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0) + x = torch.stack([x]) + + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + out3_ds = self.layer3_ds(out3) + fuse_out34 = self.fuse34(out4, out3_ds) + return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1) + + +class GPT_SoVITS_V2Pro(nn.Module): + def __init__(self, t2s: T2SModel, vits: VitsModel,sv_model:ExportERes2NetV2): + super().__init__() + self.t2s = t2s + self.vits = vits + self.sv_model = sv_model + + def forward( + self, + ssl_content: torch.Tensor, + ref_audio_sr: torch.Tensor, + ref_seq: Tensor, + text_seq: Tensor, + ref_bert: Tensor, + text_bert: Tensor, + top_k: LongTensor, + speed=1.0, + ): + codes = self.vits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = self.sv_model(audio_16k) + + pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb) + return audio + def test(): parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") @@ -833,29 +1041,53 @@ def export_symbel(version="v2"): def main(): parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") - parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") - parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") - parser.add_argument("--ref_text", required=True, help="Path to the reference text file") - parser.add_argument("--output_path", required=True, help="Path to the output directory") - parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model") + parser.add_argument( + "--sovits_model", required=True, help="Path to the SoVITS model file" + ) + parser.add_argument( + "--ref_audio", required=True, help="Path to the reference audio file" + ) + parser.add_argument( + "--ref_text", required=True, help="Path to the reference text file" + ) + parser.add_argument( + "--output_path", required=True, help="Path to the output directory" + ) + parser.add_argument( + "--export_common_model", action="store_true", help="Export Bert and SSL model" + ) parser.add_argument("--device", help="Device to use") + parser.add_argument("--version", help="version of the model", default="v2") + parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights") args = parser.parse_args() - export( - gpt_path=args.gpt_model, - vits_path=args.sovits_model, - ref_audio_path=args.ref_audio, - ref_text=args.ref_text, - output_path=args.output_path, - device=args.device, - export_bert_and_ssl=args.export_common_model, - ) + if args.version in ["v2Pro", "v2ProPlus"]: + is_half = not args.no_half + print(f"Using half precision: {is_half}") + export_prov2( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + version=args.version, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + export_bert_and_ssl=args.export_common_model, + device=args.device, + is_half=is_half, + ) + else: + export( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + export_bert_and_ssl=args.export_common_model, + ) -import inference_webui - if __name__ == "__main__": - inference_webui.is_half = False - inference_webui.dtype = torch.float32 - main() + with torch.no_grad(): + main() # test() diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index 028db5f..525273f 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -762,6 +762,7 @@ class CodePredictor(nn.Module): return pred_codes.transpose(0, 1) +v2pro_set={"v2Pro","v2ProPlus"} class SynthesizerTrn(nn.Module): """ @@ -867,20 +868,33 @@ class SynthesizerTrn(nn.Module): # self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False) + self.is_v2pro=self.version in v2pro_set + if self.is_v2pro: + self.sv_emb = nn.Linear(20480, gin_channels) + self.ge_to512 = nn.Linear(gin_channels, 512) + self.prelu = nn.PReLU(num_parameters=gin_channels) - def forward(self, codes, text, refer, noise_scale=0.5, speed=1): + def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): refer_mask = torch.ones_like(refer[:1, :1, :]) if self.version == "v1": ge = self.ref_enc(refer * refer_mask, refer_mask) else: ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) - x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed) + if self.is_v2pro: + ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1) + x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed) + else: + x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale