From 2b142405b81fd2b57b49b0d5a7682b6e46e1db36 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 7 Aug 2024 11:28:05 +0800 Subject: [PATCH] support timebre mixing support timebre mixing --- GPT_SoVITS/module/models.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index ec7407e..6bfee08 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1,8 +1,7 @@ -import warnings -warnings.filterwarnings("ignore") import copy import math import os +import pdb import torch from torch import nn @@ -984,16 +983,26 @@ class SynthesizerTrn(nn.Module): @torch.no_grad() def decode(self, codes, text, refer, noise_scale=0.5,speed=1): - ge = None - if refer is not None: - refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) - refer_mask = torch.unsqueeze( - commons.sequence_mask(refer_lengths, refer.size(2)), 1 - ).to(refer.dtype) - if (self.version == "v1"): - ge = self.ref_enc(refer * refer_mask, refer_mask) - else: - ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + def get_ge(refer): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze( + commons.sequence_mask(refer_lengths, refer.size(2)), 1 + ).to(refer.dtype) + if (self.version == "v1"): + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + return ge + if(type(refer)==list): + ges=[] + for _refer in refer: + ge=get_ge(_refer) + ges.append(ge) + ge=torch.stack(ges,0).mean(0) + else: + ge=get_ge(refer) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)