From 80733266f5af546970454951094fa06070c7857d Mon Sep 17 00:00:00 2001 From: spawner1145 Date: Tue, 4 Mar 2025 17:53:25 +0800 Subject: [PATCH] 111 --- api_role.py | 95 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 13 deletions(-) diff --git a/api_role.py b/api_role.py index 81015ae..54250f4 100644 --- a/api_role.py +++ b/api_role.py @@ -182,6 +182,8 @@ import traceback from typing import Generator, Optional, List, Dict import random import glob +from concurrent.futures import ThreadPoolExecutor +import asyncio now_dir = os.getcwd() sys.path.append(now_dir) @@ -239,6 +241,9 @@ if hasattr(tts_config, 'device'): tts_config.device = default_device tts_pipeline = TTS(tts_config) +# 创建线程池用于异步执行 TTS 任务 +executor = ThreadPoolExecutor(max_workers=1) + APP = FastAPI() class TTS_Request(BaseModel): @@ -262,6 +267,7 @@ class TTS_Request(BaseModel): streaming_mode: Optional[bool] = False parallel_infer: Optional[bool] = True repetition_penalty: Optional[float] = 1.35 + device: Optional[str] = None class TTSRole_Request(BaseModel): text: str @@ -301,17 +307,18 @@ class TTSRole_Request(BaseModel): def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: audio_file.write(data) + io_buffer.seek(0) return io_buffer def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): io_buffer.write(data.tobytes()) + io_buffer.seek(0) return io_buffer def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): - with BytesIO() as wav_buf: - sf.write(wav_buf, data, rate, format='wav') - wav_buf.seek(0) - return wav_buf + sf.write(io_buffer, data, rate, format='wav') + io_buffer.seek(0) + return io_buffer def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): process = subprocess.Popen([ @@ -320,6 +327,7 @@ def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, _ = process.communicate(input=data.tobytes()) io_buffer.write(out) + io_buffer.seek(0) return io_buffer def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO: @@ -332,7 +340,6 @@ def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO: io_buffer = pack_wav(io_buffer, data, rate) else: io_buffer = pack_raw(io_buffer, data, rate) - io_buffer.seek(0) return io_buffer def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): @@ -490,6 +497,52 @@ def select_ref_audio(role: str, text_lang: str, emotion: str = None): return None, None, None +def set_pipeline_device(pipeline: TTS, device: str): + """将 TTS 管道中的所有模型和相关组件迁移到指定设备,仅在设备变化时执行""" + if not torch.cuda.is_available() and device.startswith("cuda"): + print(f"警告: CUDA 不可用,强制使用 CPU") + device = "cpu" + + target_device = torch.device(device) + + # 检查当前设备是否需要切换 + current_device = None + if hasattr(pipeline, 't2s_model') and pipeline.t2s_model is not None: + current_device = next(pipeline.t2s_model.parameters()).device + elif hasattr(pipeline, 'vits_model') and pipeline.vits_model is not None: + current_device = next(pipeline.vits_model.parameters()).device + + if current_device == target_device: + print(f"设备已是 {device},无需切换") + return + + # 更新配置中的设备 + if hasattr(pipeline, 'configs') and hasattr(pipeline.configs, 'device'): + pipeline.configs.device = device + + # 迁移所有可能的模型到指定设备 + for attr in ['t2s_model', 'vits_model']: + if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None: + getattr(pipeline, attr).to(target_device) + + for attr in dir(pipeline): + if attr.endswith('_model') and getattr(pipeline, attr) is not None: + try: + getattr(pipeline, attr).to(target_device) + print(f"迁移 {attr} 到 {device}") + except AttributeError: + pass + + # 清理 GPU 缓存 + if torch.cuda.is_available() and not device.startswith("cuda"): + torch.cuda.empty_cache() + + print(f"TTS 管道设备已设置为: {device}") + +def run_tts_pipeline(req): + """在线程池中运行 TTS 任务""" + return tts_pipeline.run(req) + async def tts_handle(req: dict, is_ttsrole: bool = False): streaming_mode = req.get("streaming_mode", False) media_type = req.get("media_type", "wav") @@ -501,6 +554,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False): if check_res is not None: return JSONResponse(status_code=400, content=check_res) + # 如果请求中指定了 device,则覆盖所有与设备相关的参数并更新管道设备 + if "device" in req and req["device"] is not None: + device = req["device"] + req["t2s_model_device"] = device + req["vits_model_device"] = device + if hasattr(tts_config, 'device'): + tts_config.device = device + set_pipeline_device(tts_pipeline, device) + if is_ttsrole: role_exists = load_role_config(req["role"], req) @@ -546,7 +608,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False): req["return_fragment"] = True try: - tts_generator = tts_pipeline.run(req) + print(f"当前请求设备: {req.get('device')}") + if hasattr(tts_pipeline, 't2s_model') and tts_pipeline.t2s_model is not None: + print(f"t2s_model 设备: {next(tts_pipeline.t2s_model.parameters()).device}") + if hasattr(tts_pipeline, 'vits_model') and tts_pipeline.vits_model is not None: + print(f"vits_model 设备: {next(tts_pipeline.vits_model.parameters()).device}") + + # 异步执行 TTS 任务 + loop = asyncio.get_event_loop() + tts_generator = await loop.run_in_executor(executor, run_tts_pipeline, req) if streaming_mode: def streaming_generator(): @@ -558,14 +628,11 @@ async def tts_handle(req: dict, is_ttsrole: bool = False): for sr, chunk in tts_generator: buf = pack_audio(chunk, sr, stream_type) yield buf.getvalue() - buf.close() return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}") else: sr, audio_data = next(tts_generator) buf = pack_audio(audio_data, sr, media_type) - response = Response(buf.getvalue(), media_type=f"audio/{media_type}") - buf.close() - return response + return Response(buf.getvalue(), media_type=f"audio/{media_type}") except Exception as e: return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)}) @@ -596,7 +663,8 @@ async def tts_get_endpoint( media_type: Optional[str] = "wav", streaming_mode: Optional[bool] = False, parallel_infer: Optional[bool] = True, - repetition_penalty: Optional[float] = 1.35 + repetition_penalty: Optional[float] = 1.35, + device: Optional[str] = None ): req = { "text": text, @@ -618,7 +686,8 @@ async def tts_get_endpoint( "media_type": media_type, "streaming_mode": streaming_mode, "parallel_infer": parallel_infer, - "repetition_penalty": repetition_penalty + "repetition_penalty": repetition_penalty, + "device": device } return await tts_handle(req) @@ -747,7 +816,7 @@ async def set_refer_audio(refer_audio_path: str = None): if __name__ == "__main__": try: - if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈 + if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈 host = None uvicorn.run(app=APP, host=host, port=port, workers=1) except Exception as e: