export_torch_script.py support v2Pro & v2ProPlus

This commit is contained in:
csh 2025-06-12 21:53:14 +08:00
parent ed89a02337
commit 5c91e66d2e
2 changed files with 275 additions and 29 deletions

View File

@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import argparse
from io import BytesIO
from typing import Optional
from my_utils import load_audio
import torch
@ -17,6 +18,9 @@ from module.models_onnx import SynthesizerTrn
from inference_webui import get_phones_and_bert
from sv import SV
import kaldi as Kaldi
import os
import soundfile
@ -32,6 +36,22 @@ default_config = {
"EOS": 1024,
}
sv_cn_model = None
def init_sv_cn(device, is_half):
global sv_cn_model
sv_cn_model = SV(device, is_half)
def load_sovits_new(sovits_path):
f = open(sovits_path, "rb")
meta = f.read(2)
if meta != b"PK":
data = b"PK" + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path, map_location="cpu", weights_only=False)
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"]
@ -328,15 +348,22 @@ class T2STransformer:
class VitsModel(nn.Module):
def __init__(self, vits_path):
def __init__(self, vits_path, version=None):
super().__init__()
# dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path, weights_only=False)
dict_s2 = load_sovits_new(vits_path)
self.hps = dict_s2["config"]
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
if version is None:
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
else:
self.hps["model"]["version"] = "v2"
if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]:
self.hps["model"]["version"] = version
else:
raise ValueError(f"Unsupported version: {version}")
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
@ -349,7 +376,7 @@ class VitsModel(nn.Module):
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
@ -358,7 +385,7 @@ class VitsModel(nn.Module):
self.hps.data.win_length,
center=False,
)
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0]
class T2SModel(nn.Module):
@ -632,7 +659,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T.to(text_seq.device)
@ -679,6 +706,127 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
print("#### exported gpt_sovits ####")
def export_prov2(
gpt_path,
vits_path,
version,
ref_audio_path,
ref_text,
output_path,
export_bert_and_ssl=False,
device="cpu",
is_half=True,
):
if sv_cn_model == None:
init_sv_cn(device,is_half)
if not os.path.exists(output_path):
os.makedirs(output_path)
print(f"目录已创建: {output_path}")
else:
print(f"目录已存在: {output_path}")
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
if export_bert_and_ssl:
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path)
print("#### exported ssl ####")
export_bert(output_path)
else:
s = ExportSSLModel(ssl)
print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
ref_text, "all_zh", "v2"
)
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T
if is_half:
ref_bert = ref_bert.half()
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
if is_half:
text_bert = text_bert.half()
text_bert = text_bert.to(text_seq.device)
ssl_content = ssl(ref_audio)
if is_half:
ssl_content = ssl_content.half()
ssl_content = ssl_content.to(device)
sv_model = ExportERes2NetV2(sv_cn_model)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path, version)
if is_half:
vits.vq_model = vits.vq_model.half()
vits.to(device)
vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path, weights_only=False)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####")
print(raw_t2s.config)
if is_half:
raw_t2s = raw_t2s.half()
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
t2s = torch.jit.script(t2s_m).to(device)
print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device)
gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio, 16000, 32000)
if is_half:
ref_audio_sr = ref_audio_sr.half()
ref_audio_sr = ref_audio_sr.to(device)
torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
torch._dynamo.mark_dynamic(ref_seq, 1)
torch._dynamo.mark_dynamic(text_seq, 1)
torch._dynamo.mark_dynamic(ref_bert, 0)
torch._dynamo.mark_dynamic(text_bert, 0)
# torch._dynamo.mark_dynamic(sv_emb, 0)
top_k = torch.LongTensor([5]).to(device)
# 先跑一遍 sv_model 让它加载 cache详情见 L880
gpt_sovits.sv_model(ref_audio_sr)
with torch.no_grad():
gpt_sovits_export = torch.jit.trace(
gpt_sovits,
example_inputs=(
ssl_content,
ref_audio_sr,
ref_seq,
text_seq,
ref_bert,
text_bert,
top_k,
),
)
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
print("#### exported gpt_sovits ####")
audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
print("start write wav")
soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000)
@torch.jit.script
def parse_audio(ref_audio):
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
@ -717,6 +865,66 @@ class GPT_SoVITS(nn.Module):
return audio
class ExportERes2NetV2(nn.Module):
def __init__(self, sv_cn_model:SV):
super(ExportERes2NetV2, self).__init__()
self.bn1 = sv_cn_model.embedding_model.bn1
self.conv1 = sv_cn_model.embedding_model.conv1
self.layer1 = sv_cn_model.embedding_model.layer1
self.layer2 = sv_cn_model.embedding_model.layer2
self.layer3 = sv_cn_model.embedding_model.layer3
self.layer4 = sv_cn_model.embedding_model.layer4
self.layer3_ds = sv_cn_model.embedding_model.layer3_ds
self.fuse34 = sv_cn_model.embedding_model.fuse34
# audio_16k.shape: [1,N]
def forward(self, audio_16k):
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
# 只跟 device 和 dtype 有关
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
x = torch.stack([x])
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds)
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
class GPT_SoVITS_V2Pro(nn.Module):
def __init__(self, t2s: T2SModel, vits: VitsModel,sv_model:ExportERes2NetV2):
super().__init__()
self.t2s = t2s
self.vits = vits
self.sv_model = sv_model
def forward(
self,
ssl_content: torch.Tensor,
ref_audio_sr: torch.Tensor,
ref_seq: Tensor,
text_seq: Tensor,
ref_bert: Tensor,
text_bert: Tensor,
top_k: LongTensor,
speed=1.0,
):
codes = self.vits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
sv_emb = self.sv_model(audio_16k)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb)
return audio
def test():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
@ -833,29 +1041,53 @@ def export_symbel(version="v2"):
def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument("--output_path", required=True, help="Path to the output directory")
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
parser.add_argument(
"--sovits_model", required=True, help="Path to the SoVITS model file"
)
parser.add_argument(
"--ref_audio", required=True, help="Path to the reference audio file"
)
parser.add_argument(
"--ref_text", required=True, help="Path to the reference text file"
)
parser.add_argument(
"--output_path", required=True, help="Path to the output directory"
)
parser.add_argument(
"--export_common_model", action="store_true", help="Export Bert and SSL model"
)
parser.add_argument("--device", help="Device to use")
parser.add_argument("--version", help="version of the model", default="v2")
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
args = parser.parse_args()
export(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
device=args.device,
export_bert_and_ssl=args.export_common_model,
)
if args.version in ["v2Pro", "v2ProPlus"]:
is_half = not args.no_half
print(f"Using half precision: {is_half}")
export_prov2(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
version=args.version,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
export_bert_and_ssl=args.export_common_model,
device=args.device,
is_half=is_half,
)
else:
export(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
device=args.device,
export_bert_and_ssl=args.export_common_model,
)
import inference_webui
if __name__ == "__main__":
inference_webui.is_half = False
inference_webui.dtype = torch.float32
main()
with torch.no_grad():
main()
# test()

View File

@ -762,6 +762,7 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1)
v2pro_set={"v2Pro","v2ProPlus"}
class SynthesizerTrn(nn.Module):
"""
@ -867,20 +868,33 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
self.is_v2pro=self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
refer_mask = torch.ones_like(refer[:1, :1, :])
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb)
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
if self.is_v2pro:
ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
else:
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale