diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 195a59f..92e6634 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1,5 +1,7 @@ import copy import math +import os + import torch from torch import nn from torch.nn import functional as F @@ -879,9 +881,11 @@ class SynthesizerTrn(nn.Module): inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels ) - self.ref_enc = modules.MelStyleEncoder( - spec_channels, style_vector_dim=gin_channels - ) + self.version=os.environ.get("version","v1") + if(self.version=="v1"): + self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) + else: + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ssl_dim = 768 assert semantic_frame_rate in ["25hz", "50hz"] @@ -893,20 +897,15 @@ class SynthesizerTrn(nn.Module): self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.freeze_quantizer = freeze_quantizer - # if freeze_quantizer: - # self.ssl_proj.requires_grad_(False) - # self.quantizer.requires_grad_(False) - #self.quantizer.eval() - # self.enc_p.text_embedding.requires_grad_(False) - # self.enc_p.encoder_text.requires_grad_(False) - # self.enc_p.mrte.requires_grad_(False) def forward(self, ssl, y, y_lengths, text, text_lengths): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y.dtype ) - ge = self.ref_enc(y * y_mask, y_mask) - + if(self.version=="v1"): + ge = self.ref_enc(y * y_mask, y_mask) + else: + ge = self.ref_enc(y[:,:704] * y_mask, y_mask) with autocast(enabled=False): maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() with maybe_no_grad: @@ -947,7 +946,10 @@ class SynthesizerTrn(nn.Module): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y.dtype ) - ge = self.ref_enc(y * y_mask, y_mask) + if(self.version=="v1"): + ge = self.ref_enc(y * y_mask, y_mask) + else: + ge = self.ref_enc(y[:,:704] * y_mask, y_mask) ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0]) @@ -974,7 +976,10 @@ class SynthesizerTrn(nn.Module): refer_mask = torch.unsqueeze( commons.sequence_mask(refer_lengths, refer.size(2)), 1 ).to(refer.dtype) - ge = self.ref_enc(refer * refer_mask, refer_mask) + if (self.version == "v1"): + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)