mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-11 19:09:51 +08:00
condition cache
This commit is contained in:
parent
8158a97909
commit
88a1f88891
@ -143,7 +143,9 @@ class DiT(nn.Module):
|
|||||||
drop_audio_cond=False, # cfg for cond audio
|
drop_audio_cond=False, # cfg for cond audio
|
||||||
drop_text=False, # cfg for text
|
drop_text=False, # cfg for text
|
||||||
# mask: bool["b n"] | None = None, # noqa: F722
|
# mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
infer=False, # bool
|
||||||
|
text_cache=None, # torch tensor as text_embed
|
||||||
|
dt_cache=None, # torch tensor as dt
|
||||||
):
|
):
|
||||||
|
|
||||||
x=x0.transpose(2,1)
|
x=x0.transpose(2,1)
|
||||||
@ -157,9 +159,16 @@ class DiT(nn.Module):
|
|||||||
|
|
||||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||||
t = self.time_embed(time)
|
t = self.time_embed(time)
|
||||||
|
if infer and dt_cache is not None:
|
||||||
|
dt = dt_cache
|
||||||
|
else:
|
||||||
dt = self.d_embed(dt_base_bootstrap)
|
dt = self.d_embed(dt_base_bootstrap)
|
||||||
t+=dt
|
t += dt
|
||||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
|
|
||||||
|
if infer and text_cache is not None:
|
||||||
|
text_embed = text_cache
|
||||||
|
else:
|
||||||
|
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||||
|
|
||||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||||
@ -179,4 +188,7 @@ class DiT(nn.Module):
|
|||||||
x = self.norm_out(x, t)
|
x = self.norm_out(x, t)
|
||||||
output = self.proj_out(x)
|
output = self.proj_out(x)
|
||||||
|
|
||||||
|
if infer:
|
||||||
|
return output, text_embed, dt
|
||||||
|
else:
|
||||||
return output
|
return output
|
@ -1059,6 +1059,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
ssl = self.ssl_proj(x)
|
ssl = self.ssl_proj(x)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
return codes.transpose(0, 1)
|
return codes.transpose(0, 1)
|
||||||
|
|
||||||
class CFM(torch.nn.Module):
|
class CFM(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1073,6 +1074,8 @@ class CFM(torch.nn.Module):
|
|||||||
|
|
||||||
self.criterion = torch.nn.MSELoss()
|
self.criterion = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
self.use_conditioner_cache = True
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
|
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
|
||||||
"""Forward diffusion"""
|
"""Forward diffusion"""
|
||||||
@ -1085,13 +1088,24 @@ class CFM(torch.nn.Module):
|
|||||||
mu=mu.transpose(2,1)
|
mu=mu.transpose(2,1)
|
||||||
t = 0
|
t = 0
|
||||||
d = 1 / n_timesteps
|
d = 1 / n_timesteps
|
||||||
|
text_cache = None
|
||||||
|
text_cfg_cache = None
|
||||||
|
dt_cache = None
|
||||||
|
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
||||||
for j in range(n_timesteps):
|
for j in range(n_timesteps):
|
||||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
v_pred, text_emb, dt = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache)
|
||||||
|
v_pred = v_pred.transpose(2, 1)
|
||||||
|
if self.use_conditioner_cache:
|
||||||
|
text_cache = text_emb
|
||||||
|
dt_cache = dt
|
||||||
if inference_cfg_rate>1e-5:
|
if inference_cfg_rate>1e-5:
|
||||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
neg, text_cfg_emb, _ = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True, infer=True, text_cache=text_cfg_cache, dt_cache=dt_cache)
|
||||||
|
neg = neg.transpose(2, 1)
|
||||||
|
if self.use_conditioner_cache:
|
||||||
|
text_cfg_cache = text_cfg_emb
|
||||||
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||||
x = x + d * v_pred
|
x = x + d * v_pred
|
||||||
t = t + d
|
t = t + d
|
||||||
|
Loading…
x
Reference in New Issue
Block a user