support gpt-sovits v2

support gpt-sovits v2
This commit is contained in:
RVC-Boss 2024-08-02 15:48:22 +08:00 committed by GitHub
parent af2d119573
commit 0d037f8915

View File

@ -1,5 +1,7 @@
import copy
import math
import os
import torch
from torch import nn
from torch.nn import functional as F
@ -879,9 +881,11 @@ class SynthesizerTrn(nn.Module):
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.ref_enc = modules.MelStyleEncoder(
spec_channels, style_vector_dim=gin_channels
)
self.version=os.environ.get("version","v1")
if(self.version=="v1"):
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ["25hz", "50hz"]
@ -893,20 +897,15 @@ class SynthesizerTrn(nn.Module):
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer
# if freeze_quantizer:
# self.ssl_proj.requires_grad_(False)
# self.quantizer.requires_grad_(False)
#self.quantizer.eval()
# self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask)
if(self.version=="v1"):
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
@ -947,7 +946,10 @@ class SynthesizerTrn(nn.Module):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask)
if(self.version=="v1"):
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
@ -974,7 +976,10 @@ class SynthesizerTrn(nn.Module):
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask)
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)