From 836bfec1fbf7356d59bd2dbe3883997646554e20 Mon Sep 17 00:00:00 2001 From: SetoKaiba <61304189@qq.com> Date: Thu, 27 Jun 2024 22:53:53 +0800 Subject: [PATCH] fix #1240 (#1241) --- GPT_SoVITS/TTS_infer_pack/TTS.py | 10 ++++++---- api_v3.py | 12 +++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c86a077..10dacca 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -301,7 +301,7 @@ class TTS: if self.configs.is_half and str(self.configs.device)!="cpu": self.t2s_model = self.t2s_model.half() - def enable_half_precision(self, enable: bool = True): + def enable_half_precision(self, enable: bool = True, save: bool = True): ''' To enable half precision for the TTS model. Args: @@ -314,7 +314,8 @@ class TTS: self.configs.is_half = enable self.precision = torch.float16 if enable else torch.float32 - self.configs.save_configs() + if save: + self.configs.save_configs() if enable: if self.t2s_model is not None: self.t2s_model =self.t2s_model.half() @@ -334,14 +335,15 @@ class TTS: if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() - def set_device(self, device: torch.device): + def set_device(self, device: torch.device, save: bool = True): ''' To set the device for all models. Args: device: torch.device, the device to use for all models. ''' self.configs.device = device - self.configs.save_configs() + if save: + self.configs.save_configs() if self.t2s_model is not None: self.t2s_model = self.t2s_model.to(device) if self.vits_model is not None: diff --git a/api_v3.py b/api_v3.py index a121dfc..e830113 100644 --- a/api_v3.py +++ b/api_v3.py @@ -320,7 +320,7 @@ async def tts_handle(req: dict): try: tts_instance = get_tts_instance(tts_config) - move_to_gpu(tts_instance, tts_config) + move_to_original(tts_instance, tts_config) tts_generator = tts_instance.run(req) @@ -347,13 +347,15 @@ async def tts_handle(req: dict): def move_to_cpu(tts): cpu_device = torch.device('cpu') - tts.set_device(cpu_device) + tts.set_device(cpu_device, False) + tts.enable_half_precision(False, False) print("Moved TTS models to CPU to save GPU memory.") -def move_to_gpu(tts: TTS, tts_config: TTS_Config): - tts.set_device(tts_config.device) - print("Moved TTS models back to GPU for performance.") +def move_to_original(tts: TTS, tts_config: TTS_Config): + tts.set_device(tts_config.device, False) + tts.enable_half_precision(tts_config.is_half, False) + print("Moved TTS models back to original device for performance.") @APP.get("/control")