diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index 232fd74..77ae307 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -13,7 +13,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from module.commons import init_weights, get_padding from module.mrte_model import MRTE from module.quantize import ResidualVectorQuantizer -from text import symbols +# from text import symbols +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 from torch.cuda.amp import autocast @@ -182,6 +184,7 @@ class TextEncoder(nn.Module): kernel_size, p_dropout, latent_channels=192, + version="v2", ): super().__init__() self.out_channels = out_channels @@ -192,6 +195,7 @@ class TextEncoder(nn.Module): self.kernel_size = kernel_size self.p_dropout = p_dropout self.latent_channels = latent_channels + self.version = version self.ssl_proj = nn.Conv1d(768, hidden_channels, 1) @@ -207,6 +211,11 @@ class TextEncoder(nn.Module): self.encoder_text = attentions.Encoder( hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout ) + + if self.version == "v1": + symbols = symbols_v1.symbols + else: + symbols = symbols_v2.symbols self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.mrte = MRTE() @@ -817,6 +826,7 @@ class SynthesizerTrn(nn.Module): use_sdp=True, semantic_frame_rate=None, freeze_quantizer=None, + version="v2", **kwargs ): super().__init__() @@ -837,6 +847,7 @@ class SynthesizerTrn(nn.Module): self.segment_size = segment_size self.n_speakers = n_speakers self.gin_channels = gin_channels + self.version = version self.use_sdp = use_sdp self.enc_p = TextEncoder( @@ -847,6 +858,7 @@ class SynthesizerTrn(nn.Module): n_layers, kernel_size, p_dropout, + version=version, ) self.dec = Generator( inter_channels, @@ -871,9 +883,11 @@ class SynthesizerTrn(nn.Module): inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels ) - self.ref_enc = modules.MelStyleEncoder( - spec_channels, style_vector_dim=gin_channels - ) + # self.version=os.environ.get("version","v1") + if self.version == "v1": + self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) + else: + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ssl_dim = 768 self.ssl_dim = ssl_dim @@ -894,7 +908,10 @@ class SynthesizerTrn(nn.Module): def forward(self, codes, text, refer): refer_mask = torch.ones_like(refer[:1,:1,:]) - ge = self.ref_enc(refer * refer_mask, refer_mask) + if (self.version == "v1"): + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index ab457d7..43aac19 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -1,11 +1,12 @@ -from module.models_onnx import SynthesizerTrn, symbols +from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule import torch import torchaudio from torch import nn from feature_extractor import cnhubert -cnhubert_base_path = "pretrained_models/chinese-hubert-base" -cnhubert.cnhubert_base_path=cnhubert_base_path + +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +cnhubert.cnhubert_base_path = cnhubert_base_path ssl_model = cnhubert.get_model() from text import cleaned_text_to_sequence import soundfile @@ -196,6 +197,11 @@ class VitsModel(nn.Module): super().__init__() dict_s2 = torch.load(vits_path,map_location="cpu") self.hps = dict_s2["config"] + if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = "v2" + self.hps = DictToAttrRecursive(self.hps) self.hps.model.semantic_frame_rate = "25hz" self.vq_model = SynthesizerTrn( @@ -267,13 +273,13 @@ class SSLModel(nn.Module): return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) -def export(vits_path, gpt_path, project_name): +def export(vits_path, gpt_path, project_name, vits_model="v2"): vits = VitsModel(vits_path) gpt = T2SModel(gpt_path, vits) gpt_sovits = GptSoVits(vits, gpt) ssl = SSLModel() - ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) - text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)]) + text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)]) ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float() ref_audio = torch.randn((1, 48000 * 5)).float() @@ -287,34 +293,38 @@ def export(vits_path, gpt_path, project_name): pass ssl_content = ssl(ref_audio_16k).float() - - debug = False + + # debug = False + debug = True + + # gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) if debug: a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) - return - - a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() + else: + a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() + soundfile.write("out.wav", a, vits.hps.data.sampling_rate) - soundfile.write("out.wav", a, vits.hps.data.sampling_rate) - - gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) + if vits_model == "v1": + symbols = symbols_v1 + else: + symbols = symbols_v2 MoeVSConf = { - "Folder" : f"{project_name}", - "Name" : f"{project_name}", - "Type" : "GPT-SoVits", - "Rate" : vits.hps.data.sampling_rate, - "NumLayers": gpt.t2s_model.num_layers, - "EmbeddingDim": gpt.t2s_model.embedding_dim, - "Dict": "BasicDict", - "BertPath": "chinese-roberta-wwm-ext-large", - "Symbol": symbols, - "AddBlank": False - } - + "Folder": f"{project_name}", + "Name": f"{project_name}", + "Type": "GPT-SoVits", + "Rate": vits.hps.data.sampling_rate, + "NumLayers": gpt.t2s_model.num_layers, + "EmbeddingDim": gpt.t2s_model.embedding_dim, + "Dict": "BasicDict", + "BertPath": "chinese-roberta-wwm-ext-large", + # "Symbol": symbols, + "AddBlank": False, + } + MoeVSConfJson = json.dumps(MoeVSConf) with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile: json.dump(MoeVSConf, MoeVsConfFile, indent = 4)