From 70f1ec719ed6dbc91b106f9124d2f8a30b07b3da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 3 Apr 2025 16:41:13 +0800 Subject: [PATCH] onnx export onnx export --- GPT_SoVITS/module/models_onnx.py | 345 +++++++++++++++++++++++++------ 1 file changed, 283 insertions(+), 62 deletions(-) diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index 1c24056..6bae60c 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -1,24 +1,28 @@ +import warnings +warnings.filterwarnings("ignore") import copy import math -from typing import Optional +import os +import pdb + import torch from torch import nn from torch.nn import functional as F from module import commons from module import modules -from module import attentions_onnx as attentions - -from f5_tts.model import DiT - +from module import attentions +#from f5_tts.model.backbones.dit import DiT from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from module.commons import init_weights, get_padding +from module.mrte_model import MRTE from module.quantize import ResidualVectorQuantizer # from text import symbols from text import symbols as symbols_v1 from text import symbols2 as symbols_v2 from torch.cuda.amp import autocast +import contextlib,random class StochasticDurationPredictor(nn.Module): @@ -186,7 +190,7 @@ class TextEncoder(nn.Module): kernel_size, p_dropout, latent_channels=192, - version="v2", + version = "v2", ): super().__init__() self.out_channels = out_channels @@ -220,7 +224,7 @@ class TextEncoder(nn.Module): symbols = symbols_v2.symbols self.text_embedding = nn.Embedding(len(symbols), hidden_channels) - self.mrte = attentions.MRTE() + self.mrte = MRTE() self.encoder2 = attentions.Encoder( hidden_channels, @@ -233,7 +237,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, text, ge, speed=1): + def forward(self, y, text, ge, speed=1,test=None): y_mask = torch.ones_like(y[:1,:1,:]) y = self.ssl_proj(y * y_mask) * y_mask @@ -244,16 +248,35 @@ class TextEncoder(nn.Module): text = self.text_embedding(text).transpose(1, 2) text = self.encoder_text(text * text_mask, text_mask) y = self.mrte(y, y_mask, text, text_mask, ge) - + y = self.encoder2(y * y_mask, y_mask) if(speed!=1): y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") - + stats = self.proj(y) * y_mask m, logs = torch.split(stats, self.out_channels, dim=1) return y, m, logs, y_mask + def extract_latent(self, x): + x = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(x) + return codes.transpose(0, 1) + + def decode_latent(self, codes, y_mask, refer, refer_mask, ge): + quantized = self.quantizer.decode(codes) + + y = self.vq_proj(quantized) * y_mask + y = self.encoder_ssl(y * y_mask, y_mask) + + y = self.mrte(y, y_mask, refer, refer_mask, ge) + + y = self.encoder2(y * y_mask, y_mask) + + stats = self.proj(y) * y_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return y, m, logs, y_mask, quantized + class ResidualCouplingBlock(nn.Module): def __init__( @@ -465,7 +488,7 @@ class Generator(torch.nn.Module): if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x, g:Optional[torch.Tensor]=None): + def forward(self, x, g=None): x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -923,7 +946,7 @@ class SynthesizerTrn(nn.Module): # self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False) - def forward(self, codes, text, refer,noise_scale=0.5, speed=1): + def forward(self, codes, text, refer, noise_scale=0.5): refer_mask = torch.ones_like(refer[:1,:1,:]) if (self.version == "v1"): ge = self.ref_enc(refer * refer_mask, refer_mask) @@ -935,79 +958,98 @@ class SynthesizerTrn(nn.Module): dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) - x, m_p, logs_p, y_mask = self.enc_p( - quantized, text, ge, speed + _, m_p, logs_p, y_mask = self.enc_p( + quantized, text, ge ) - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - z = self.flow(z_p, y_mask, g=ge, reverse=True) - o = self.dec((z * y_mask)[:, :, :], g=ge) return o def extract_latent(self, x): ssl = self.ssl_proj(x) - quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + _, codes, _, _ = self.quantizer(ssl) return codes.transpose(0, 1) - + + class CFM(torch.nn.Module): def __init__( self, in_channels,dit ): super().__init__() - # self.sigma_min = 1e-6 + self.sigma_min = 1e-6 self.estimator = dit self.in_channels = in_channels - # self.criterion = torch.nn.MSELoss() + self.criterion = torch.nn.MSELoss() - def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0): + @torch.inference_mode() + def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0): """Forward diffusion""" B, T = mu.size(0), mu.size(1) - x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) - - ntimesteps = int(n_timesteps) - + x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature prompt_len = prompt.size(-1) prompt_x = torch.zeros_like(x,dtype=mu.dtype) prompt_x[..., :prompt_len] = prompt[..., :prompt_len] - x[..., :prompt_len] = 0.0 + x[..., :prompt_len] = 0 mu=mu.transpose(2,1) - t = torch.tensor(0.0,dtype=x.dtype,device=x.device) - d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device) - d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d - - for j in range(ntimesteps): + t = 0 + d = 1 / n_timesteps + for j in range(n_timesteps): t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t - # d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d + d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d # v_pred = model(x, t_tensor, d_tensor, **extra_args) - v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1) - # if inference_cfg_rate>1e-5: - # neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) - # v_pred=v_pred+(v_pred-neg)*inference_cfg_rate + v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1) + if inference_cfg_rate>1e-5: + neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) + v_pred=v_pred+(v_pred-neg)*inference_cfg_rate x = x + d * v_pred t = t + d - x[:, :, :prompt_len] = 0.0 + x[:, :, :prompt_len] = 0 return x + def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt): + b, _, t = x1.shape + t = torch.rand([b], device=mu.device, dtype=x1.dtype) + x0 = torch.randn_like(x1,device=mu.device) + vt = x1 - x0 + xt = x0 + t[:, None, None] * vt + dt = torch.zeros_like(t,device=mu.device) + prompt = torch.zeros_like(x1) + for i in range(b): + prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]] + xt[i, :, :prompt_lens[i]] = 0 + gailv=0.3# if ttime()>1736250488 else 0.1 + if random.random() < gailv: + base = torch.randint(2, 8, (t.shape[0],), device=mu.device) + d = 1/torch.pow(2, base) + d_input = d.clone() + d_input[d_input < 1e-2] = 0 + # with torch.no_grad(): + v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach() + # v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach() + x_mid = xt + d[:, None, None] * v_pred_1 + # v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach() + v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach() + vt = (v_pred_1 + v_pred_2) / 2 + vt = vt.detach() + dt = 2*d + + vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1) + loss = 0 + for i in range(b): + loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]]) + loss /= b + + return loss def set_no_grad(net_g): for name, param in net_g.named_parameters(): param.requires_grad=False -@torch.jit.script_if_tracing -def compile_codes_length(codes): - y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device) - return y_lengths1 * 2.5 * 1.5 - -@torch.jit.script_if_tracing -def compile_ref_length(refer): - refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) - return refer_lengths class SynthesizerTrnV3(nn.Module): """ @@ -1035,7 +1077,6 @@ class SynthesizerTrnV3(nn.Module): use_sdp=True, semantic_frame_rate=None, freeze_quantizer=None, - version="v3", **kwargs): super().__init__() @@ -1056,7 +1097,6 @@ class SynthesizerTrnV3(nn.Module): self.segment_size = segment_size self.n_speakers = n_speakers self.gin_channels = gin_channels - self.version = version self.model_dim=512 self.use_sdp = use_sdp @@ -1083,7 +1123,7 @@ class SynthesizerTrnV3(nn.Module): n_q=1, bins=1024 ) - freeze_quantizer + self.freeze_quantizer=freeze_quantizer inter_channels2=512 self.bridge=nn.Sequential( nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), @@ -1092,32 +1132,213 @@ class SynthesizerTrnV3(nn.Module): self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels) self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1) self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim - if freeze_quantizer==True: + if self.freeze_quantizer==True: set_no_grad(self.ssl_proj) set_no_grad(self.quantizer) set_no_grad(self.enc_p) - def create_ge(self, refer): - refer_lengths = compile_ref_length(refer) - refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) - ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) - return ge + def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now + with autocast(enabled=False): + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + ge = self.ref_enc(y[:,:704] * y_mask, y_mask) + maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() + with maybe_no_grad: + if self.freeze_quantizer: + self.ssl_proj.eval()# + self.quantizer.eval() + self.enc_p.eval() + ssl = self.ssl_proj(ssl) + quantized, codes, commit_loss, quantized_list = self.quantizer( + ssl, layers=[0] + ) + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + fea=self.bridge(x) + fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT + fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate. + B=ssl.shape[0] + prompt_len_max = mel_lengths*2/3 + prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) + minn=min(mel.shape[-1],fea.shape[-1]) + mel=mel[:,:,:minn] + fea=fea[:,:,:minn] + cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt) + return cfm_loss - def forward(self, codes, text,ge,speed=1): + @torch.no_grad() + def decode_encp(self, codes,text, refer,ge=None,speed=1): + # print(2333333,refer.shape) + # ge=None + if(ge==None): + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) + y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device) + if speed==1: + sizee=int(codes.size(2)*2.5*1.5) + else: + sizee=int(codes.size(2)*2.5*1.5/speed)+1 + y_lengths1 = torch.LongTensor([sizee]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) - y_lengths1=compile_codes_length(codes) - quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == '25hz': quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed) + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed) fea=self.bridge(x) fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT ####more wn paramter to learn mel fea, y_mask_ = self.wns1(fea, y_lengths1, ge) - return fea + return fea,ge def extract_latent(self, x): ssl = self.ssl_proj(x) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) - return codes.transpose(0,1) \ No newline at end of file + return codes.transpose(0,1) + + +class SynthesizerTrnV3b(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + **kwargs): + + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.model_dim=512 + self.use_sdp = use_sdp + self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout) + # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, + upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, + gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + + ssl_dim = 768 + assert semantic_frame_rate in ['25hz', "50hz"] + self.semantic_frame_rate = semantic_frame_rate + if semantic_frame_rate == '25hz': + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) + else: + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) + + self.quantizer = ResidualVectorQuantizer( + dimension=ssl_dim, + n_q=1, + bins=1024 + ) + self.freeze_quantizer=freeze_quantizer + + inter_channels2=512 + self.bridge=nn.Sequential( + nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), + nn.LeakyReLU() + ) + self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels) + self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1) + self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim + + + def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now + with autocast(enabled=False): + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + ge = self.ref_enc(y[:,:704] * y_mask, y_mask) + # ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k + # ge=None + maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() + with maybe_no_grad: + if self.freeze_quantizer: + self.ssl_proj.eval() + self.quantizer.eval() + ssl = self.ssl_proj(ssl) + quantized, codes, commit_loss, quantized_list = self.quantizer( + ssl, layers=[0] + ) + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) + z_p = self.flow(z, y_mask, g=ge) + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=ge) + fea=self.bridge(x) + fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT + fea, y_mask_ = self.wns1(fea, mel_lengths, ge) + learned_mel = self.linear_mel(fea) + B=ssl.shape[0] + prompt_len_max = mel_lengths*2/3 + prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)# + minn=min(mel.shape[-1],fea.shape[-1]) + mel=mel[:,:,:minn] + fea=fea[:,:,:minn] + cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need + return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized + + @torch.no_grad() + def decode_encp(self, codes,text, refer,ge=None): + # print(2333333,refer.shape) + # ge=None + if(ge==None): + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) + y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device) + y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == '25hz': + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + fea=self.bridge(x) + fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT + ####more wn paramter to learn mel + fea, y_mask_ = self.wns1(fea, y_lengths1, ge) + return fea,ge + + def extract_latent(self, x): + ssl = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + return codes.transpose(0,1) +