diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 338e88d..623da80 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1251,7 +1251,7 @@ class SynthesizerTrnV3(nn.Module): return cfm_loss @torch.no_grad() - def decode_encp(self, codes,text, refer,ge=None): + def decode_encp(self, codes,text, refer,ge=None,speed=1): # print(2333333,refer.shape) # ge=None if(ge==None): @@ -1259,13 +1259,17 @@ class SynthesizerTrnV3(nn.Module): refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device) - y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device) + if speed==1: + sizee=int(codes.size(2)*2.5*1.5) + else: + sizee=int(codes.size(2)*2.5*1.5/speed)+1 + y_lengths1 = torch.LongTensor([sizee]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == '25hz': quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed) fea=self.bridge(x) fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT ####more wn paramter to learn mel