mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
support gpt-sovits v2
support gpt-sovits v2
This commit is contained in:
parent
af2d119573
commit
0d037f8915
@ -1,5 +1,7 @@
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@ -879,9 +881,11 @@ class SynthesizerTrn(nn.Module):
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
spec_channels, style_vector_dim=gin_channels
|
||||
)
|
||||
self.version=os.environ.get("version","v1")
|
||||
if(self.version=="v1"):
|
||||
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||
else:
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
@ -893,20 +897,15 @@ class SynthesizerTrn(nn.Module):
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
# if freeze_quantizer:
|
||||
# self.ssl_proj.requires_grad_(False)
|
||||
# self.quantizer.requires_grad_(False)
|
||||
#self.quantizer.eval()
|
||||
# self.enc_p.text_embedding.requires_grad_(False)
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
|
||||
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
|
||||
if(self.version=="v1"):
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
else:
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
with autocast(enabled=False):
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
with maybe_no_grad:
|
||||
@ -947,7 +946,10 @@ class SynthesizerTrn(nn.Module):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
if(self.version=="v1"):
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
else:
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
||||
@ -974,7 +976,10 @@ class SynthesizerTrn(nn.Module):
|
||||
refer_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||
).to(refer.dtype)
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
if (self.version == "v1"):
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user