From bd53aa8200cdf02e107abf8a3fb4220840586eaf Mon Sep 17 00:00:00 2001 From: CyberWon Date: Thu, 8 Aug 2024 22:49:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=AC=E5=9C=B0=E7=9A=84tts.py=E4=B8=8D?= =?UTF-8?q?=E6=98=AF=E6=9C=80=E6=96=B0=E7=9A=84=EF=BC=8C=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E4=B9=8B=E5=89=8D=E7=9A=84=E4=BF=AE=E5=A4=8D=E8=A2=AB=E6=9B=BF?= =?UTF-8?q?=E6=8D=A2=E4=BA=86=E3=80=82=E5=9B=9E=E6=BB=9A=E5=88=B0=E6=9C=80?= =?UTF-8?q?=E6=96=B0=E7=89=88=E6=9C=AC=E5=B9=B6=E6=B7=BB=E5=8A=A0v2?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 36 +++++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 893f6524..a4cbb4f7 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,6 +1,6 @@ from copy import deepcopy import math -import os, sys +import os, sys, gc import random import traceback @@ -39,7 +39,7 @@ default: cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - + custom: device: cuda is_half: true @@ -209,7 +209,7 @@ class TTS: self.text_preprocessor: TextPreprocessor = \ TextPreprocessor(self.bert_model, self.bert_tokenizer, - self.configs.device, version=self.version) + self.configs.device) self.prompt_cache: dict = { "ref_audio_path": None, @@ -301,12 +301,12 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.t2s_model = self.t2s_model.half() - def enable_half_precision(self, enable: bool = True): + def enable_half_precision(self, enable: bool = True, save: bool = True): ''' To enable half precision for the TTS model. Args: enable: bool, whether to enable half precision. - + ''' if str(self.configs.device) == "cpu" and enable: print("Half precision is not supported on CPU.") @@ -314,7 +314,8 @@ class TTS: self.configs.is_half = enable self.precision = torch.float16 if enable else torch.float32 - self.configs.save_configs() + if save: + self.configs.save_configs() if enable: if self.t2s_model is not None: self.t2s_model = self.t2s_model.half() @@ -334,14 +335,15 @@ class TTS: if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() - def set_device(self, device: torch.device): + def set_device(self, device: torch.device, save: bool = True): ''' To set the device for all models. Args: device: torch.device, the device to use for all models. ''' self.configs.device = device - self.configs.save_configs() + if save: + self.configs.save_configs() if self.t2s_model is not None: self.t2s_model = self.t2s_model.to(device) if self.vits_model is not None: @@ -353,13 +355,17 @@ class TTS: def set_ref_audio(self, ref_audio_path: str): ''' - To set the reference audio for the TTS model, + To set the reference audio for the TTS model, including the prompt_semantic and refer_spepc. Args: ref_audio_path: str, the path of the reference audio. ''' self._set_prompt_semantic(ref_audio_path) self._set_ref_spec(ref_audio_path) + self._set_ref_audio_path(ref_audio_path) + + def _set_ref_audio_path(self, ref_audio_path): + self.prompt_cache["ref_audio_path"] = ref_audio_path def _set_ref_spec(self, ref_audio_path): audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) @@ -545,11 +551,11 @@ class TTS: def recovery_order(self, data: list, batch_index_list: list) -> list: ''' Recovery the order of the audio according to the batch_index_list. - + Args: data (List[list(np.ndarray)]): the out of order audio . batch_index_list (List[list[int]]): the batch index list. - + Returns: list (List[np.ndarray]): the data in the original order. ''' @@ -570,9 +576,9 @@ class TTS: def run(self, inputs: dict): """ Text to speech inference. - + Args: - inputs (dict): + inputs (dict): { "text": "", # str.(required) text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized @@ -801,7 +807,6 @@ class TTS: audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))] all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( all_pred_semantic, _batch_phones, refer_audio_spec ).detach()[0, 0, :]) @@ -867,6 +872,7 @@ class TTS: def empty_cache(self): try: + gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 if "cuda" in str(self.configs.device): torch.cuda.empty_cache() elif str(self.configs.device) == "mps": @@ -932,4 +938,4 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int): # 将管道输出解码为 NumPy 数组 processed_audio = np.frombuffer(out, np.int16) - return processed_audio + return processed_audio \ No newline at end of file