mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-23 21:05:22 +08:00
export_torch_script.py support v2Pro & v2ProPlus
This commit is contained in:
parent
ed89a02337
commit
5c91e66d2e
@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from my_utils import load_audio
|
||||
import torch
|
||||
@ -17,6 +18,9 @@ from module.models_onnx import SynthesizerTrn
|
||||
|
||||
from inference_webui import get_phones_and_bert
|
||||
|
||||
from sv import SV
|
||||
import kaldi as Kaldi
|
||||
|
||||
import os
|
||||
import soundfile
|
||||
|
||||
@ -32,6 +36,22 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
sv_cn_model = None
|
||||
def init_sv_cn(device, is_half):
|
||||
global sv_cn_model
|
||||
sv_cn_model = SV(device, is_half)
|
||||
|
||||
def load_sovits_new(sovits_path):
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != b"PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
bio.seek(0)
|
||||
return torch.load(bio, map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
@ -328,15 +348,22 @@ class T2STransformer:
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
def __init__(self, vits_path, version=None):
|
||||
super().__init__()
|
||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path, weights_only=False)
|
||||
dict_s2 = load_sovits_new(vits_path)
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
|
||||
if version is None:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]:
|
||||
self.hps["model"]["version"] = version
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
@ -349,7 +376,7 @@ class VitsModel(nn.Module):
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
|
||||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
|
||||
refer = spectrogram_torch(
|
||||
ref_audio,
|
||||
self.hps.data.filter_length,
|
||||
@ -358,7 +385,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0]
|
||||
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
@ -632,7 +659,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
@ -679,6 +706,127 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
print("#### exported gpt_sovits ####")
|
||||
|
||||
|
||||
def export_prov2(
|
||||
gpt_path,
|
||||
vits_path,
|
||||
version,
|
||||
ref_audio_path,
|
||||
ref_text,
|
||||
output_path,
|
||||
export_bert_and_ssl=False,
|
||||
device="cpu",
|
||||
is_half=True,
|
||||
):
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn(device,is_half)
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
print(f"目录已创建: {output_path}")
|
||||
else:
|
||||
print(f"目录已存在: {output_path}")
|
||||
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
if export_bert_and_ssl:
|
||||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||
torch.jit.script(s).save(ssl_path)
|
||||
print("#### exported ssl ####")
|
||||
export_bert(output_path)
|
||||
else:
|
||||
s = ExportSSLModel(ssl)
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||||
ref_text, "all_zh", "v2"
|
||||
)
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T
|
||||
if is_half:
|
||||
ref_bert = ref_bert.half()
|
||||
ref_bert = ref_bert.to(ref_seq.device)
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T
|
||||
if is_half:
|
||||
text_bert = text_bert.half()
|
||||
text_bert = text_bert.to(text_seq.device)
|
||||
|
||||
ssl_content = ssl(ref_audio)
|
||||
if is_half:
|
||||
ssl_content = ssl_content.half()
|
||||
ssl_content = ssl_content.to(device)
|
||||
|
||||
sv_model = ExportERes2NetV2(sv_cn_model)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path, version)
|
||||
if is_half:
|
||||
vits.vq_model = vits.vq_model.half()
|
||||
vits.to(device)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
if is_half:
|
||||
raw_t2s = raw_t2s.half()
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m).to(device)
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device)
|
||||
gpt_sovits.eval()
|
||||
|
||||
ref_audio_sr = s.resample(ref_audio, 16000, 32000)
|
||||
if is_half:
|
||||
ref_audio_sr = ref_audio_sr.half()
|
||||
ref_audio_sr = ref_audio_sr.to(device)
|
||||
|
||||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||
# torch._dynamo.mark_dynamic(sv_emb, 0)
|
||||
|
||||
top_k = torch.LongTensor([5]).to(device)
|
||||
# 先跑一遍 sv_model 让它加载 cache,详情见 L880
|
||||
gpt_sovits.sv_model(ref_audio_sr)
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert,
|
||||
top_k,
|
||||
),
|
||||
)
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
print("#### exported gpt_sovits ####")
|
||||
audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
print("start write wav")
|
||||
soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||||
@ -717,6 +865,66 @@ class GPT_SoVITS(nn.Module):
|
||||
return audio
|
||||
|
||||
|
||||
class ExportERes2NetV2(nn.Module):
|
||||
def __init__(self, sv_cn_model:SV):
|
||||
super(ExportERes2NetV2, self).__init__()
|
||||
self.bn1 = sv_cn_model.embedding_model.bn1
|
||||
self.conv1 = sv_cn_model.embedding_model.conv1
|
||||
self.layer1 = sv_cn_model.embedding_model.layer1
|
||||
self.layer2 = sv_cn_model.embedding_model.layer2
|
||||
self.layer3 = sv_cn_model.embedding_model.layer3
|
||||
self.layer4 = sv_cn_model.embedding_model.layer4
|
||||
self.layer3_ds = sv_cn_model.embedding_model.layer3_ds
|
||||
self.fuse34 = sv_cn_model.embedding_model.fuse34
|
||||
|
||||
# audio_16k.shape: [1,N]
|
||||
def forward(self, audio_16k):
|
||||
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
|
||||
# 只跟 device 和 dtype 有关
|
||||
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
|
||||
x = torch.stack([x])
|
||||
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
|
||||
|
||||
class GPT_SoVITS_V2Pro(nn.Module):
|
||||
def __init__(self, t2s: T2SModel, vits: VitsModel,sv_model:ExportERes2NetV2):
|
||||
super().__init__()
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
self.sv_model = sv_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content: torch.Tensor,
|
||||
ref_audio_sr: torch.Tensor,
|
||||
ref_seq: Tensor,
|
||||
text_seq: Tensor,
|
||||
ref_bert: Tensor,
|
||||
text_bert: Tensor,
|
||||
top_k: LongTensor,
|
||||
speed=1.0,
|
||||
):
|
||||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||
sv_emb = self.sv_model(audio_16k)
|
||||
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb)
|
||||
return audio
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
@ -833,29 +1041,53 @@ def export_symbel(version="v2"):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||
parser.add_argument(
|
||||
"--sovits_model", required=True, help="Path to the SoVITS model file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_audio", required=True, help="Path to the reference audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_text", required=True, help="Path to the reference text file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", required=True, help="Path to the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_common_model", action="store_true", help="Export Bert and SSL model"
|
||||
)
|
||||
parser.add_argument("--device", help="Device to use")
|
||||
parser.add_argument("--version", help="version of the model", default="v2")
|
||||
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
export(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
)
|
||||
if args.version in ["v2Pro", "v2ProPlus"]:
|
||||
is_half = not args.no_half
|
||||
print(f"Using half precision: {is_half}")
|
||||
export_prov2(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
version=args.version,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
device=args.device,
|
||||
is_half=is_half,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
)
|
||||
|
||||
|
||||
import inference_webui
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference_webui.is_half = False
|
||||
inference_webui.dtype = torch.float32
|
||||
main()
|
||||
with torch.no_grad():
|
||||
main()
|
||||
# test()
|
||||
|
@ -762,6 +762,7 @@ class CodePredictor(nn.Module):
|
||||
|
||||
return pred_codes.transpose(0, 1)
|
||||
|
||||
v2pro_set={"v2Pro","v2ProPlus"}
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
@ -867,20 +868,33 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.text_embedding.requires_grad_(False)
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
self.is_v2pro=self.version in v2pro_set
|
||||
if self.is_v2pro:
|
||||
self.sv_emb = nn.Linear(20480, gin_channels)
|
||||
self.ge_to512 = nn.Linear(gin_channels, 512)
|
||||
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
||||
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
sv_emb = self.sv_emb(sv_emb)
|
||||
ge += sv_emb.unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
if self.is_v2pro:
|
||||
ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
|
||||
else:
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user