Fix onnx_export to support v2 (#1604)

This commit is contained in:
zzz 2024-09-13 11:27:22 +08:00 committed by GitHub
parent 570da092c9
commit 0c000191b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 31 deletions

View File

@ -13,7 +13,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
from text import symbols
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
@ -182,6 +184,7 @@ class TextEncoder(nn.Module):
kernel_size,
p_dropout,
latent_channels=192,
version="v2",
):
super().__init__()
self.out_channels = out_channels
@ -192,6 +195,7 @@ class TextEncoder(nn.Module):
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.latent_channels = latent_channels
self.version = version
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
@ -207,6 +211,11 @@ class TextEncoder(nn.Module):
self.encoder_text = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
if self.version == "v1":
symbols = symbols_v1.symbols
else:
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
@ -817,6 +826,7 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version="v2",
**kwargs
):
super().__init__()
@ -837,6 +847,7 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
@ -847,6 +858,7 @@ class SynthesizerTrn(nn.Module):
n_layers,
kernel_size,
p_dropout,
version=version,
)
self.dec = Generator(
inter_channels,
@ -871,9 +883,11 @@ class SynthesizerTrn(nn.Module):
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.ref_enc = modules.MelStyleEncoder(
spec_channels, style_vector_dim=gin_channels
)
# self.version=os.environ.get("version","v1")
if self.version == "v1":
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
ssl_dim = 768
self.ssl_dim = ssl_dim
@ -894,7 +908,10 @@ class SynthesizerTrn(nn.Module):
def forward(self, codes, text, refer):
refer_mask = torch.ones_like(refer[:1,:1,:])
ge = self.ref_enc(refer * refer_mask, refer_mask)
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":

View File

@ -1,11 +1,12 @@
from module.models_onnx import SynthesizerTrn, symbols
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch
import torchaudio
from torch import nn
from feature_extractor import cnhubert
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path=cnhubert_base_path
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
import soundfile
@ -196,6 +197,11 @@ class VitsModel(nn.Module):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
@ -267,13 +273,13 @@ class SSLModel(nn.Module):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(vits_path, gpt_path, project_name):
def export(vits_path, gpt_path, project_name, vits_model="v2"):
vits = VitsModel(vits_path)
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
@ -287,34 +293,38 @@ def export(vits_path, gpt_path, project_name):
pass
ssl_content = ssl(ref_audio_16k).float()
debug = False
# debug = False
debug = True
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if debug:
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
return
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
else:
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if vits_model == "v1":
symbols = symbols_v1
else:
symbols = symbols_v2
MoeVSConf = {
"Folder" : f"{project_name}",
"Name" : f"{project_name}",
"Type" : "GPT-SoVits",
"Rate" : vits.hps.data.sampling_rate,
"NumLayers": gpt.t2s_model.num_layers,
"EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large",
"Symbol": symbols,
"AddBlank": False
}
"Folder": f"{project_name}",
"Name": f"{project_name}",
"Type": "GPT-SoVits",
"Rate": vits.hps.data.sampling_rate,
"NumLayers": gpt.t2s_model.num_layers,
"EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large",
# "Symbol": symbols,
"AddBlank": False,
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)