diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py index 4a9a50c..1bda2b3 100644 --- a/GPT_SoVITS/module/data_utils.py +++ b/GPT_SoVITS/module/data_utils.py @@ -470,6 +470,216 @@ class TextAudioSpeakerCollateV3: # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths +class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, hparams, val=False): + exp_dir = hparams.exp_dir + self.path2 = "%s/2-name2text.txt" % exp_dir + self.path4 = "%s/4-cnhubert" % exp_dir + self.path5 = "%s/5-wav32k" % exp_dir + assert os.path.exists(self.path2) + assert os.path.exists(self.path4) + assert os.path.exists(self.path5) + names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 + names5 = set(os.listdir(self.path5)) + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + for line in lines: + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1]] + + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) + tmp = self.audiopaths_sid_text + leng = len(tmp) + min_num = 100 + if leng < min_num: + self.audiopaths_sid_text = [] + for _ in range(max(2, int(min_num / leng))): + self.audiopaths_sid_text += tmp + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.val = val + + random.seed(1234) + random.shuffle(self.audiopaths_sid_text) + + print("phoneme_data_len:", len(self.phoneme_data.keys())) + print("wav_data_len:", len(self.audiopaths_sid_text)) + + audiopaths_sid_text_new = [] + lengths = [] + skipped_phone = 0 + skipped_dur = 0 + for audiopath in tqdm(self.audiopaths_sid_text): + try: + phoneme = self.phoneme_data[audiopath][0] + phoneme = phoneme.split(" ") + phoneme_ids = cleaned_text_to_sequence(phoneme, version) + except Exception: + print(f"{audiopath} not in self.phoneme_data !") + skipped_phone += 1 + continue + + size = os.path.getsize("%s/%s" % (self.path5, audiopath)) + duration = size / self.sampling_rate / 2 + + if duration == 0: + print(f"Zero duration for {audiopath}, skipping...") + skipped_dur += 1 + continue + + if 54 > duration > 0.6 or self.val: + audiopaths_sid_text_new.append([audiopath, phoneme_ids]) + lengths.append(size // (2 * self.hop_length)) + else: + skipped_dur += 1 + continue + + print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) + print("total left: ", len(audiopaths_sid_text_new)) + assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo + self.audiopaths_sid_text = audiopaths_sid_text_new + self.lengths = lengths + self.spec_min = -12 + self.spec_max = 2 + + self.filter_length_mel = self.win_length_mel = 1280 + self.hop_length_mel = 320 + self.n_mel_channels = 100 + self.sampling_rate_mel = 32000 + self.mel_fmin = 0 + self.mel_fmax = None + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def get_audio_text_speaker_pair(self, audiopath_sid_text): + audiopath, phoneme_ids = audiopath_sid_text + text = torch.FloatTensor(phoneme_ids) + try: + spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath)) + with torch.no_grad(): + ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") + if ssl.shape[-1] != spec.shape[-1]: + typee = ssl.dtype + ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) + ssl.requires_grad = False + except: + traceback.print_exc() + mel = torch.zeros(100, 192) + # wav = torch.zeros(1, 96 * self.hop_length) + spec = torch.zeros(1025, 96) + ssl = torch.zeros(1, 768, 96) + text = text[-1:] + print("load audio or ssl error!!!!!!", audiopath) + return (ssl, spec, mel, text) + + def get_audio(self, filename): + audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768 + audio = torch.FloatTensor(audio_array) # /32768 + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False + ) + spec = torch.squeeze(spec, 0) + spec1 = spectrogram_torch(audio_norm, 1280,32000, 320, 1280,center=False) + mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None) + mel = self.norm_spec(torch.squeeze(mel, 0)) + return spec, mel + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + # with torch.no_grad(): + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + + def __len__(self): + return len(self.audiopaths_sid_text) + + +class TextAudioSpeakerCollateV4: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # ssl, spec, wav,mel, text + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True) + # (ssl, spec,mel, text) + max_ssl_len = max([x[0].size(2) for x in batch]) + max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) + max_spec_len = max([x[1].size(1) for x in batch]) + max_spec_len = int(2 * ((max_spec_len // 2) + 1)) + # max_wav_len = max([x[2].size(1) for x in batch]) + max_text_len = max([x[3].size(0) for x in batch]) + + ssl_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + text_lengths = torch.LongTensor(len(batch)) + # wav_lengths = torch.LongTensor(len(batch)) + mel_lengths = torch.LongTensor(len(batch)) + + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len*2) + ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) + text_padded = torch.LongTensor(len(batch), max_text_len) + # wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + + spec_padded.zero_() + mel_padded.zero_() + ssl_padded.zero_() + text_padded.zero_() + # wav_padded.zero_() + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + # ssl, spec, wav,mel, text + ssl = row[0] + ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :] + ssl_lengths[i] = ssl.size(2) + + spec = row[1] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + # wav = row[2] + # wav_padded[i, :, :wav.size(1)] = wav + # wav_lengths[i] = wav.size(1) + + mel = row[2] + mel_padded[i, :, : mel.size(1)] = mel + mel_lengths[i] = mel.size(1) + + text = row[3] + text_padded[i, : text.size(0)] = text + text_lengths[i] = text.size(0) + + # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths + return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths + class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset): """ diff --git a/GPT_SoVITS/module/mel_processing.py b/GPT_SoVITS/module/mel_processing.py index 7718b4a..7a17c54 100644 --- a/GPT_SoVITS/module/mel_processing.py +++ b/GPT_SoVITS/module/mel_processing.py @@ -38,89 +38,72 @@ hann_window = {} def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + if torch.min(y) < -1.2: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.2: + print('max value is ', torch.max(y)) global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_size) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + dtype_device = str(y.dtype) + '_' + str(y.device) + # wnsize_dtype_device = str(win_size) + '_' + dtype_device + key = "%s-%s-%s-%s-%s" %(dtype_device,n_fft, sampling_rate, hop_size, win_size) + # if wnsize_dtype_device not in hann_window: + if key not in hann_window: + # hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), - mode="reflect", - ) + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, - ) + # spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[key], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8) return spec def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): global mel_basis - dtype_device = str(spec.dtype) + "_" + str(spec.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: + dtype_device = str(spec.dtype) + '_' + str(spec.device) + # fmax_dtype_device = str(fmax) + '_' + dtype_device + key = "%s-%s-%s-%s-%s-%s"%(dtype_device,n_fft, num_mels, sampling_rate, fmin, fmax) + # if fmax_dtype_device not in mel_basis: + if key not in mel_basis: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + # mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + # spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = torch.matmul(mel_basis[key], spec) spec = spectral_normalize_torch(spec) return spec + def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + if torch.min(y) < -1.2: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.2: + print('max value is ', torch.max(y)) global mel_basis, hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - wnsize_dtype_device = str(win_size) + "_" + dtype_device + dtype_device = str(y.dtype) + '_' + str(y.device) + # fmax_dtype_device = str(fmax) + '_' + dtype_device + fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s"%(dtype_device,n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax) + # wnsize_dtype_device = str(win_size) + '_' + dtype_device + wnsize_dtype_device = fmax_dtype_device if fmax_dtype_device not in mel_basis: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), - mode="reflect", - ) + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, - ) + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8) spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = spectral_normalize_torch(spec) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index aac520a..21f60d9 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -414,7 +414,7 @@ class Generator(torch.nn.Module): upsample_rates, upsample_initial_channel, upsample_kernel_sizes, - gin_channels=0, + gin_channels=0,is_bias=False, ): super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) @@ -442,7 +442,7 @@ class Generator(torch.nn.Module): for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock(ch, k, d)) - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias) self.ups.apply(init_weights) if gin_channels != 0: @@ -1173,7 +1173,7 @@ class SynthesizerTrnV3(nn.Module): quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) fea = self.bridge(x) - fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT + fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT fea, y_mask_ = self.wns1( fea, mel_lengths, ge ) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate. @@ -1196,9 +1196,9 @@ class SynthesizerTrnV3(nn.Module): ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) if speed == 1: - sizee = int(codes.size(2) * 2.5 * 1.5) + sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4)) else: - sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1 + sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4) / speed) + 1 y_lengths1 = torch.LongTensor([sizee]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) @@ -1207,7 +1207,7 @@ class SynthesizerTrnV3(nn.Module): quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) fea = self.bridge(x) - fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT + fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT ####more wn paramter to learn mel fea, y_mask_ = self.wns1(fea, y_lengths1, ge) return fea, ge