mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
111
This commit is contained in:
parent
4b4673f66d
commit
80733266f5
95
api_role.py
95
api_role.py
@ -182,6 +182,8 @@ import traceback
|
||||
from typing import Generator, Optional, List, Dict
|
||||
import random
|
||||
import glob
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -239,6 +241,9 @@ if hasattr(tts_config, 'device'):
|
||||
tts_config.device = default_device
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
# 创建线程池用于异步执行 TTS 任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
APP = FastAPI()
|
||||
|
||||
class TTS_Request(BaseModel):
|
||||
@ -262,6 +267,7 @@ class TTS_Request(BaseModel):
|
||||
streaming_mode: Optional[bool] = False
|
||||
parallel_infer: Optional[bool] = True
|
||||
repetition_penalty: Optional[float] = 1.35
|
||||
device: Optional[str] = None
|
||||
|
||||
class TTSRole_Request(BaseModel):
|
||||
text: str
|
||||
@ -301,17 +307,18 @@ class TTSRole_Request(BaseModel):
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||
audio_file.write(data)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer.write(data.tobytes())
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
with BytesIO() as wav_buf:
|
||||
sf.write(wav_buf, data, rate, format='wav')
|
||||
wav_buf.seek(0)
|
||||
return wav_buf
|
||||
sf.write(io_buffer, data, rate, format='wav')
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
process = subprocess.Popen([
|
||||
@ -320,6 +327,7 @@ def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
||||
@ -332,7 +340,6 @@ def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
@ -490,6 +497,52 @@ def select_ref_audio(role: str, text_lang: str, emotion: str = None):
|
||||
|
||||
return None, None, None
|
||||
|
||||
def set_pipeline_device(pipeline: TTS, device: str):
|
||||
"""将 TTS 管道中的所有模型和相关组件迁移到指定设备,仅在设备变化时执行"""
|
||||
if not torch.cuda.is_available() and device.startswith("cuda"):
|
||||
print(f"警告: CUDA 不可用,强制使用 CPU")
|
||||
device = "cpu"
|
||||
|
||||
target_device = torch.device(device)
|
||||
|
||||
# 检查当前设备是否需要切换
|
||||
current_device = None
|
||||
if hasattr(pipeline, 't2s_model') and pipeline.t2s_model is not None:
|
||||
current_device = next(pipeline.t2s_model.parameters()).device
|
||||
elif hasattr(pipeline, 'vits_model') and pipeline.vits_model is not None:
|
||||
current_device = next(pipeline.vits_model.parameters()).device
|
||||
|
||||
if current_device == target_device:
|
||||
print(f"设备已是 {device},无需切换")
|
||||
return
|
||||
|
||||
# 更新配置中的设备
|
||||
if hasattr(pipeline, 'configs') and hasattr(pipeline.configs, 'device'):
|
||||
pipeline.configs.device = device
|
||||
|
||||
# 迁移所有可能的模型到指定设备
|
||||
for attr in ['t2s_model', 'vits_model']:
|
||||
if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
|
||||
getattr(pipeline, attr).to(target_device)
|
||||
|
||||
for attr in dir(pipeline):
|
||||
if attr.endswith('_model') and getattr(pipeline, attr) is not None:
|
||||
try:
|
||||
getattr(pipeline, attr).to(target_device)
|
||||
print(f"迁移 {attr} 到 {device}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# 清理 GPU 缓存
|
||||
if torch.cuda.is_available() and not device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"TTS 管道设备已设置为: {device}")
|
||||
|
||||
def run_tts_pipeline(req):
|
||||
"""在线程池中运行 TTS 任务"""
|
||||
return tts_pipeline.run(req)
|
||||
|
||||
async def tts_handle(req: dict, is_ttsrole: bool = False):
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
@ -501,6 +554,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
|
||||
if check_res is not None:
|
||||
return JSONResponse(status_code=400, content=check_res)
|
||||
|
||||
# 如果请求中指定了 device,则覆盖所有与设备相关的参数并更新管道设备
|
||||
if "device" in req and req["device"] is not None:
|
||||
device = req["device"]
|
||||
req["t2s_model_device"] = device
|
||||
req["vits_model_device"] = device
|
||||
if hasattr(tts_config, 'device'):
|
||||
tts_config.device = device
|
||||
set_pipeline_device(tts_pipeline, device)
|
||||
|
||||
if is_ttsrole:
|
||||
role_exists = load_role_config(req["role"], req)
|
||||
|
||||
@ -546,7 +608,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
|
||||
req["return_fragment"] = True
|
||||
|
||||
try:
|
||||
tts_generator = tts_pipeline.run(req)
|
||||
print(f"当前请求设备: {req.get('device')}")
|
||||
if hasattr(tts_pipeline, 't2s_model') and tts_pipeline.t2s_model is not None:
|
||||
print(f"t2s_model 设备: {next(tts_pipeline.t2s_model.parameters()).device}")
|
||||
if hasattr(tts_pipeline, 'vits_model') and tts_pipeline.vits_model is not None:
|
||||
print(f"vits_model 设备: {next(tts_pipeline.vits_model.parameters()).device}")
|
||||
|
||||
# 异步执行 TTS 任务
|
||||
loop = asyncio.get_event_loop()
|
||||
tts_generator = await loop.run_in_executor(executor, run_tts_pipeline, req)
|
||||
|
||||
if streaming_mode:
|
||||
def streaming_generator():
|
||||
@ -558,14 +628,11 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
|
||||
for sr, chunk in tts_generator:
|
||||
buf = pack_audio(chunk, sr, stream_type)
|
||||
yield buf.getvalue()
|
||||
buf.close()
|
||||
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
buf = pack_audio(audio_data, sr, media_type)
|
||||
response = Response(buf.getvalue(), media_type=f"audio/{media_type}")
|
||||
buf.close()
|
||||
return response
|
||||
return Response(buf.getvalue(), media_type=f"audio/{media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)})
|
||||
|
||||
@ -596,7 +663,8 @@ async def tts_get_endpoint(
|
||||
media_type: Optional[str] = "wav",
|
||||
streaming_mode: Optional[bool] = False,
|
||||
parallel_infer: Optional[bool] = True,
|
||||
repetition_penalty: Optional[float] = 1.35
|
||||
repetition_penalty: Optional[float] = 1.35,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
@ -618,7 +686,8 @@ async def tts_get_endpoint(
|
||||
"media_type": media_type,
|
||||
"streaming_mode": streaming_mode,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"device": device
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
@ -747,7 +816,7 @@ async def set_refer_audio(refer_audio_path: str = None):
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈
|
||||
if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈
|
||||
host = None
|
||||
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
||||
except Exception as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user