mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-12 22:29:50 +08:00
support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
This commit is contained in:
parent
3f46359652
commit
0621259549
@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
3) computes spectrograms from audio files.
|
3) computes spectrograms from audio files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hparams, val=False):
|
def __init__(self, hparams, version=None,val=False):
|
||||||
exp_dir = hparams.exp_dir
|
exp_dir = hparams.exp_dir
|
||||||
self.path2 = "%s/2-name2text.txt" % exp_dir
|
self.path2 = "%s/2-name2text.txt" % exp_dir
|
||||||
self.path4 = "%s/4-cnhubert" % exp_dir
|
self.path4 = "%s/4-cnhubert" % exp_dir
|
||||||
@ -29,8 +29,14 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
assert os.path.exists(self.path2)
|
assert os.path.exists(self.path2)
|
||||||
assert os.path.exists(self.path4)
|
assert os.path.exists(self.path4)
|
||||||
assert os.path.exists(self.path5)
|
assert os.path.exists(self.path5)
|
||||||
|
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||||
|
if self.is_v2Pro:
|
||||||
|
self.path7 = "%s/7-sv_cn" % exp_dir
|
||||||
|
assert os.path.exists(self.path7)
|
||||||
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
||||||
names5 = set(os.listdir(self.path5))
|
names5 = set(os.listdir(self.path5))
|
||||||
|
if self.is_v2Pro:
|
||||||
|
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
|
||||||
self.phoneme_data = {}
|
self.phoneme_data = {}
|
||||||
with open(self.path2, "r", encoding="utf8") as f:
|
with open(self.path2, "r", encoding="utf8") as f:
|
||||||
lines = f.read().strip("\n").split("\n")
|
lines = f.read().strip("\n").split("\n")
|
||||||
@ -40,8 +46,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
if len(tmp) != 4:
|
if len(tmp) != 4:
|
||||||
continue
|
continue
|
||||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||||
|
if self.is_v2Pro:
|
||||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
|
||||||
|
else:
|
||||||
|
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
||||||
tmp = self.audiopaths_sid_text
|
tmp = self.audiopaths_sid_text
|
||||||
leng = len(tmp)
|
leng = len(tmp)
|
||||||
min_num = 100
|
min_num = 100
|
||||||
@ -109,14 +117,21 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
typee = ssl.dtype
|
typee = ssl.dtype
|
||||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
|
if self.is_v2Pro:
|
||||||
|
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
spec = torch.zeros(1025, 100)
|
spec = torch.zeros(1025, 100)
|
||||||
wav = torch.zeros(1, 100 * self.hop_length)
|
wav = torch.zeros(1, 100 * self.hop_length)
|
||||||
ssl = torch.zeros(1, 768, 100)
|
ssl = torch.zeros(1, 768, 100)
|
||||||
text = text[-1:]
|
text = text[-1:]
|
||||||
|
if self.is_v2Pro:
|
||||||
|
sv_emb=torch.zeros(1,20480)
|
||||||
print("load audio or ssl error!!!!!!", audiopath)
|
print("load audio or ssl error!!!!!!", audiopath)
|
||||||
return (ssl, spec, wav, text)
|
if self.is_v2Pro:
|
||||||
|
return (ssl, spec, wav, text,sv_emb)
|
||||||
|
else:
|
||||||
|
return (ssl, spec, wav, text)
|
||||||
|
|
||||||
def get_audio(self, filename):
|
def get_audio(self, filename):
|
||||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||||
@ -177,8 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
class TextAudioSpeakerCollate:
|
class TextAudioSpeakerCollate:
|
||||||
"""Zero-pads model inputs and targets"""
|
"""Zero-pads model inputs and targets"""
|
||||||
|
|
||||||
def __init__(self, return_ids=False):
|
def __init__(self, return_ids=False,version=None):
|
||||||
self.return_ids = return_ids
|
self.return_ids = return_ids
|
||||||
|
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
"""Collate's training batch from normalized text, audio and speaker identities
|
"""Collate's training batch from normalized text, audio and speaker identities
|
||||||
@ -211,6 +227,9 @@ class TextAudioSpeakerCollate:
|
|||||||
ssl_padded.zero_()
|
ssl_padded.zero_()
|
||||||
text_padded.zero_()
|
text_padded.zero_()
|
||||||
|
|
||||||
|
if self.is_v2Pro:
|
||||||
|
sv_embs=torch.FloatTensor(len(batch),20480)
|
||||||
|
|
||||||
for i in range(len(ids_sorted_decreasing)):
|
for i in range(len(ids_sorted_decreasing)):
|
||||||
row = batch[ids_sorted_decreasing[i]]
|
row = batch[ids_sorted_decreasing[i]]
|
||||||
|
|
||||||
@ -230,7 +249,12 @@ class TextAudioSpeakerCollate:
|
|||||||
text_padded[i, : text.size(0)] = text
|
text_padded[i, : text.size(0)] = text
|
||||||
text_lengths[i] = text.size(0)
|
text_lengths[i] = text.size(0)
|
||||||
|
|
||||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
if self.is_v2Pro:
|
||||||
|
sv_embs[i]=row[4]
|
||||||
|
if self.is_v2Pro:
|
||||||
|
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
|
||||||
|
else:
|
||||||
|
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||||
|
|
||||||
|
|
||||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||||
|
@ -586,11 +586,12 @@ class DiscriminatorS(torch.nn.Module):
|
|||||||
|
|
||||||
return x, fmap
|
return x, fmap
|
||||||
|
|
||||||
|
v2pro_set={"v2Pro","v2ProPlus"}
|
||||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
def __init__(self, use_spectral_norm=False):
|
def __init__(self, use_spectral_norm=False,version=None):
|
||||||
super(MultiPeriodDiscriminator, self).__init__()
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
periods = [2, 3, 5, 7, 11]
|
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
|
||||||
|
else:periods = [2, 3, 5, 7, 11]
|
||||||
|
|
||||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||||
@ -786,7 +787,6 @@ class CodePredictor(nn.Module):
|
|||||||
|
|
||||||
return pred_codes.transpose(0, 1)
|
return pred_codes.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrn(nn.Module):
|
class SynthesizerTrn(nn.Module):
|
||||||
"""
|
"""
|
||||||
Synthesizer for Training
|
Synthesizer for Training
|
||||||
@ -886,12 +886,23 @@ class SynthesizerTrn(nn.Module):
|
|||||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||||
self.freeze_quantizer = freeze_quantizer
|
self.freeze_quantizer = freeze_quantizer
|
||||||
|
|
||||||
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
self.is_v2pro=self.version in v2pro_set
|
||||||
|
if self.is_v2pro:
|
||||||
|
self.sv_emb = nn.Linear(20480, gin_channels)
|
||||||
|
self.ge_to512 = nn.Linear(gin_channels, 512)
|
||||||
|
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
||||||
|
|
||||||
|
def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
if self.version == "v1":
|
if self.version == "v1":
|
||||||
ge = self.ref_enc(y * y_mask, y_mask)
|
ge = self.ref_enc(y * y_mask, y_mask)
|
||||||
else:
|
else:
|
||||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||||
|
if self.is_v2pro:
|
||||||
|
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
||||||
|
ge += sv_emb.unsqueeze(-1)
|
||||||
|
ge = self.prelu(ge)
|
||||||
|
ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||||
with maybe_no_grad:
|
with maybe_no_grad:
|
||||||
@ -904,7 +915,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||||
z_p = self.flow(z, y_mask, g=ge)
|
z_p = self.flow(z, y_mask, g=ge)
|
||||||
|
|
||||||
@ -941,8 +952,8 @@ class SynthesizerTrn(nn.Module):
|
|||||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
|
def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None):
|
||||||
def get_ge(refer):
|
def get_ge(refer, sv_emb):
|
||||||
ge = None
|
ge = None
|
||||||
if refer is not None:
|
if refer is not None:
|
||||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
@ -951,16 +962,20 @@ class SynthesizerTrn(nn.Module):
|
|||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
else:
|
else:
|
||||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
|
if self.is_v2pro:
|
||||||
|
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
||||||
|
ge += sv_emb.unsqueeze(-1)
|
||||||
|
ge = self.prelu(ge)
|
||||||
return ge
|
return ge
|
||||||
|
|
||||||
if type(refer) == list:
|
if type(refer) == list:
|
||||||
ges = []
|
ges = []
|
||||||
for _refer in refer:
|
for idx,_refer in enumerate(refer):
|
||||||
ge = get_ge(_refer)
|
ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None)
|
||||||
ges.append(ge)
|
ges.append(ge)
|
||||||
ge = torch.stack(ges, 0).mean(0)
|
ge = torch.stack(ges, 0).mean(0)
|
||||||
else:
|
else:
|
||||||
ge = get_ge(refer)
|
ge = get_ge(refer, sv_emb)
|
||||||
|
|
||||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
@ -968,7 +983,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user