chore: fix

This commit is contained in:
kevin.zhang 2024-05-14 08:49:19 +08:00
parent 648db7e855
commit 70a32d1a3f

View File

@ -320,7 +320,7 @@ async def tts_handle(req: dict):
try: try:
tts_instance = get_tts_instance(tts_config) tts_instance = get_tts_instance(tts_config)
move_to_gpu(tts_instance) move_to_gpu(tts_instance, tts_config)
tts_generator = tts_instance.run(req) tts_generator = tts_instance.run(req)
@ -351,9 +351,8 @@ def move_to_cpu(tts):
print("Moved TTS models to CPU to save GPU memory.") print("Moved TTS models to CPU to save GPU memory.")
def move_to_gpu(tts): def move_to_gpu(tts: TTS, tts_config: TTS_Config):
gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tts.set_device(tts_config.device)
tts.set_device(gpu_device)
print("Moved TTS models back to GPU for performance.") print("Moved TTS models back to GPU for performance.")