mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
111
This commit is contained in:
parent
4b4673f66d
commit
80733266f5
93
api_role.py
93
api_role.py
@ -182,6 +182,8 @@ import traceback
|
|||||||
from typing import Generator, Optional, List, Dict
|
from typing import Generator, Optional, List, Dict
|
||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import asyncio
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -239,6 +241,9 @@ if hasattr(tts_config, 'device'):
|
|||||||
tts_config.device = default_device
|
tts_config.device = default_device
|
||||||
tts_pipeline = TTS(tts_config)
|
tts_pipeline = TTS(tts_config)
|
||||||
|
|
||||||
|
# 创建线程池用于异步执行 TTS 任务
|
||||||
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
APP = FastAPI()
|
APP = FastAPI()
|
||||||
|
|
||||||
class TTS_Request(BaseModel):
|
class TTS_Request(BaseModel):
|
||||||
@ -262,6 +267,7 @@ class TTS_Request(BaseModel):
|
|||||||
streaming_mode: Optional[bool] = False
|
streaming_mode: Optional[bool] = False
|
||||||
parallel_infer: Optional[bool] = True
|
parallel_infer: Optional[bool] = True
|
||||||
repetition_penalty: Optional[float] = 1.35
|
repetition_penalty: Optional[float] = 1.35
|
||||||
|
device: Optional[str] = None
|
||||||
|
|
||||||
class TTSRole_Request(BaseModel):
|
class TTSRole_Request(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
@ -301,17 +307,18 @@ class TTSRole_Request(BaseModel):
|
|||||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||||
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||||
audio_file.write(data)
|
audio_file.write(data)
|
||||||
|
io_buffer.seek(0)
|
||||||
return io_buffer
|
return io_buffer
|
||||||
|
|
||||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||||
io_buffer.write(data.tobytes())
|
io_buffer.write(data.tobytes())
|
||||||
|
io_buffer.seek(0)
|
||||||
return io_buffer
|
return io_buffer
|
||||||
|
|
||||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||||
with BytesIO() as wav_buf:
|
sf.write(io_buffer, data, rate, format='wav')
|
||||||
sf.write(wav_buf, data, rate, format='wav')
|
io_buffer.seek(0)
|
||||||
wav_buf.seek(0)
|
return io_buffer
|
||||||
return wav_buf
|
|
||||||
|
|
||||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||||
process = subprocess.Popen([
|
process = subprocess.Popen([
|
||||||
@ -320,6 +327,7 @@ def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
|||||||
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
out, _ = process.communicate(input=data.tobytes())
|
out, _ = process.communicate(input=data.tobytes())
|
||||||
io_buffer.write(out)
|
io_buffer.write(out)
|
||||||
|
io_buffer.seek(0)
|
||||||
return io_buffer
|
return io_buffer
|
||||||
|
|
||||||
def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
||||||
@ -332,7 +340,6 @@ def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
|||||||
io_buffer = pack_wav(io_buffer, data, rate)
|
io_buffer = pack_wav(io_buffer, data, rate)
|
||||||
else:
|
else:
|
||||||
io_buffer = pack_raw(io_buffer, data, rate)
|
io_buffer = pack_raw(io_buffer, data, rate)
|
||||||
io_buffer.seek(0)
|
|
||||||
return io_buffer
|
return io_buffer
|
||||||
|
|
||||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||||
@ -490,6 +497,52 @@ def select_ref_audio(role: str, text_lang: str, emotion: str = None):
|
|||||||
|
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
|
def set_pipeline_device(pipeline: TTS, device: str):
|
||||||
|
"""将 TTS 管道中的所有模型和相关组件迁移到指定设备,仅在设备变化时执行"""
|
||||||
|
if not torch.cuda.is_available() and device.startswith("cuda"):
|
||||||
|
print(f"警告: CUDA 不可用,强制使用 CPU")
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
target_device = torch.device(device)
|
||||||
|
|
||||||
|
# 检查当前设备是否需要切换
|
||||||
|
current_device = None
|
||||||
|
if hasattr(pipeline, 't2s_model') and pipeline.t2s_model is not None:
|
||||||
|
current_device = next(pipeline.t2s_model.parameters()).device
|
||||||
|
elif hasattr(pipeline, 'vits_model') and pipeline.vits_model is not None:
|
||||||
|
current_device = next(pipeline.vits_model.parameters()).device
|
||||||
|
|
||||||
|
if current_device == target_device:
|
||||||
|
print(f"设备已是 {device},无需切换")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 更新配置中的设备
|
||||||
|
if hasattr(pipeline, 'configs') and hasattr(pipeline.configs, 'device'):
|
||||||
|
pipeline.configs.device = device
|
||||||
|
|
||||||
|
# 迁移所有可能的模型到指定设备
|
||||||
|
for attr in ['t2s_model', 'vits_model']:
|
||||||
|
if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
|
||||||
|
getattr(pipeline, attr).to(target_device)
|
||||||
|
|
||||||
|
for attr in dir(pipeline):
|
||||||
|
if attr.endswith('_model') and getattr(pipeline, attr) is not None:
|
||||||
|
try:
|
||||||
|
getattr(pipeline, attr).to(target_device)
|
||||||
|
print(f"迁移 {attr} 到 {device}")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 清理 GPU 缓存
|
||||||
|
if torch.cuda.is_available() and not device.startswith("cuda"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
print(f"TTS 管道设备已设置为: {device}")
|
||||||
|
|
||||||
|
def run_tts_pipeline(req):
|
||||||
|
"""在线程池中运行 TTS 任务"""
|
||||||
|
return tts_pipeline.run(req)
|
||||||
|
|
||||||
async def tts_handle(req: dict, is_ttsrole: bool = False):
|
async def tts_handle(req: dict, is_ttsrole: bool = False):
|
||||||
streaming_mode = req.get("streaming_mode", False)
|
streaming_mode = req.get("streaming_mode", False)
|
||||||
media_type = req.get("media_type", "wav")
|
media_type = req.get("media_type", "wav")
|
||||||
@ -501,6 +554,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
|
|||||||
if check_res is not None:
|
if check_res is not None:
|
||||||
return JSONResponse(status_code=400, content=check_res)
|
return JSONResponse(status_code=400, content=check_res)
|
||||||
|
|
||||||
|
# 如果请求中指定了 device,则覆盖所有与设备相关的参数并更新管道设备
|
||||||
|
if "device" in req and req["device"] is not None:
|
||||||
|
device = req["device"]
|
||||||
|
req["t2s_model_device"] = device
|
||||||
|
req["vits_model_device"] = device
|
||||||
|
if hasattr(tts_config, 'device'):
|
||||||
|
tts_config.device = device
|
||||||
|
set_pipeline_device(tts_pipeline, device)
|
||||||
|
|
||||||
if is_ttsrole:
|
if is_ttsrole:
|
||||||
role_exists = load_role_config(req["role"], req)
|
role_exists = load_role_config(req["role"], req)
|
||||||
|
|
||||||
@ -546,7 +608,15 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
|
|||||||
req["return_fragment"] = True
|
req["return_fragment"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tts_generator = tts_pipeline.run(req)
|
print(f"当前请求设备: {req.get('device')}")
|
||||||
|
if hasattr(tts_pipeline, 't2s_model') and tts_pipeline.t2s_model is not None:
|
||||||
|
print(f"t2s_model 设备: {next(tts_pipeline.t2s_model.parameters()).device}")
|
||||||
|
if hasattr(tts_pipeline, 'vits_model') and tts_pipeline.vits_model is not None:
|
||||||
|
print(f"vits_model 设备: {next(tts_pipeline.vits_model.parameters()).device}")
|
||||||
|
|
||||||
|
# 异步执行 TTS 任务
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
tts_generator = await loop.run_in_executor(executor, run_tts_pipeline, req)
|
||||||
|
|
||||||
if streaming_mode:
|
if streaming_mode:
|
||||||
def streaming_generator():
|
def streaming_generator():
|
||||||
@ -558,14 +628,11 @@ async def tts_handle(req: dict, is_ttsrole: bool = False):
|
|||||||
for sr, chunk in tts_generator:
|
for sr, chunk in tts_generator:
|
||||||
buf = pack_audio(chunk, sr, stream_type)
|
buf = pack_audio(chunk, sr, stream_type)
|
||||||
yield buf.getvalue()
|
yield buf.getvalue()
|
||||||
buf.close()
|
|
||||||
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
||||||
else:
|
else:
|
||||||
sr, audio_data = next(tts_generator)
|
sr, audio_data = next(tts_generator)
|
||||||
buf = pack_audio(audio_data, sr, media_type)
|
buf = pack_audio(audio_data, sr, media_type)
|
||||||
response = Response(buf.getvalue(), media_type=f"audio/{media_type}")
|
return Response(buf.getvalue(), media_type=f"audio/{media_type}")
|
||||||
buf.close()
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)})
|
return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)})
|
||||||
|
|
||||||
@ -596,7 +663,8 @@ async def tts_get_endpoint(
|
|||||||
media_type: Optional[str] = "wav",
|
media_type: Optional[str] = "wav",
|
||||||
streaming_mode: Optional[bool] = False,
|
streaming_mode: Optional[bool] = False,
|
||||||
parallel_infer: Optional[bool] = True,
|
parallel_infer: Optional[bool] = True,
|
||||||
repetition_penalty: Optional[float] = 1.35
|
repetition_penalty: Optional[float] = 1.35,
|
||||||
|
device: Optional[str] = None
|
||||||
):
|
):
|
||||||
req = {
|
req = {
|
||||||
"text": text,
|
"text": text,
|
||||||
@ -618,7 +686,8 @@ async def tts_get_endpoint(
|
|||||||
"media_type": media_type,
|
"media_type": media_type,
|
||||||
"streaming_mode": streaming_mode,
|
"streaming_mode": streaming_mode,
|
||||||
"parallel_infer": parallel_infer,
|
"parallel_infer": parallel_infer,
|
||||||
"repetition_penalty": repetition_penalty
|
"repetition_penalty": repetition_penalty,
|
||||||
|
"device": device
|
||||||
}
|
}
|
||||||
return await tts_handle(req)
|
return await tts_handle(req)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user