diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 52402e9..5cd618e 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -3,7 +3,7 @@ import math import os, sys, gc import random import traceback - +import time import torchaudio from tqdm import tqdm now_dir = os.getcwd() @@ -908,11 +908,14 @@ class TTS: split_bucket = False print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) - if split_bucket and speed_factor==1.0: + if split_bucket and speed_factor==1.0 and not (self.configs.is_v3_synthesizer and parallel_infer): print(i18n("分桶处理模式已开启")) elif speed_factor!=1.0: print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理")) split_bucket = False + elif self.configs.is_v3_synthesizer and parallel_infer: + print(i18n("当开启并行推理模式时,SoVits V3模型不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False else: print(i18n("分桶处理模式已关闭")) @@ -936,7 +939,7 @@ class TTS: raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") ###### setting reference audio and prompt text preprocessing ######## - t0 = ttime() + t0 = time.perf_counter() if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): if not os.path.exists(ref_audio_path): raise ValueError(f"{ref_audio_path} not exists") @@ -975,7 +978,7 @@ class TTS: ###### text preprocessing ######## - t1 = ttime() + t1 = time.perf_counter() data:list = None if not return_fragment: data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) @@ -1027,7 +1030,7 @@ class TTS: return batch[0] - t2 = ttime() + t2 = time.perf_counter() try: print("############ 推理 ############") ###### inference ###### @@ -1036,7 +1039,7 @@ class TTS: audio = [] output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000 for item in data: - t3 = ttime() + t3 = time.perf_counter() if return_fragment: item = make_batch(item) if item is None: @@ -1071,7 +1074,7 @@ class TTS: max_len=max_len, repetition_penalty=repetition_penalty, ) - t4 = ttime() + t4 = time.perf_counter() t_34 += t4 - t3 refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]] @@ -1094,6 +1097,7 @@ class TTS: print(f"############ {i18n('合成音频')} ############") if not self.configs.is_v3_synthesizer: if speed_factor == 1.0: + print(f"{i18n('并行合成中')}...") # ## vits并行推理 method 2 pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] upsample_rate = math.prod(self.vits_model.upsample_rates) @@ -1118,17 +1122,28 @@ class TTS: audio_fragment ) ###试试重建不带上prompt部分 else: - for i, idx in enumerate(tqdm(idx_list)): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.v3_synthesis( - _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps - ) - batch_audio_fragment.append( - audio_fragment - ) + if parallel_infer: + print(f"{i18n('并行合成中')}...") + audio_fragments = self.v3_synthesis_batched_infer( + idx_list, + pred_semantic_list, + batch_phones, + speed=speed_factor, + sample_steps=sample_steps + ) + batch_audio_fragment.extend(audio_fragments) + else: + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.v3_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.append( + audio_fragment + ) - t5 = ttime() + t5 = time.perf_counter() t_45 += t5 - t4 if return_fragment: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) @@ -1219,13 +1234,13 @@ class TTS: if super_sampling: print(f"############ {i18n('音频超采样')} ############") - t1 = ttime() + t1 = time.perf_counter() self.init_sr_model() if not self.sr_model_not_exist: audio,sr=self.sr_model(audio.unsqueeze(0),sr) max_audio=np.abs(audio).max() if max_audio > 1: audio /= max_audio - t2 = ttime() + t2 = time.perf_counter() print(f"超采样用时:{t2-t1:.3f}s") else: audio = audio.cpu().numpy() @@ -1260,7 +1275,7 @@ class TTS: ref_audio = ref_audio.mean(0).unsqueeze(0) if ref_sr!=24000: ref_audio=resample(ref_audio, ref_sr, self.configs.device) - # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) @@ -1285,15 +1300,156 @@ class TTS: cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) cfm_res = cfm_res[:, :, mel2.shape[2]:] - mel2 = cfm_res[:, :, -T_min:] + mel2 = cfm_res[:, :, -T_min:] fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) - cmf_res = torch.cat(cfm_resss, 2) - cmf_res = denorm_spec(cmf_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + with torch.inference_mode(): - wav_gen = self.bigvgan_model(cmf_res) + wav_gen = self.bigvgan_model(cfm_res) audio=wav_gen[0][0]#.cpu().detach().numpy() return audio + + + + def v3_synthesis_batched_infer(self, + idx_list:List[int], + semantic_tokens_list:List[torch.Tensor], + batch_phones:List[torch.Tensor], + speed:float=1.0, + sample_steps:int=32 + )->List[torch.Tensor]: + + prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) + refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) + + fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio:torch.Tensor = self.prompt_cache["raw_audio"] + ref_sr = self.prompt_cache["raw_sr"] + ref_audio=ref_audio.to(self.configs.device).float() + if (ref_audio.shape[0] == 2): + ref_audio = ref_audio.mean(0).unsqueeze(0) + if ref_sr!=24000: + ref_audio=resample(ref_audio, ref_sr, self.configs.device) + + mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if (T_min > 468): + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min + + mel2=mel2.to(self.precision) + + + # #### batched inference + overlapped_len = 12 + feat_chunks = [] + feat_lens = [] + feat_list = [] + + for i, idx in enumerate(idx_list): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + semantic_tokens = semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 + feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + feat_list.append(feat) + feat_lens.append(feat.shape[2]) + + feats = torch.cat(feat_list, 2) + feats_padded = F.pad(feats, (overlapped_len,0), "constant", 0) + pos = 0 + padding_len = 0 + while True: + if pos ==0: + chunk = feats_padded[:, :, pos:pos + chunk_len] + else: + pos = pos - overlapped_len + chunk = feats_padded[:, :, pos:pos + chunk_len] + pos += chunk_len + if (chunk.shape[-1] == 0): break + + # padding for the last chunk + padding_len = chunk_len - chunk.shape[2] + if padding_len != 0: + chunk = F.pad(chunk, (0,padding_len), "constant", 0) + feat_chunks.append(chunk) + + + + feat_chunks = torch.cat(feat_chunks, 0) + bs = feat_chunks.shape[0] + fea_ref = fea_ref.repeat(bs,1,1) + fea = torch.cat([fea_ref, feat_chunks], 2).transpose(2, 1) + pred_spec = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) + pred_spec = pred_spec[:, :, -chunk_len:] + dd = pred_spec.shape[1] + pred_spec = pred_spec.permute(1, 0, 2).contiguous().view(dd, -1).unsqueeze(0) + # pred_spec = pred_spec[..., :-padding_len] + + + pred_spec = denorm_spec(pred_spec) + + with torch.no_grad(): + wav_gen = self.bigvgan_model(pred_spec) + audio = wav_gen[0][0]#.cpu().detach().numpy() + + + audio_fragments = [] + upsample_rate = 256 + pos = 0 + + while pos < audio.shape[-1]: + audio_fragment = audio[pos:pos+chunk_len*upsample_rate] + audio_fragments.append(audio_fragment) + pos += chunk_len*upsample_rate + + audio = self.sola_algorithm(audio_fragments, overlapped_len*upsample_rate) + audio = audio[overlapped_len*upsample_rate:-padding_len*upsample_rate] + + audio_fragments = [] + for feat_len in feat_lens: + audio_fragment = audio[:feat_len*upsample_rate] + audio_fragments.append(audio_fragment) + audio = audio[feat_len*upsample_rate:] + + + return audio_fragments + + + + def sola_algorithm(self, + audio_fragments:List[torch.Tensor], + overlap_len:int, + ): + + for i in range(len(audio_fragments)-1): + f1 = audio_fragments[i] + f2 = audio_fragments[i+1] + w1 = f1[-overlap_len:] + w2 = f2[:overlap_len] + assert w1.shape == w2.shape + corr = F.conv1d(w1.view(1,1,-1), w2.view(1,1,-1),padding=w2.shape[-1]//2).view(-1)[:-1] + idx = corr.argmax() + f1_ = f1[:-(overlap_len-idx)] + audio_fragments[i] = f1_ + + f2_ = f2[idx:] + window = torch.hann_window((overlap_len-idx)*2, device=f1.device, dtype=f1.dtype) + f2_[:(overlap_len-idx)] = window[:(overlap_len-idx)]*f2_[:(overlap_len-idx)] + window[(overlap_len-idx):]*f1[-(overlap_len-idx):] + audio_fragments[i+1] = f2_ + + + return torch.cat(audio_fragments, 0) + + +