onnx export

onnx export
This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα 2025-04-03 16:41:13 +08:00 committed by GitHub
parent e537e129a5
commit 70f1ec719e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,24 +1,28 @@
import warnings
warnings.filterwarnings("ignore")
import copy import copy
import math import math
from typing import Optional import os
import pdb
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from module import commons
from module import modules from module import modules
from module import attentions_onnx as attentions from module import attentions
#from f5_tts.model.backbones.dit import DiT
from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib,random
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
@ -186,7 +190,7 @@ class TextEncoder(nn.Module):
kernel_size, kernel_size,
p_dropout, p_dropout,
latent_channels=192, latent_channels=192,
version="v2", version = "v2",
): ):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
@ -220,7 +224,7 @@ class TextEncoder(nn.Module):
symbols = symbols_v2.symbols symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = attentions.MRTE() self.mrte = MRTE()
self.encoder2 = attentions.Encoder( self.encoder2 = attentions.Encoder(
hidden_channels, hidden_channels,
@ -233,7 +237,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge, speed=1): def forward(self, y, text, ge, speed=1,test=None):
y_mask = torch.ones_like(y[:1,:1,:]) y_mask = torch.ones_like(y[:1,:1,:])
y = self.ssl_proj(y * y_mask) * y_mask y = self.ssl_proj(y * y_mask) * y_mask
@ -254,6 +258,25 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask return y, m, logs, y_mask
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
y = self.mrte(y, y_mask, refer, refer_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__( def __init__(
@ -465,7 +488,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g:Optional[torch.Tensor]=None): def forward(self, x, g=None):
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: if g is not None:
x = x + self.cond(g) x = x + self.cond(g)
@ -923,7 +946,7 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer,noise_scale=0.5, speed=1): def forward(self, codes, text, refer, noise_scale=0.5):
refer_mask = torch.ones_like(refer[:1,:1,:]) refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"): if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
@ -935,79 +958,98 @@ class SynthesizerTrn(nn.Module):
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p( _, m_p, logs_p, y_mask = self.enc_p(
quantized, text, ge, speed quantized, text, ge
) )
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge) o = self.dec((z * y_mask)[:, :, :], g=ge)
return o return o
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) _, codes, _, _ = self.quantizer(ssl)
return codes.transpose(0, 1) return codes.transpose(0, 1)
class CFM(torch.nn.Module): class CFM(torch.nn.Module):
def __init__( def __init__(
self, self,
in_channels,dit in_channels,dit
): ):
super().__init__() super().__init__()
# self.sigma_min = 1e-6 self.sigma_min = 1e-6
self.estimator = dit self.estimator = dit
self.in_channels = in_channels self.in_channels = in_channels
# self.criterion = torch.nn.MSELoss() self.criterion = torch.nn.MSELoss()
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0): @torch.inference_mode()
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
"""Forward diffusion""" """Forward diffusion"""
B, T = mu.size(0), mu.size(1) B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature
ntimesteps = int(n_timesteps)
prompt_len = prompt.size(-1) prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype) prompt_x = torch.zeros_like(x,dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len] prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0.0 x[..., :prompt_len] = 0
mu=mu.transpose(2,1) mu=mu.transpose(2,1)
t = torch.tensor(0.0,dtype=x.dtype,device=x.device) t = 0
d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device) d = 1 / n_timesteps
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d for j in range(n_timesteps):
for j in range(ntimesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args) # v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1) v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
# if inference_cfg_rate>1e-5: if inference_cfg_rate>1e-5:
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
x = x + d * v_pred x = x + d * v_pred
t = t + d t = t + d
x[:, :, :prompt_len] = 0.0 x[:, :, :prompt_len] = 0
return x return x
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
b, _, t = x1.shape
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
x0 = torch.randn_like(x1,device=mu.device)
vt = x1 - x0
xt = x0 + t[:, None, None] * vt
dt = torch.zeros_like(t,device=mu.device)
prompt = torch.zeros_like(x1)
for i in range(b):
prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]]
xt[i, :, :prompt_lens[i]] = 0
gailv=0.3# if ttime()>1736250488 else 0.1
if random.random() < gailv:
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
d = 1/torch.pow(2, base)
d_input = d.clone()
d_input[d_input < 1e-2] = 0
# with torch.no_grad():
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
x_mid = xt + d[:, None, None] * v_pred_1
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
vt = (v_pred_1 + v_pred_2) / 2
vt = vt.detach()
dt = 2*d
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
loss = 0
for i in range(b):
loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]])
loss /= b
return loss
def set_no_grad(net_g): def set_no_grad(net_g):
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
param.requires_grad=False param.requires_grad=False
@torch.jit.script_if_tracing
def compile_codes_length(codes):
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
return y_lengths1 * 2.5 * 1.5
@torch.jit.script_if_tracing
def compile_ref_length(refer):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
return refer_lengths
class SynthesizerTrnV3(nn.Module): class SynthesizerTrnV3(nn.Module):
""" """
@ -1035,7 +1077,6 @@ class SynthesizerTrnV3(nn.Module):
use_sdp=True, use_sdp=True,
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
version="v3",
**kwargs): **kwargs):
super().__init__() super().__init__()
@ -1056,7 +1097,6 @@ class SynthesizerTrnV3(nn.Module):
self.segment_size = segment_size self.segment_size = segment_size
self.n_speakers = n_speakers self.n_speakers = n_speakers
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.version = version
self.model_dim=512 self.model_dim=512
self.use_sdp = use_sdp self.use_sdp = use_sdp
@ -1083,7 +1123,7 @@ class SynthesizerTrnV3(nn.Module):
n_q=1, n_q=1,
bins=1024 bins=1024
) )
freeze_quantizer self.freeze_quantizer=freeze_quantizer
inter_channels2=512 inter_channels2=512
self.bridge=nn.Sequential( self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
@ -1092,32 +1132,213 @@ class SynthesizerTrnV3(nn.Module):
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels) self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1) self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
if freeze_quantizer==True: if self.freeze_quantizer==True:
set_no_grad(self.ssl_proj) set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer) set_no_grad(self.quantizer)
set_no_grad(self.enc_p) set_no_grad(self.enc_p)
def create_ge(self, refer): def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
refer_lengths = compile_ref_length(refer) with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()#
self.quantizer.eval()
self.enc_p.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
return cfm_loss
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None,speed=1):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
return ge y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
if speed==1:
def forward(self, codes, text,ge,speed=1): sizee=int(codes.size(2)*2.5*1.5)
else:
y_lengths1=compile_codes_length(codes) sizee=int(codes.size(2)*2.5*1.5/speed)+1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed)
fea=self.bridge(x) fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
####more wn paramter to learn mel ####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge) fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea return fea,ge
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1) return codes.transpose(0,1)
class SynthesizerTrnV3b(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.model_dim=512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.freeze_quantizer=freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
# ge=None
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
learned_mel = self.linear_mel(fea)
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)#
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)