mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
chore: make gpu happy when do tts
This commit is contained in:
parent
715e4614eb
commit
84f5c1e5d8
20
api_v3.py
20
api_v3.py
@ -102,6 +102,8 @@ import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
|
||||
import torch
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
@ -317,6 +319,9 @@ async def tts_handle(req: dict):
|
||||
|
||||
try:
|
||||
tts_instance = get_tts_instance(tts_config)
|
||||
|
||||
move_to_gpu(tts_instance)
|
||||
|
||||
tts_generator = tts_instance.run(req)
|
||||
|
||||
if streaming_mode:
|
||||
@ -326,6 +331,7 @@ async def tts_handle(req: dict):
|
||||
media_type = "raw"
|
||||
for sr, chunk in tts_generator:
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
move_to_cpu(tts_instance)
|
||||
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||
@ -333,11 +339,24 @@ async def tts_handle(req: dict):
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
move_to_cpu(tts_instance)
|
||||
return Response(audio_data, media_type=f"audio/{media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
||||
|
||||
|
||||
def move_to_cpu(tts):
|
||||
cpu_device = torch.device('cpu')
|
||||
tts.set_device(cpu_device)
|
||||
print("Moved TTS models to CPU to save GPU memory.")
|
||||
|
||||
|
||||
def move_to_gpu(tts):
|
||||
gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
tts.set_device(gpu_device)
|
||||
print("Moved TTS models back to GPU for performance.")
|
||||
|
||||
|
||||
@APP.get("/control")
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
@ -390,6 +409,7 @@ async def tts_get_endpoint(
|
||||
"repetition_penalty": float(repetition_penalty),
|
||||
"tts_infer_yaml_path": tts_infer_yaml_path
|
||||
}
|
||||
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user