diff --git a/.gitignore b/.gitignore index 6f846a9..28b8a7a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ reference GPT_weights SoVITS_weights TEMP +PortableGit ffmpeg.exe ffprobe.exe - +tmp_audio +trained diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c5f227c..b875151 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -249,8 +249,6 @@ class TTS: if self.configs.is_half and str(self.configs.device)!="cpu": self.bert_model = self.bert_model.half() - - def init_vits_weights(self, weights_path: str): print(f"Loading VITS weights from {weights_path}") self.configs.vits_weights_path = weights_path @@ -437,7 +435,8 @@ class TTS: device:torch.device=torch.device("cpu"), precision:torch.dtype=torch.float32, ): - + # 但是这里不能套,反而会负优化 + # with torch.no_grad(): _data:list = [] index_and_len_list = [] for idx, item in enumerate(data): @@ -485,6 +484,8 @@ class TTS: norm_text_batch = [] bert_max_len = 0 phones_max_len = 0 + # 但是这里也不能套,反而会负优化 + # with torch.no_grad(): for item in item_list: if prompt_data is not None: all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ @@ -567,8 +568,9 @@ class TTS: Stop the inference process. ''' self.stop_flag = True - + # 使用装饰器 + @torch.no_grad() def run(self, inputs:dict): """ Text to speech inference. @@ -616,25 +618,25 @@ class TTS: seed = inputs.get("seed", -1) seed = -1 if seed in ["", None] else seed actual_seed = set_seed(seed) - + if return_fragment: # split_bucket = False print(i18n("分段返回模式已开启")) if split_bucket: split_bucket = False print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) - + if split_bucket: print(i18n("分桶处理模式已开启")) - + if fragment_interval<0.01: fragment_interval = 0.01 print(i18n("分段间隔过小,已自动设置为0.01")) - + no_prompt_text = False if prompt_text in [None, ""]: no_prompt_text = True - + assert text_lang in self.configs.languages if not no_prompt_text: assert prompt_lang in self.configs.languages @@ -643,11 +645,10 @@ class TTS: ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)): raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") - ###### setting reference audio and prompt text preprocessing ######## t0 = ttime() if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): - self.set_ref_audio(ref_audio_path) + self.set_ref_audio(ref_audio_path) if not no_prompt_text: prompt_text = prompt_text.strip("\n") @@ -663,8 +664,7 @@ class TTS: self.prompt_cache["phones"] = phones self.prompt_cache["bert_features"] = bert_features self.prompt_cache["norm_text"] = norm_text - - + ###### text preprocessing ######## t1 = ttime() data:list = None @@ -674,7 +674,7 @@ class TTS: yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16) return - + batch_index_list:list = None data, batch_index_list = self.to_batch(data, prompt_data=self.prompt_cache if not no_prompt_text else None, @@ -692,7 +692,7 @@ class TTS: if i%batch_size == 0: data.append([]) data[-1].append(texts[i]) - + def make_batch(batch_texts): batch_data = [] print(i18n("############ 提取文本Bert特征 ############")) @@ -717,7 +717,8 @@ class TTS: precision=self.precision ) return batch[0] - + + t2 = ttime() try: print("############ 推理 ############") @@ -731,40 +732,43 @@ class TTS: item = make_batch(item) if item is None: continue - + batch_phones:List[torch.LongTensor] = item["phones"] batch_phones_len:torch.LongTensor = item["phones_len"] all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] all_phoneme_lens:torch.LongTensor = item["all_phones_len"] all_bert_features:List[torch.LongTensor] = item["all_bert_features"] norm_text:str = item["norm_text"] - + print(i18n("前端处理后的文本(每句):"), norm_text) if no_prompt_text : prompt = None else: prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) - - with torch.no_grad(): - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - ) + + + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + ) t4 = ttime() t_34 += t4 - t3 - + refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\ .to(dtype=self.precision, device=self.configs.device) - + batch_audio_fragment = [] + # 这里要记得加 torch.no_grad() 不然速度慢一大截 + # with torch.no_grad(): + # ## vits并行推理 method 1 # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) @@ -777,7 +781,7 @@ class TTS: # batch_audio_fragment = (self.vits_model.batched_decode( # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec # )) - + # ## vits并行推理 method 2 pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] upsample_rate = math.prod(self.vits_model.upsample_rates) @@ -790,7 +794,6 @@ class TTS: ).detach()[0, 0, :]) audio_frag_end_idx.insert(0, 0) batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] - # ## vits串行推理 # for i, idx in enumerate(idx_list): @@ -802,7 +805,7 @@ class TTS: # batch_audio_fragment.append( # audio_fragment # ) ###试试重建不带上prompt部分 - + t5 = ttime() t_45 += t5 - t4 if return_fragment: @@ -816,7 +819,7 @@ class TTS: ) else: audio.append(batch_audio_fragment) - + if self.stop_flag: yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16)