chore: make gpu happy when do tts

This commit is contained in:
kevin.zhang 2024-05-09 11:12:44 +08:00
parent 715e4614eb
commit 84f5c1e5d8

View File

@ -102,6 +102,8 @@ import sys
import traceback import traceback
from typing import Generator from typing import Generator
import torch
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir)) sys.path.append("%s/GPT_SoVITS" % (now_dir))
@ -317,6 +319,9 @@ async def tts_handle(req: dict):
try: try:
tts_instance = get_tts_instance(tts_config) tts_instance = get_tts_instance(tts_config)
move_to_gpu(tts_instance)
tts_generator = tts_instance.run(req) tts_generator = tts_instance.run(req)
if streaming_mode: if streaming_mode:
@ -326,6 +331,7 @@ async def tts_handle(req: dict):
media_type = "raw" media_type = "raw"
for sr, chunk in tts_generator: for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
move_to_cpu(tts_instance)
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
@ -333,11 +339,24 @@ async def tts_handle(req: dict):
else: else:
sr, audio_data = next(tts_generator) sr, audio_data = next(tts_generator)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
move_to_cpu(tts_instance)
return Response(audio_data, media_type=f"audio/{media_type}") return Response(audio_data, media_type=f"audio/{media_type}")
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
def move_to_cpu(tts):
cpu_device = torch.device('cpu')
tts.set_device(cpu_device)
print("Moved TTS models to CPU to save GPU memory.")
def move_to_gpu(tts):
gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tts.set_device(gpu_device)
print("Moved TTS models back to GPU for performance.")
@APP.get("/control") @APP.get("/control")
async def control(command: str = None): async def control(command: str = None):
if command is None: if command is None:
@ -390,6 +409,7 @@ async def tts_get_endpoint(
"repetition_penalty": float(repetition_penalty), "repetition_penalty": float(repetition_penalty),
"tts_infer_yaml_path": tts_infer_yaml_path "tts_infer_yaml_path": tts_infer_yaml_path
} }
return await tts_handle(req) return await tts_handle(req)