增加了注释 GPT_SoVITS/TTS_infer_pack/TTS.py

This commit is contained in:
chasonjiang 2024-03-10 01:57:04 +08:00
parent ed2ffe1356
commit cae976ef5a

View File

@ -21,7 +21,7 @@ from .text_segmentation_method import splits
from .TextPreprocessor import TextPreprocessor from .TextPreprocessor import TextPreprocessor
i18n = I18nAuto() i18n = I18nAuto()
# tts_infer.yaml # configs/tts_infer.yaml
""" """
default: default:
device: cpu device: cpu
@ -240,6 +240,12 @@ class TTS:
self.t2s_model = t2s_model self.t2s_model = t2s_model
def set_ref_audio(self, ref_audio_path:str): def set_ref_audio(self, ref_audio_path:str):
'''
To set the reference audio for the TTS model,
including the prompt_semantic and refer_spepc.
Args:
ref_audio_path: str, the path of the reference audio.
'''
self._set_prompt_semantic(ref_audio_path) self._set_prompt_semantic(ref_audio_path)
self._set_ref_spepc(ref_audio_path) self._set_ref_spepc(ref_audio_path)
@ -399,6 +405,16 @@ class TTS:
return _data, batch_index_list return _data, batch_index_list
def recovery_order(self, data:list, batch_index_list:list)->list: def recovery_order(self, data:list, batch_index_list:list)->list:
'''
Recovery the order of the audio according to the batch_index_list.
Args:
data (List[list(np.ndarray)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list.
Returns:
list (List[np.ndarray]): the data in the original order.
'''
lenght = len(sum(batch_index_list, [])) lenght = len(sum(batch_index_list, []))
_data = [None]*lenght _data = [None]*lenght
for i, index_list in enumerate(batch_index_list): for i, index_list in enumerate(batch_index_list):
@ -407,6 +423,9 @@ class TTS:
return _data return _data
def stop(self,): def stop(self,):
'''
Stop the inference process.
'''
self.stop_flag = True self.stop_flag = True
@ -435,8 +454,8 @@ class TTS:
returns: returns:
tulpe[int, np.ndarray]: sampling rate and audio data. tulpe[int, np.ndarray]: sampling rate and audio data.
""" """
########## variables initialization ###########
self.stop_flag:bool = False self.stop_flag:bool = False
text:str = inputs.get("text", "") text:str = inputs.get("text", "")
text_lang:str = inputs.get("text_lang", "") text_lang:str = inputs.get("text_lang", "")
ref_audio_path:str = inputs.get("ref_audio_path", "") ref_audio_path:str = inputs.get("ref_audio_path", "")
@ -475,6 +494,8 @@ class TTS:
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)): ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
###### setting reference audio and prompt text preprocessing ########
t0 = ttime() t0 = ttime()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
self.set_ref_audio(ref_audio_path) self.set_ref_audio(ref_audio_path)
@ -494,12 +515,8 @@ class TTS:
self.prompt_cache["bert_features"] = bert_features self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text self.prompt_cache["norm_text"] = norm_text
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=np.float16 if self.configs.is_half else np.float32,
)
###### text preprocessing ########
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
audio = [] audio = []
t1 = ttime() t1 = ttime()
@ -516,6 +533,8 @@ class TTS:
device=self.configs.device device=self.configs.device
) )
###### inference ######
t_34 = 0.0 t_34 = 0.0
t_45 = 0.0 t_45 = 0.0
for item in data: for item in data:
@ -525,12 +544,10 @@ class TTS:
all_bert_features = item["all_bert_features"] all_bert_features = item["all_bert_features"]
norm_text = item["norm_text"] norm_text = item["norm_text"]
# phones = phones.to(self.configs.device)
all_phoneme_ids = all_phoneme_ids.to(self.configs.device) all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
all_bert_features = all_bert_features.to(self.configs.device) all_bert_features = all_bert_features.to(self.configs.device)
if self.configs.is_half: if self.configs.is_half:
all_bert_features = all_bert_features.half() all_bert_features = all_bert_features.half()
# all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]*all_phoneme_ids.shape[0], device=self.configs.device)
print(i18n("前端处理后的文本(每句):"), norm_text) print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text : if no_prompt_text :
@ -539,7 +556,6 @@ class TTS:
prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device)
with torch.no_grad(): with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
None, None,