mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Fix onnx_export to support v2 (#1604)
This commit is contained in:
parent
570da092c9
commit
0c000191b3
@ -13,7 +13,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|||||||
from module.commons import init_weights, get_padding
|
from module.commons import init_weights, get_padding
|
||||||
from module.mrte_model import MRTE
|
from module.mrte_model import MRTE
|
||||||
from module.quantize import ResidualVectorQuantizer
|
from module.quantize import ResidualVectorQuantizer
|
||||||
from text import symbols
|
# from text import symbols
|
||||||
|
from text import symbols as symbols_v1
|
||||||
|
from text import symbols2 as symbols_v2
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
|
|
||||||
@ -182,6 +184,7 @@ class TextEncoder(nn.Module):
|
|||||||
kernel_size,
|
kernel_size,
|
||||||
p_dropout,
|
p_dropout,
|
||||||
latent_channels=192,
|
latent_channels=192,
|
||||||
|
version="v2",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@ -192,6 +195,7 @@ class TextEncoder(nn.Module):
|
|||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.p_dropout = p_dropout
|
self.p_dropout = p_dropout
|
||||||
self.latent_channels = latent_channels
|
self.latent_channels = latent_channels
|
||||||
|
self.version = version
|
||||||
|
|
||||||
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
||||||
|
|
||||||
@ -207,6 +211,11 @@ class TextEncoder(nn.Module):
|
|||||||
self.encoder_text = attentions.Encoder(
|
self.encoder_text = attentions.Encoder(
|
||||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.version == "v1":
|
||||||
|
symbols = symbols_v1.symbols
|
||||||
|
else:
|
||||||
|
symbols = symbols_v2.symbols
|
||||||
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
||||||
|
|
||||||
self.mrte = MRTE()
|
self.mrte = MRTE()
|
||||||
@ -817,6 +826,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
use_sdp=True,
|
use_sdp=True,
|
||||||
semantic_frame_rate=None,
|
semantic_frame_rate=None,
|
||||||
freeze_quantizer=None,
|
freeze_quantizer=None,
|
||||||
|
version="v2",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -837,6 +847,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
self.segment_size = segment_size
|
self.segment_size = segment_size
|
||||||
self.n_speakers = n_speakers
|
self.n_speakers = n_speakers
|
||||||
self.gin_channels = gin_channels
|
self.gin_channels = gin_channels
|
||||||
|
self.version = version
|
||||||
|
|
||||||
self.use_sdp = use_sdp
|
self.use_sdp = use_sdp
|
||||||
self.enc_p = TextEncoder(
|
self.enc_p = TextEncoder(
|
||||||
@ -847,6 +858,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
n_layers,
|
n_layers,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
p_dropout,
|
p_dropout,
|
||||||
|
version=version,
|
||||||
)
|
)
|
||||||
self.dec = Generator(
|
self.dec = Generator(
|
||||||
inter_channels,
|
inter_channels,
|
||||||
@ -871,9 +883,11 @@ class SynthesizerTrn(nn.Module):
|
|||||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ref_enc = modules.MelStyleEncoder(
|
# self.version=os.environ.get("version","v1")
|
||||||
spec_channels, style_vector_dim=gin_channels
|
if self.version == "v1":
|
||||||
)
|
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||||
|
else:
|
||||||
|
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
||||||
|
|
||||||
ssl_dim = 768
|
ssl_dim = 768
|
||||||
self.ssl_dim = ssl_dim
|
self.ssl_dim = ssl_dim
|
||||||
@ -894,7 +908,10 @@ class SynthesizerTrn(nn.Module):
|
|||||||
|
|
||||||
def forward(self, codes, text, refer):
|
def forward(self, codes, text, refer):
|
||||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
if (self.version == "v1"):
|
||||||
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
else:
|
||||||
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
|
|
||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from module.models_onnx import SynthesizerTrn, symbols
|
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
|
|
||||||
cnhubert.cnhubert_base_path=cnhubert_base_path
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
ssl_model = cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
import soundfile
|
import soundfile
|
||||||
@ -196,6 +197,11 @@ class VitsModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
dict_s2 = torch.load(vits_path,map_location="cpu")
|
dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||||
self.hps = dict_s2["config"]
|
self.hps = dict_s2["config"]
|
||||||
|
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||||
|
self.hps["model"]["version"] = "v1"
|
||||||
|
else:
|
||||||
|
self.hps["model"]["version"] = "v2"
|
||||||
|
|
||||||
self.hps = DictToAttrRecursive(self.hps)
|
self.hps = DictToAttrRecursive(self.hps)
|
||||||
self.hps.model.semantic_frame_rate = "25hz"
|
self.hps.model.semantic_frame_rate = "25hz"
|
||||||
self.vq_model = SynthesizerTrn(
|
self.vq_model = SynthesizerTrn(
|
||||||
@ -267,13 +273,13 @@ class SSLModel(nn.Module):
|
|||||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name):
|
def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||||
vits = VitsModel(vits_path)
|
vits = VitsModel(vits_path)
|
||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
ssl = SSLModel()
|
ssl = SSLModel()
|
||||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||||
@ -288,32 +294,36 @@ def export(vits_path, gpt_path, project_name):
|
|||||||
|
|
||||||
ssl_content = ssl(ref_audio_16k).float()
|
ssl_content = ssl(ref_audio_16k).float()
|
||||||
|
|
||||||
debug = False
|
# debug = False
|
||||||
|
debug = True
|
||||||
|
|
||||||
|
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||||
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
||||||
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
||||||
return
|
else:
|
||||||
|
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||||
|
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
if vits_model == "v1":
|
||||||
|
symbols = symbols_v1
|
||||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
else:
|
||||||
|
symbols = symbols_v2
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
|
||||||
|
|
||||||
MoeVSConf = {
|
MoeVSConf = {
|
||||||
"Folder" : f"{project_name}",
|
"Folder": f"{project_name}",
|
||||||
"Name" : f"{project_name}",
|
"Name": f"{project_name}",
|
||||||
"Type" : "GPT-SoVits",
|
"Type": "GPT-SoVits",
|
||||||
"Rate" : vits.hps.data.sampling_rate,
|
"Rate": vits.hps.data.sampling_rate,
|
||||||
"NumLayers": gpt.t2s_model.num_layers,
|
"NumLayers": gpt.t2s_model.num_layers,
|
||||||
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
||||||
"Dict": "BasicDict",
|
"Dict": "BasicDict",
|
||||||
"BertPath": "chinese-roberta-wwm-ext-large",
|
"BertPath": "chinese-roberta-wwm-ext-large",
|
||||||
"Symbol": symbols,
|
# "Symbol": symbols,
|
||||||
"AddBlank": False
|
"AddBlank": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||||
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user