From 5450922d8d03063ce4267b6342897cfbab82cfd5 Mon Sep 17 00:00:00 2001 From: Kaning123 Date: Thu, 19 Mar 2026 17:39:55 +0800 Subject: [PATCH] feat:Added entry to get value "ge" of class SynthesizerTrn --- GPT_SoVITS/module/models.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 348ddb3f..9b47ef90 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -989,10 +989,8 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o, y_mask, (z, z_p, m_p, logs_p) - - @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): + def ge_(self, refer, sv_emb, InjectGE=False, GE=None, LoadGE=True): def get_ge(refer, sv_emb): ge = None if refer is not None: @@ -1007,15 +1005,28 @@ class SynthesizerTrn(nn.Module): ge += sv_emb.unsqueeze(-1) ge = self.prelu(ge) return ge - - if type(refer) == list: - ges = [] - for idx, _refer in enumerate(refer): - ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) - ges.append(ge) - ge = torch.stack(ges, 0).mean(0) + + if LoadGE: + if type(refer) == list: + ges = [] + for idx, _refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) + ges.append(ge) + ge = torch.stack(ges, 0).mean(0) + else: + ge = get_ge(refer, sv_emb) else: - ge = get_ge(refer, sv_emb) + if InjectGE: + if type(GE) == list: + GE = torch.stack(GE, 0).mean(0) + ge = GE + else: + raise ValueError + return ge + @torch.no_grad() + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, + InjectGE=False,GE=None,LoadGE=True): + ge = self.ge_(refer, sv_emb, InjectGE, GE, LoadGE) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)