From 8811f721ed07edbbfa199bd9ffb6b00c50d7daf4 Mon Sep 17 00:00:00 2001 From: "kevin.zhang" Date: Tue, 14 May 2024 08:52:11 +0800 Subject: [PATCH] chore: fix --- api_v3.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/api_v3.py b/api_v3.py index 9703637d..43f20438 100644 --- a/api_v3.py +++ b/api_v3.py @@ -320,7 +320,7 @@ async def tts_handle(req: dict): try: tts_instance = get_tts_instance(tts_config) - move_to_gpu(tts_instance) + move_to_gpu(tts_instance, tts_config) tts_generator = tts_instance.run(req) @@ -351,9 +351,8 @@ def move_to_cpu(tts): print("Moved TTS models to CPU to save GPU memory.") -def move_to_gpu(tts): - gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - tts.set_device(gpu_device) +def move_to_gpu(tts: TTS, tts_config: TTS_Config): + tts.set_device(tts_config.device) print("Moved TTS models back to GPU for performance.")