feat:Added entry to get value "ge" of class SynthesizerTrn

This commit is contained in:
Kaning123 2026-03-19 17:39:55 +08:00
parent 86ac5555e1
commit 5450922d8d

View File

@ -989,10 +989,8 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
def ge_(self, refer, sv_emb, InjectGE=False, GE=None, LoadGE=True):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
@ -1008,6 +1006,7 @@ class SynthesizerTrn(nn.Module):
ge = self.prelu(ge)
return ge
if LoadGE:
if type(refer) == list:
ges = []
for idx, _refer in enumerate(refer):
@ -1016,6 +1015,18 @@ class SynthesizerTrn(nn.Module):
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer, sv_emb)
else:
if InjectGE:
if type(GE) == list:
GE = torch.stack(GE, 0).mean(0)
ge = GE
else:
raise ValueError
return ge
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None,
InjectGE=False,GE=None,LoadGE=True):
ge = self.ge_(refer, sv_emb, InjectGE, GE, LoadGE)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)