export 支持语速设置

This commit is contained in:
csh 2024-10-25 23:33:27 +08:00
parent 81413894d0
commit 18f7e8260d

View File

@ -231,7 +231,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge):
def forward(self, y, text, ge, speed=1):
y_mask = torch.ones_like(y[:1,:1,:])
y = self.ssl_proj(y * y_mask) * y_mask
@ -244,6 +244,9 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if(speed!=1):
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
@ -887,7 +890,7 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer):
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
@ -900,10 +903,10 @@ class SynthesizerTrn(nn.Module):
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, text, ge
quantized, text, ge, speed
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)