support timebre mixing

support timebre mixing
This commit is contained in:
RVC-Boss 2024-08-07 11:28:05 +08:00 committed by GitHub
parent 20ef716431
commit 2b142405b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,7 @@
import warnings
warnings.filterwarnings("ignore")
import copy import copy
import math import math
import os import os
import pdb
import torch import torch
from torch import nn from torch import nn
@ -984,16 +983,26 @@ class SynthesizerTrn(nn.Module):
@torch.no_grad() @torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5,speed=1): def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
ge = None def get_ge(refer):
if refer is not None: ge = None
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) if refer is not None:
refer_mask = torch.unsqueeze( refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
commons.sequence_mask(refer_lengths, refer.size(2)), 1 refer_mask = torch.unsqueeze(
).to(refer.dtype) commons.sequence_mask(refer_lengths, refer.size(2)), 1
if (self.version == "v1"): ).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask) if (self.version == "v1"):
else: ge = self.ref_enc(refer * refer_mask, refer_mask)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge
if(type(refer)==list):
ges=[]
for _refer in refer:
ge=get_ge(_refer)
ges.append(ge)
ge=torch.stack(ges,0).mean(0)
else:
ge=get_ge(refer)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)