From 84f5c1e5d824e2ea65d775d76572ca9a67c0bbce Mon Sep 17 00:00:00 2001 From: "kevin.zhang" Date: Thu, 9 May 2024 11:12:44 +0800 Subject: [PATCH] chore: make gpu happy when do tts --- api_v3.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/api_v3.py b/api_v3.py index a78077a1..9703637d 100644 --- a/api_v3.py +++ b/api_v3.py @@ -102,6 +102,8 @@ import sys import traceback from typing import Generator +import torch + now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) @@ -317,6 +319,9 @@ async def tts_handle(req: dict): try: tts_instance = get_tts_instance(tts_config) + + move_to_gpu(tts_instance) + tts_generator = tts_instance.run(req) if streaming_mode: @@ -326,6 +331,7 @@ async def tts_handle(req: dict): media_type = "raw" for sr, chunk in tts_generator: yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + move_to_cpu(tts_instance) # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") @@ -333,11 +339,24 @@ async def tts_handle(req: dict): else: sr, audio_data = next(tts_generator) audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + move_to_cpu(tts_instance) return Response(audio_data, media_type=f"audio/{media_type}") except Exception as e: return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) +def move_to_cpu(tts): + cpu_device = torch.device('cpu') + tts.set_device(cpu_device) + print("Moved TTS models to CPU to save GPU memory.") + + +def move_to_gpu(tts): + gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + tts.set_device(gpu_device) + print("Moved TTS models back to GPU for performance.") + + @APP.get("/control") async def control(command: str = None): if command is None: @@ -390,6 +409,7 @@ async def tts_get_endpoint( "repetition_penalty": float(repetition_penalty), "tts_infer_yaml_path": tts_infer_yaml_path } + return await tts_handle(req)