From cae976ef5af5db171cc1795a765e179954290b57 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Sun, 10 Mar 2024 01:57:04 +0800 Subject: [PATCH] =?UTF-8?q?=20=20=20=20=E5=A2=9E=E5=8A=A0=E4=BA=86?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=20=20=20GPT=5FSoVITS/TTS=5Finfer=5Fpack/TTS.?= =?UTF-8?q?py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 36 +++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index b26bb70..cc460b8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -21,7 +21,7 @@ from .text_segmentation_method import splits from .TextPreprocessor import TextPreprocessor i18n = I18nAuto() -# tts_infer.yaml +# configs/tts_infer.yaml """ default: device: cpu @@ -240,6 +240,12 @@ class TTS: self.t2s_model = t2s_model def set_ref_audio(self, ref_audio_path:str): + ''' + To set the reference audio for the TTS model, + including the prompt_semantic and refer_spepc. + Args: + ref_audio_path: str, the path of the reference audio. + ''' self._set_prompt_semantic(ref_audio_path) self._set_ref_spepc(ref_audio_path) @@ -399,6 +405,16 @@ class TTS: return _data, batch_index_list def recovery_order(self, data:list, batch_index_list:list)->list: + ''' + Recovery the order of the audio according to the batch_index_list. + + Args: + data (List[list(np.ndarray)]): the out of order audio . + batch_index_list (List[list[int]]): the batch index list. + + Returns: + list (List[np.ndarray]): the data in the original order. + ''' lenght = len(sum(batch_index_list, [])) _data = [None]*lenght for i, index_list in enumerate(batch_index_list): @@ -407,6 +423,9 @@ class TTS: return _data def stop(self,): + ''' + Stop the inference process. + ''' self.stop_flag = True @@ -435,8 +454,8 @@ class TTS: returns: tulpe[int, np.ndarray]: sampling rate and audio data. """ + ########## variables initialization ########### self.stop_flag:bool = False - text:str = inputs.get("text", "") text_lang:str = inputs.get("text_lang", "") ref_audio_path:str = inputs.get("ref_audio_path", "") @@ -475,6 +494,8 @@ class TTS: ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)): raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") + + ###### setting reference audio and prompt text preprocessing ######## t0 = ttime() if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): self.set_ref_audio(ref_audio_path) @@ -494,12 +515,8 @@ class TTS: self.prompt_cache["bert_features"] = bert_features self.prompt_cache["norm_text"] = norm_text - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - + ###### text preprocessing ######## data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) audio = [] t1 = ttime() @@ -516,6 +533,8 @@ class TTS: device=self.configs.device ) + + ###### inference ###### t_34 = 0.0 t_45 = 0.0 for item in data: @@ -525,12 +544,10 @@ class TTS: all_bert_features = item["all_bert_features"] norm_text = item["norm_text"] - # phones = phones.to(self.configs.device) all_phoneme_ids = all_phoneme_ids.to(self.configs.device) all_bert_features = all_bert_features.to(self.configs.device) if self.configs.is_half: all_bert_features = all_bert_features.half() - # all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]*all_phoneme_ids.shape[0], device=self.configs.device) print(i18n("前端处理后的文本(每句):"), norm_text) if no_prompt_text : @@ -539,7 +556,6 @@ class TTS: prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device) with torch.no_grad(): - # pred_semantic = t2s_model.model.infer( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( all_phoneme_ids, None,