support gpt-sovits v4

support gpt-sovits v4
This commit is contained in:
RVC-Boss 2025-04-20 14:53:42 +08:00 committed by GitHub
parent c6cb6b45f3
commit 50e9ba0218
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 255 additions and 62 deletions

View File

@ -470,6 +470,216 @@ class TextAudioSpeakerCollateV3:
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, hparams, val=False):
exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
self.path5 = "%s/5-wav32k" % exp_dir
assert os.path.exists(self.path2)
assert os.path.exists(self.path4)
assert os.path.exists(self.path5)
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5))
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
self.filter_length = hparams.filter_length
self.hop_length = hparams.hop_length
self.win_length = hparams.win_length
self.sampling_rate = hparams.sampling_rate
self.val = val
random.seed(1234)
random.shuffle(self.audiopaths_sid_text)
print("phoneme_data_len:", len(self.phoneme_data.keys()))
print("wav_data_len:", len(self.audiopaths_sid_text))
audiopaths_sid_text_new = []
lengths = []
skipped_phone = 0
skipped_dur = 0
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
skipped_phone += 1
continue
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
duration = size / self.sampling_rate / 2
if duration == 0:
print(f"Zero duration for {audiopath}, skipping...")
skipped_dur += 1
continue
if 54 > duration > 0.6 or self.val:
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
lengths.append(size // (2 * self.hop_length))
else:
skipped_dur += 1
continue
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
print("total left: ", len(audiopaths_sid_text_new))
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
self.spec_min = -12
self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1280
self.hop_length_mel = 320
self.n_mel_channels = 100
self.sampling_rate_mel = 32000
self.mel_fmin = 0
self.mel_fmax = None
def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
def get_audio_text_speaker_pair(self, audiopath_sid_text):
audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids)
try:
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
except:
traceback.print_exc()
mel = torch.zeros(100, 192)
# wav = torch.zeros(1, 96 * self.hop_length)
spec = torch.zeros(1025, 96)
ssl = torch.zeros(1, 768, 96)
text = text[-1:]
print("load audio or ssl error!!!!!!", audiopath)
return (ssl, spec, mel, text)
def get_audio(self, filename):
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(audio_norm, 1280,32000, 320, 1280,center=False)
mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None)
mel = self.norm_spec(torch.squeeze(mel, 0))
return spec, mel
def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
return sid
def __getitem__(self, index):
# with torch.no_grad():
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
def __len__(self):
return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV4:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
# (ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
max_spec_len = max([x[1].size(1) for x in batch])
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
# max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[3].size(0) for x in batch])
ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
text_lengths = torch.LongTensor(len(batch))
# wav_lengths = torch.LongTensor(len(batch))
mel_lengths = torch.LongTensor(len(batch))
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len*2)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spec_padded.zero_()
mel_padded.zero_()
ssl_padded.zero_()
text_padded.zero_()
# wav_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text
ssl = row[0]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
# wav = row[2]
# wav_padded[i, :, :wav.size(1)] = wav
# wav_lengths[i] = wav.size(1)
mel = row[2]
mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1)
text = row[3]
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset): class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
""" """

View File

@ -38,89 +38,72 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0: if torch.min(y) < -1.2:
print("min value is ", torch.min(y)) print('min value is ', torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.2:
print("max value is ", torch.max(y)) print('max value is ', torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device # wnsize_dtype_device = str(win_size) + '_' + dtype_device
if wnsize_dtype_device not in hann_window: key = "%s-%s-%s-%s-%s" %(dtype_device,n_fft, sampling_rate, hop_size, win_size)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) # if wnsize_dtype_device not in hann_window:
if key not in hann_window:
# hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( # spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
y, spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[key],
n_fft, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
return spec return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device) dtype_device = str(spec.dtype) + '_' + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device # fmax_dtype_device = str(fmax) + '_' + dtype_device
if fmax_dtype_device not in mel_basis: key = "%s-%s-%s-%s-%s-%s"%(dtype_device,n_fft, num_mels, sampling_rate, fmin, fmax)
# if fmax_dtype_device not in mel_basis:
if key not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) # mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec) mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
# spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = torch.matmul(mel_basis[key], spec)
spec = spectral_normalize_torch(spec) spec = spectral_normalize_torch(spec)
return spec return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0: if torch.min(y) < -1.2:
print("min value is ", torch.min(y)) print('min value is ', torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.2:
print("max value is ", torch.max(y)) print('max value is ', torch.max(y))
global mel_basis, hann_window global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + '_' + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device # fmax_dtype_device = str(fmax) + '_' + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s"%(dtype_device,n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax)
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
wnsize_dtype_device = fmax_dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
y, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec) spec = spectral_normalize_torch(spec)

View File

@ -414,7 +414,7 @@ class Generator(torch.nn.Module):
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
gin_channels=0, gin_channels=0,is_bias=False,
): ):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
@ -442,7 +442,7 @@ class Generator(torch.nn.Module):
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
self.ups.apply(init_weights) self.ups.apply(init_weights)
if gin_channels != 0: if gin_channels != 0:
@ -1173,7 +1173,7 @@ class SynthesizerTrnV3(nn.Module):
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea = self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT
fea, y_mask_ = self.wns1( fea, y_mask_ = self.wns1(
fea, mel_lengths, ge fea, mel_lengths, ge
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate. ) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
@ -1196,9 +1196,9 @@ class SynthesizerTrnV3(nn.Module):
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
if speed == 1: if speed == 1:
sizee = int(codes.size(2) * 2.5 * 1.5) sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4))
else: else:
sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1 sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4) / speed) + 1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device) y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@ -1207,7 +1207,7 @@ class SynthesizerTrnV3(nn.Module):
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea = self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT
####more wn paramter to learn mel ####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge) fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea, ge return fea, ge