diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 893f6524..a4cbb4f7 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,6 +1,6 @@ from copy import deepcopy import math -import os, sys +import os, sys, gc import random import traceback @@ -39,7 +39,7 @@ default: cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - + custom: device: cuda is_half: true @@ -209,7 +209,7 @@ class TTS: self.text_preprocessor: TextPreprocessor = \ TextPreprocessor(self.bert_model, self.bert_tokenizer, - self.configs.device, version=self.version) + self.configs.device) self.prompt_cache: dict = { "ref_audio_path": None, @@ -301,12 +301,12 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.t2s_model = self.t2s_model.half() - def enable_half_precision(self, enable: bool = True): + def enable_half_precision(self, enable: bool = True, save: bool = True): ''' To enable half precision for the TTS model. Args: enable: bool, whether to enable half precision. - + ''' if str(self.configs.device) == "cpu" and enable: print("Half precision is not supported on CPU.") @@ -314,7 +314,8 @@ class TTS: self.configs.is_half = enable self.precision = torch.float16 if enable else torch.float32 - self.configs.save_configs() + if save: + self.configs.save_configs() if enable: if self.t2s_model is not None: self.t2s_model = self.t2s_model.half() @@ -334,14 +335,15 @@ class TTS: if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() - def set_device(self, device: torch.device): + def set_device(self, device: torch.device, save: bool = True): ''' To set the device for all models. Args: device: torch.device, the device to use for all models. ''' self.configs.device = device - self.configs.save_configs() + if save: + self.configs.save_configs() if self.t2s_model is not None: self.t2s_model = self.t2s_model.to(device) if self.vits_model is not None: @@ -353,13 +355,17 @@ class TTS: def set_ref_audio(self, ref_audio_path: str): ''' - To set the reference audio for the TTS model, + To set the reference audio for the TTS model, including the prompt_semantic and refer_spepc. Args: ref_audio_path: str, the path of the reference audio. ''' self._set_prompt_semantic(ref_audio_path) self._set_ref_spec(ref_audio_path) + self._set_ref_audio_path(ref_audio_path) + + def _set_ref_audio_path(self, ref_audio_path): + self.prompt_cache["ref_audio_path"] = ref_audio_path def _set_ref_spec(self, ref_audio_path): audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) @@ -545,11 +551,11 @@ class TTS: def recovery_order(self, data: list, batch_index_list: list) -> list: ''' Recovery the order of the audio according to the batch_index_list. - + Args: data (List[list(np.ndarray)]): the out of order audio . batch_index_list (List[list[int]]): the batch index list. - + Returns: list (List[np.ndarray]): the data in the original order. ''' @@ -570,9 +576,9 @@ class TTS: def run(self, inputs: dict): """ Text to speech inference. - + Args: - inputs (dict): + inputs (dict): { "text": "", # str.(required) text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized @@ -801,7 +807,6 @@ class TTS: audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))] all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( all_pred_semantic, _batch_phones, refer_audio_spec ).detach()[0, 0, :]) @@ -867,6 +872,7 @@ class TTS: def empty_cache(self): try: + gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 if "cuda" in str(self.configs.device): torch.cuda.empty_cache() elif str(self.configs.device) == "mps": @@ -932,4 +938,4 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int): # 将管道输出解码为 NumPy 数组 processed_audio = np.frombuffer(out, np.int16) - return processed_audio + return processed_audio \ No newline at end of file