This commit is contained in:
spawner1145 2025-03-04 17:53:25 +08:00
parent 4b4673f66d
commit 80733266f5

View File

@ -182,6 +182,8 @@ import traceback
from typing import Generator, Optional, List, Dict
import random
import glob
from concurrent.futures import ThreadPoolExecutor
import asyncio
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -239,6 +241,9 @@ if hasattr(tts_config, 'device'):
tts_config.device = default_device
tts_pipeline = TTS(tts_config)
# 创建线程池用于异步执行 TTS 任务
executor = ThreadPoolExecutor(max_workers=1)
APP = FastAPI()
class TTS_Request(BaseModel):
@ -262,6 +267,7 @@ class TTS_Request(BaseModel):
streaming_mode: Optional[bool] = False
parallel_infer: Optional[bool] = True
repetition_penalty: Optional[float] = 1.35
device: Optional[str] = None
class TTSRole_Request(BaseModel):
text: str
@ -301,17 +307,18 @@ class TTSRole_Request(BaseModel):
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
audio_file.write(data)
io_buffer.seek(0)
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
io_buffer.seek(0)
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
with BytesIO() as wav_buf:
sf.write(wav_buf, data, rate, format='wav')
wav_buf.seek(0)
return wav_buf
sf.write(io_buffer, data, rate, format='wav')
io_buffer.seek(0)
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen([
@ -320,6 +327,7 @@ def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
io_buffer.seek(0)
return io_buffer
def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
@ -332,7 +340,6 @@ def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
@ -490,6 +497,52 @@ def select_ref_audio(role: str, text_lang: str, emotion: str = None):
return None, None, None
def set_pipeline_device(pipeline: TTS, device: str):
"""将 TTS 管道中的所有模型和相关组件迁移到指定设备,仅在设备变化时执行"""
if not torch.cuda.is_available() and device.startswith("cuda"):
print(f"警告: CUDA 不可用,强制使用 CPU")
device = "cpu"
target_device = torch.device(device)
# 检查当前设备是否需要切换
current_device = None
if hasattr(pipeline, 't2s_model') and pipeline.t2s_model is not None:
current_device = next(pipeline.t2s_model.parameters()).device
elif hasattr(pipeline, 'vits_model') and pipeline.vits_model is not None:
current_device = next(pipeline.vits_model.parameters()).device
if current_device == target_device:
print(f"设备已是 {device},无需切换")
return
# 更新配置中的设备
if hasattr(pipeline, 'configs') and hasattr(pipeline.configs, 'device'):
pipeline.configs.device = device
# 迁移所有可能的模型到指定设备
for attr in ['t2s_model', 'vits_model']:
if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
getattr(pipeline, attr).to(target_device)
for attr in dir(pipeline):
if attr.endswith('_model') and getattr(pipeline, attr) is not None:
try:
getattr(pipeline, attr).to(target_device)
print(f"迁移 {attr}{device}")
except AttributeError:
pass
# 清理 GPU 缓存
if torch.cuda.is_available() and not device.startswith("cuda"):
torch.cuda.empty_cache()
print(f"TTS 管道设备已设置为: {device}")
def run_tts_pipeline(req):
"""在线程池中运行 TTS 任务"""
return tts_pipeline.run(req)
async def tts_handle(req: dict, is_ttsrole: bool = False):
streaming_mode = req.get("streaming_mode", False)
media_type = req.get("media_type", "wav")
@ -501,6 +554,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
if check_res is not None:
return JSONResponse(status_code=400, content=check_res)
# 如果请求中指定了 device则覆盖所有与设备相关的参数并更新管道设备
if "device" in req and req["device"] is not None:
device = req["device"]
req["t2s_model_device"] = device
req["vits_model_device"] = device
if hasattr(tts_config, 'device'):
tts_config.device = device
set_pipeline_device(tts_pipeline, device)
if is_ttsrole:
role_exists = load_role_config(req["role"], req)
@ -546,7 +608,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
req["return_fragment"] = True
try:
tts_generator = tts_pipeline.run(req)
print(f"当前请求设备: {req.get('device')}")
if hasattr(tts_pipeline, 't2s_model') and tts_pipeline.t2s_model is not None:
print(f"t2s_model 设备: {next(tts_pipeline.t2s_model.parameters()).device}")
if hasattr(tts_pipeline, 'vits_model') and tts_pipeline.vits_model is not None:
print(f"vits_model 设备: {next(tts_pipeline.vits_model.parameters()).device}")
# 异步执行 TTS 任务
loop = asyncio.get_event_loop()
tts_generator = await loop.run_in_executor(executor, run_tts_pipeline, req)
if streaming_mode:
def streaming_generator():
@ -558,14 +628,11 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
for sr, chunk in tts_generator:
buf = pack_audio(chunk, sr, stream_type)
yield buf.getvalue()
buf.close()
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
else:
sr, audio_data = next(tts_generator)
buf = pack_audio(audio_data, sr, media_type)
response = Response(buf.getvalue(), media_type=f"audio/{media_type}")
buf.close()
return response
return Response(buf.getvalue(), media_type=f"audio/{media_type}")
except Exception as e:
return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)})
@ -596,7 +663,8 @@ async def tts_get_endpoint(
media_type: Optional[str] = "wav",
streaming_mode: Optional[bool] = False,
parallel_infer: Optional[bool] = True,
repetition_penalty: Optional[float] = 1.35
repetition_penalty: Optional[float] = 1.35,
device: Optional[str] = None
):
req = {
"text": text,
@ -618,7 +686,8 @@ async def tts_get_endpoint(
"media_type": media_type,
"streaming_mode": streaming_mode,
"parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty
"repetition_penalty": repetition_penalty,
"device": device
}
return await tts_handle(req)
@ -747,7 +816,7 @@ async def set_refer_audio(refer_audio_path: str = None):
if __name__ == "__main__":
try:
if host == 'None': # 在调用时使用 -a None 参数可以让api监听双栈
if host == 'None': # 在调用时使用 -a None 参数可以让api监听双栈
host = None
uvicorn.run(app=APP, host=host, port=port, workers=1)
except Exception as e: