support timebre mixing

support timebre mixing
This commit is contained in:
RVC-Boss 2024-08-07 11:28:05 +08:00 committed by GitHub
parent 20ef716431
commit 2b142405b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,7 @@
import warnings
warnings.filterwarnings("ignore")
import copy
import math
import os
import pdb
import torch
from torch import nn
@ -984,16 +983,26 @@ class SynthesizerTrn(nn.Module):
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
def get_ge(refer):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge
if(type(refer)==list):
ges=[]
for _refer in refer:
ge=get_ge(_refer)
ges.append(ge)
ge=torch.stack(ges,0).mean(0)
else:
ge=get_ge(refer)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)