This commit is contained in:
spawner1145 2025-03-04 17:53:25 +08:00
parent 4b4673f66d
commit 80733266f5

View File

@ -182,6 +182,8 @@ import traceback
from typing import Generator, Optional, List, Dict from typing import Generator, Optional, List, Dict
import random import random
import glob import glob
from concurrent.futures import ThreadPoolExecutor
import asyncio
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -239,6 +241,9 @@ if hasattr(tts_config, 'device'):
tts_config.device = default_device tts_config.device = default_device
tts_pipeline = TTS(tts_config) tts_pipeline = TTS(tts_config)
# 创建线程池用于异步执行 TTS 任务
executor = ThreadPoolExecutor(max_workers=1)
APP = FastAPI() APP = FastAPI()
class TTS_Request(BaseModel): class TTS_Request(BaseModel):
@ -262,6 +267,7 @@ class TTS_Request(BaseModel):
streaming_mode: Optional[bool] = False streaming_mode: Optional[bool] = False
parallel_infer: Optional[bool] = True parallel_infer: Optional[bool] = True
repetition_penalty: Optional[float] = 1.35 repetition_penalty: Optional[float] = 1.35
device: Optional[str] = None
class TTSRole_Request(BaseModel): class TTSRole_Request(BaseModel):
text: str text: str
@ -301,17 +307,18 @@ class TTSRole_Request(BaseModel):
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
audio_file.write(data) audio_file.write(data)
io_buffer.seek(0)
return io_buffer return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes()) io_buffer.write(data.tobytes())
io_buffer.seek(0)
return io_buffer return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
with BytesIO() as wav_buf: sf.write(io_buffer, data, rate, format='wav')
sf.write(wav_buf, data, rate, format='wav') io_buffer.seek(0)
wav_buf.seek(0) return io_buffer
return wav_buf
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen([ process = subprocess.Popen([
@ -320,6 +327,7 @@ def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = process.communicate(input=data.tobytes()) out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out) io_buffer.write(out)
io_buffer.seek(0)
return io_buffer return io_buffer
def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO: def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
@ -332,7 +340,6 @@ def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
io_buffer = pack_wav(io_buffer, data, rate) io_buffer = pack_wav(io_buffer, data, rate)
else: else:
io_buffer = pack_raw(io_buffer, data, rate) io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer return io_buffer
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
@ -490,6 +497,52 @@ def select_ref_audio(role: str, text_lang: str, emotion: str = None):
return None, None, None return None, None, None
def set_pipeline_device(pipeline: TTS, device: str):
"""将 TTS 管道中的所有模型和相关组件迁移到指定设备,仅在设备变化时执行"""
if not torch.cuda.is_available() and device.startswith("cuda"):
print(f"警告: CUDA 不可用,强制使用 CPU")
device = "cpu"
target_device = torch.device(device)
# 检查当前设备是否需要切换
current_device = None
if hasattr(pipeline, 't2s_model') and pipeline.t2s_model is not None:
current_device = next(pipeline.t2s_model.parameters()).device
elif hasattr(pipeline, 'vits_model') and pipeline.vits_model is not None:
current_device = next(pipeline.vits_model.parameters()).device
if current_device == target_device:
print(f"设备已是 {device},无需切换")
return
# 更新配置中的设备
if hasattr(pipeline, 'configs') and hasattr(pipeline.configs, 'device'):
pipeline.configs.device = device
# 迁移所有可能的模型到指定设备
for attr in ['t2s_model', 'vits_model']:
if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
getattr(pipeline, attr).to(target_device)
for attr in dir(pipeline):
if attr.endswith('_model') and getattr(pipeline, attr) is not None:
try:
getattr(pipeline, attr).to(target_device)
print(f"迁移 {attr}{device}")
except AttributeError:
pass
# 清理 GPU 缓存
if torch.cuda.is_available() and not device.startswith("cuda"):
torch.cuda.empty_cache()
print(f"TTS 管道设备已设置为: {device}")
def run_tts_pipeline(req):
"""在线程池中运行 TTS 任务"""
return tts_pipeline.run(req)
async def tts_handle(req: dict, is_ttsrole: bool = False): async def tts_handle(req: dict, is_ttsrole: bool = False):
streaming_mode = req.get("streaming_mode", False) streaming_mode = req.get("streaming_mode", False)
media_type = req.get("media_type", "wav") media_type = req.get("media_type", "wav")
@ -501,6 +554,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
if check_res is not None: if check_res is not None:
return JSONResponse(status_code=400, content=check_res) return JSONResponse(status_code=400, content=check_res)
# 如果请求中指定了 device则覆盖所有与设备相关的参数并更新管道设备
if "device" in req and req["device"] is not None:
device = req["device"]
req["t2s_model_device"] = device
req["vits_model_device"] = device
if hasattr(tts_config, 'device'):
tts_config.device = device
set_pipeline_device(tts_pipeline, device)
if is_ttsrole: if is_ttsrole:
role_exists = load_role_config(req["role"], req) role_exists = load_role_config(req["role"], req)
@ -546,7 +608,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
req["return_fragment"] = True req["return_fragment"] = True
try: try:
tts_generator = tts_pipeline.run(req) print(f"当前请求设备: {req.get('device')}")
if hasattr(tts_pipeline, 't2s_model') and tts_pipeline.t2s_model is not None:
print(f"t2s_model 设备: {next(tts_pipeline.t2s_model.parameters()).device}")
if hasattr(tts_pipeline, 'vits_model') and tts_pipeline.vits_model is not None:
print(f"vits_model 设备: {next(tts_pipeline.vits_model.parameters()).device}")
# 异步执行 TTS 任务
loop = asyncio.get_event_loop()
tts_generator = await loop.run_in_executor(executor, run_tts_pipeline, req)
if streaming_mode: if streaming_mode:
def streaming_generator(): def streaming_generator():
@ -558,14 +628,11 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
for sr, chunk in tts_generator: for sr, chunk in tts_generator:
buf = pack_audio(chunk, sr, stream_type) buf = pack_audio(chunk, sr, stream_type)
yield buf.getvalue() yield buf.getvalue()
buf.close()
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}") return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
else: else:
sr, audio_data = next(tts_generator) sr, audio_data = next(tts_generator)
buf = pack_audio(audio_data, sr, media_type) buf = pack_audio(audio_data, sr, media_type)
response = Response(buf.getvalue(), media_type=f"audio/{media_type}") return Response(buf.getvalue(), media_type=f"audio/{media_type}")
buf.close()
return response
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)}) return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)})
@ -596,7 +663,8 @@ async def tts_get_endpoint(
media_type: Optional[str] = "wav", media_type: Optional[str] = "wav",
streaming_mode: Optional[bool] = False, streaming_mode: Optional[bool] = False,
parallel_infer: Optional[bool] = True, parallel_infer: Optional[bool] = True,
repetition_penalty: Optional[float] = 1.35 repetition_penalty: Optional[float] = 1.35,
device: Optional[str] = None
): ):
req = { req = {
"text": text, "text": text,
@ -618,7 +686,8 @@ async def tts_get_endpoint(
"media_type": media_type, "media_type": media_type,
"streaming_mode": streaming_mode, "streaming_mode": streaming_mode,
"parallel_infer": parallel_infer, "parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty "repetition_penalty": repetition_penalty,
"device": device
} }
return await tts_handle(req) return await tts_handle(req)