feat: 添加导出 v3 的 script (#2208)

* feat: 添加导出 v3 的 script

* Fix: 由于 export_torch_script_v3 的改动,v2 现在需要传入 top_k
This commit is contained in:
zzz 2025-03-26 14:50:55 +08:00 committed by GitHub
parent f1332ff53a
commit b0e465eb72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1275 additions and 9 deletions

View File

@ -427,7 +427,7 @@ class T2SModel(nn.Module):
self.top_k = int(raw_t2s.config["inference"]["top_k"])
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor,top_k:LongTensor):
bert = torch.cat([ref_bert.T, text_bert.T], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
@ -472,12 +472,13 @@ class T2SModel(nn.Module):
.to(device=x.device, dtype=torch.bool)
idx = 0
top_k = int(top_k)
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
logits = self.ar_predict_layer(xy_dec[:, -1])
logits = logits[:, :-1]
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1)
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
@ -493,7 +494,7 @@ class T2SModel(nn.Module):
if(idx<11):###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1)
@ -653,6 +654,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
torch._dynamo.mark_dynamic(ref_bert, 0)
torch._dynamo.mark_dynamic(text_bert, 0)
top_k = torch.LongTensor([5]).to(device)
with torch.no_grad():
gpt_sovits_export = torch.jit.trace(
gpt_sovits,
@ -662,7 +665,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ref_seq,
text_seq,
ref_bert,
text_bert))
text_bert,
top_k))
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
@ -684,15 +688,26 @@ class GPT_SoVITS(nn.Module):
self.t2s = t2s
self.vits = vits
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0):
def forward(
self,
ssl_content: torch.Tensor,
ref_audio_sr: torch.Tensor,
ref_seq: Tensor,
text_seq: Tensor,
ref_bert: Tensor,
text_bert: Tensor,
top_k: LongTensor,
speed=1.0,
):
codes = self.vits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
return audio
def test():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
@ -784,8 +799,10 @@ def test():
print('text_bert:',text_bert.shape)
text_bert=text_bert.to('cuda')
top_k = torch.LongTensor([5]).to('cuda')
with torch.no_grad():
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert)
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
print('start write wav')
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)

File diff suppressed because it is too large Load Diff

View File

@ -138,7 +138,7 @@ class DiT(nn.Module):
time: float["b"] | float[""], # time step # noqa: F821 F722
dt_base_bootstrap,
text0, # : int["b nt"] # noqa: F722#####condition feature
use_grad_ckpt, # bool
use_grad_ckpt=False, # bool
###no-use
drop_audio_cond=False, # cfg for cond audio
drop_text=False, # cfg for text

View File

@ -9,6 +9,8 @@ from module import commons
from module import modules
from module import attentions_onnx as attentions
from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
@ -342,6 +344,37 @@ class PosteriorEncoder(nn.Module):
return z, m, logs, x_mask
class Encoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths, g=None):
if(g!=None):
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
return stats, x_mask
class WNEncoder(nn.Module):
def __init__(
self,
@ -916,4 +949,175 @@ class SynthesizerTrn(nn.Module):
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
return codes.transpose(0, 1)
class CFM(torch.nn.Module):
def __init__(
self,
in_channels,dit
):
super().__init__()
# self.sigma_min = 1e-6
self.estimator = dit
self.in_channels = in_channels
# self.criterion = torch.nn.MSELoss()
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0):
"""Forward diffusion"""
B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype)
ntimesteps = int(n_timesteps)
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0.0
mu=mu.transpose(2,1)
t = torch.tensor(0.0,dtype=x.dtype,device=x.device)
d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device)
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
for j in range(ntimesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1)
# if inference_cfg_rate>1e-5:
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
x = x + d * v_pred
t = t + d
x[:, :, :prompt_len] = 0.0
return x
def set_no_grad(net_g):
for name, param in net_g.named_parameters():
param.requires_grad=False
@torch.jit.script_if_tracing
def compile_codes_length(codes):
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
return y_lengths1 * 2.5 * 1.5
@torch.jit.script_if_tracing
def compile_ref_length(refer):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
return refer_lengths
class SynthesizerTrnV3(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version="v3",
**kwargs):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.model_dim=512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
# gin_channels=gin_channels)
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
if freeze_quantizer==True:
set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer)
set_no_grad(self.enc_p)
def create_ge(self, refer):
refer_lengths = compile_ref_length(refer)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
return ge
def forward(self, codes, text,ge,speed=1):
y_lengths1=compile_codes_length(codes)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)