mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-05 22:08:15 +08:00
feat:Added entry to get value "ge" of class SynthesizerTrn
This commit is contained in:
parent
86ac5555e1
commit
5450922d8d
@ -989,10 +989,8 @@ class SynthesizerTrn(nn.Module):
|
||||
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
def ge_(self, refer, sv_emb, InjectGE=False, GE=None, LoadGE=True):
|
||||
def get_ge(refer, sv_emb):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
@ -1008,14 +1006,27 @@ class SynthesizerTrn(nn.Module):
|
||||
ge = self.prelu(ge)
|
||||
return ge
|
||||
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for idx, _refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||
ges.append(ge)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
if LoadGE:
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for idx, _refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||
ges.append(ge)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
else:
|
||||
ge = get_ge(refer, sv_emb)
|
||||
else:
|
||||
ge = get_ge(refer, sv_emb)
|
||||
if InjectGE:
|
||||
if type(GE) == list:
|
||||
GE = torch.stack(GE, 0).mean(0)
|
||||
ge = GE
|
||||
else:
|
||||
raise ValueError
|
||||
return ge
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None,
|
||||
InjectGE=False,GE=None,LoadGE=True):
|
||||
ge = self.ge_(refer, sv_emb, InjectGE, GE, LoadGE)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user