mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-28 04:49:01 +08:00
support gpt-sovits v2
support gpt-sovits v2
This commit is contained in:
parent
af2d119573
commit
0d037f8915
@ -1,5 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -879,9 +881,11 @@ class SynthesizerTrn(nn.Module):
|
|||||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ref_enc = modules.MelStyleEncoder(
|
self.version=os.environ.get("version","v1")
|
||||||
spec_channels, style_vector_dim=gin_channels
|
if(self.version=="v1"):
|
||||||
)
|
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||||
|
else:
|
||||||
|
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
||||||
|
|
||||||
ssl_dim = 768
|
ssl_dim = 768
|
||||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||||
@ -893,20 +897,15 @@ class SynthesizerTrn(nn.Module):
|
|||||||
|
|
||||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||||
self.freeze_quantizer = freeze_quantizer
|
self.freeze_quantizer = freeze_quantizer
|
||||||
# if freeze_quantizer:
|
|
||||||
# self.ssl_proj.requires_grad_(False)
|
|
||||||
# self.quantizer.requires_grad_(False)
|
|
||||||
#self.quantizer.eval()
|
|
||||||
# self.enc_p.text_embedding.requires_grad_(False)
|
|
||||||
# self.enc_p.encoder_text.requires_grad_(False)
|
|
||||||
# self.enc_p.mrte.requires_grad_(False)
|
|
||||||
|
|
||||||
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||||
y.dtype
|
y.dtype
|
||||||
)
|
)
|
||||||
ge = self.ref_enc(y * y_mask, y_mask)
|
if(self.version=="v1"):
|
||||||
|
ge = self.ref_enc(y * y_mask, y_mask)
|
||||||
|
else:
|
||||||
|
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||||
with maybe_no_grad:
|
with maybe_no_grad:
|
||||||
@ -947,7 +946,10 @@ class SynthesizerTrn(nn.Module):
|
|||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||||
y.dtype
|
y.dtype
|
||||||
)
|
)
|
||||||
ge = self.ref_enc(y * y_mask, y_mask)
|
if(self.version=="v1"):
|
||||||
|
ge = self.ref_enc(y * y_mask, y_mask)
|
||||||
|
else:
|
||||||
|
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||||
|
|
||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
||||||
@ -974,7 +976,10 @@ class SynthesizerTrn(nn.Module):
|
|||||||
refer_mask = torch.unsqueeze(
|
refer_mask = torch.unsqueeze(
|
||||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||||
).to(refer.dtype)
|
).to(refer.dtype)
|
||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
if (self.version == "v1"):
|
||||||
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
else:
|
||||||
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
|
|
||||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user