This commit is contained in:
SetoKaiba 2024-06-27 22:53:53 +08:00 committed by GitHub
parent 8dd7cfab93
commit 836bfec1fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 9 deletions

View File

@ -301,7 +301,7 @@ class TTS:
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device)!="cpu":
self.t2s_model = self.t2s_model.half() self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True): def enable_half_precision(self, enable: bool = True, save: bool = True):
''' '''
To enable half precision for the TTS model. To enable half precision for the TTS model.
Args: Args:
@ -314,6 +314,7 @@ class TTS:
self.configs.is_half = enable self.configs.is_half = enable
self.precision = torch.float16 if enable else torch.float32 self.precision = torch.float16 if enable else torch.float32
if save:
self.configs.save_configs() self.configs.save_configs()
if enable: if enable:
if self.t2s_model is not None: if self.t2s_model is not None:
@ -334,13 +335,14 @@ class TTS:
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float() self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device): def set_device(self, device: torch.device, save: bool = True):
''' '''
To set the device for all models. To set the device for all models.
Args: Args:
device: torch.device, the device to use for all models. device: torch.device, the device to use for all models.
''' '''
self.configs.device = device self.configs.device = device
if save:
self.configs.save_configs() self.configs.save_configs()
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.to(device) self.t2s_model = self.t2s_model.to(device)

View File

@ -320,7 +320,7 @@ async def tts_handle(req: dict):
try: try:
tts_instance = get_tts_instance(tts_config) tts_instance = get_tts_instance(tts_config)
move_to_gpu(tts_instance, tts_config) move_to_original(tts_instance, tts_config)
tts_generator = tts_instance.run(req) tts_generator = tts_instance.run(req)
@ -347,13 +347,15 @@ async def tts_handle(req: dict):
def move_to_cpu(tts): def move_to_cpu(tts):
cpu_device = torch.device('cpu') cpu_device = torch.device('cpu')
tts.set_device(cpu_device) tts.set_device(cpu_device, False)
tts.enable_half_precision(False, False)
print("Moved TTS models to CPU to save GPU memory.") print("Moved TTS models to CPU to save GPU memory.")
def move_to_gpu(tts: TTS, tts_config: TTS_Config): def move_to_original(tts: TTS, tts_config: TTS_Config):
tts.set_device(tts_config.device) tts.set_device(tts_config.device, False)
print("Moved TTS models back to GPU for performance.") tts.enable_half_precision(tts_config.is_half, False)
print("Moved TTS models back to original device for performance.")
@APP.get("/control") @APP.get("/control")