mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-03 12:38:12 +08:00
- Frontend: add wavesurfer.js v7 waveform visualization with region-based audio trimming - Frontend: add export trimmed audio button, OfflineAudioContext-based client-side trimming - API: add OpenAPI tags, descriptions, and summaries for all endpoints - API: enhance /health endpoint with PID, memory, and GPU info (optional psutil/torch) - API: bump version to 1.1.0, enable /docs and /redoc - Docs: rewrite simple_api.md as comprehensive API reference - Docs: update simple_api_quickstart.md with Swagger/ReDoc links - Docs: update README with endpoint table and feature list - Tests: fix DummyFastAPI mock to accept **kwargs (tags, summary, etc.) - All 7 tests pass, compile check OK
788 lines
28 KiB
Python
788 lines
28 KiB
Python
"""
|
||
Small profile-based API layer for GPT-SoVITS.
|
||
|
||
Run:
|
||
python simple_api.py -c simple_api.yaml
|
||
|
||
Then call:
|
||
POST http://127.0.0.1:9881/speak
|
||
{"text": "hello", "voice": "default"}
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import base64
|
||
import os
|
||
import shutil
|
||
import subprocess
|
||
import sys
|
||
import threading
|
||
import traceback
|
||
import uuid
|
||
import wave
|
||
from copy import deepcopy
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||
|
||
import numpy as np
|
||
import soundfile as sf
|
||
import yaml
|
||
from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel, Field
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parent
|
||
os.chdir(PROJECT_ROOT)
|
||
sys.path.append(str(PROJECT_ROOT))
|
||
sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS"))
|
||
|
||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
|
||
get_method_names as get_cut_method_names,
|
||
)
|
||
|
||
|
||
DEFAULT_TTS_PARAMS: Dict[str, Any] = {
|
||
"text": "",
|
||
"text_lang": "zh",
|
||
"ref_audio_path": "",
|
||
"aux_ref_audio_paths": [],
|
||
"prompt_text": "",
|
||
"prompt_lang": "zh",
|
||
"top_k": 15,
|
||
"top_p": 1.0,
|
||
"temperature": 1.0,
|
||
"text_split_method": "cut5",
|
||
"batch_size": 1,
|
||
"batch_threshold": 0.75,
|
||
"split_bucket": True,
|
||
"speed_factor": 1.0,
|
||
"fragment_interval": 0.3,
|
||
"seed": -1,
|
||
"media_type": "wav",
|
||
"streaming_mode": False,
|
||
"return_fragment": False,
|
||
"fixed_length_chunk": False,
|
||
"parallel_infer": True,
|
||
"repetition_penalty": 1.35,
|
||
"sample_steps": 32,
|
||
"super_sampling": False,
|
||
"overlap_length": 2,
|
||
"min_chunk_length": 16,
|
||
}
|
||
|
||
SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"}
|
||
SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"}
|
||
VOICE_META_KEYS = {"description"}
|
||
|
||
|
||
class SpeakRequest(BaseModel):
|
||
text: str
|
||
voice: Optional[str] = None
|
||
text_lang: Optional[str] = None
|
||
ref_audio_path: Optional[str] = None
|
||
aux_ref_audio_paths: Optional[List[str]] = None
|
||
prompt_text: Optional[str] = None
|
||
prompt_lang: Optional[str] = None
|
||
format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac")
|
||
stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2")
|
||
streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false")
|
||
speed: Optional[float] = Field(default=None, description="Alias of speed_factor")
|
||
speed_factor: Optional[float] = None
|
||
top_k: Optional[int] = None
|
||
top_p: Optional[float] = None
|
||
temperature: Optional[float] = None
|
||
text_split_method: Optional[str] = None
|
||
batch_size: Optional[int] = None
|
||
batch_threshold: Optional[float] = None
|
||
split_bucket: Optional[bool] = None
|
||
fragment_interval: Optional[float] = None
|
||
seed: Optional[int] = None
|
||
parallel_infer: Optional[bool] = None
|
||
repetition_penalty: Optional[float] = None
|
||
sample_steps: Optional[int] = None
|
||
super_sampling: Optional[bool] = None
|
||
overlap_length: Optional[int] = None
|
||
min_chunk_length: Optional[int] = None
|
||
|
||
|
||
class WeightsRequest(BaseModel):
|
||
gpt_weights_path: Optional[str] = None
|
||
sovits_weights_path: Optional[str] = None
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(description="GPT-SoVITS simple API")
|
||
parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path")
|
||
parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path")
|
||
parser.add_argument("-a", "--bind-addr", default=None, help="bind address")
|
||
parser.add_argument("-p", "--port", type=int, default=None, help="bind port")
|
||
return parser.parse_args()
|
||
|
||
|
||
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
|
||
path = Path(config_path)
|
||
if not path.is_absolute():
|
||
path = PROJECT_ROOT / path
|
||
if not path.exists():
|
||
raise FileNotFoundError(f"simple API config not found: {path}")
|
||
with path.open("r", encoding="utf-8") as f:
|
||
data = yaml.safe_load(f) or {}
|
||
if not isinstance(data, dict):
|
||
raise ValueError("simple API config must be a YAML object")
|
||
return data
|
||
|
||
|
||
args = parse_args()
|
||
simple_config = load_yaml_config(args.config)
|
||
server_config = simple_config.get("server", {}) or {}
|
||
host = args.bind_addr or server_config.get("host", "127.0.0.1")
|
||
port = args.port or int(server_config.get("port", 9881))
|
||
tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml")
|
||
|
||
cut_method_names = get_cut_method_names()
|
||
tts_config: Optional[TTS_Config] = None
|
||
tts_pipeline: Optional[TTS] = None
|
||
infer_lock = threading.Lock()
|
||
|
||
APP = FastAPI(
|
||
title="GPT-SoVITS Simple API",
|
||
description=(
|
||
"简化接口层,封装 GPT-SoVITS 推理引擎。\n\n"
|
||
"## 核心流程\n"
|
||
"1. 上传 3-10 秒参考音频(或视频,前端自动提取音频)\n"
|
||
"2. 填写需要生成的文字\n"
|
||
"3. 调用 `/api/tts` 获取生成的音频\n\n"
|
||
"## 其他接口\n"
|
||
"- `/speak` — 基于 voice profile 的调用方式\n"
|
||
"- `/v1/tts` — OpenAI 兼容格式\n"
|
||
"- `/admin/*` — 管理接口(热加载配置、切换模型)"
|
||
),
|
||
version="1.1.0",
|
||
docs_url="/docs",
|
||
redoc_url="/redoc",
|
||
openapi_tags=[
|
||
{"name": "MVP", "description": "核心 TTS 接口,上传参考音频直接生成"},
|
||
{"name": "Profile", "description": "基于 voice profile 的调用方式"},
|
||
{"name": "Admin", "description": "管理接口:热加载配置、切换模型权重"},
|
||
{"name": "System", "description": "健康检查与信息查询"},
|
||
],
|
||
)
|
||
APP.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=simple_config.get("cors_allow_origins", ["*"]),
|
||
allow_credentials=False,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
test_frontend_dir = PROJECT_ROOT / "test_frontend"
|
||
if test_frontend_dir.exists():
|
||
APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend")
|
||
|
||
|
||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||
def handle_pack_ogg() -> None:
|
||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||
audio_file.write(data)
|
||
|
||
try:
|
||
threading.stack_size(4096 * 4096)
|
||
pack_thread = threading.Thread(target=handle_pack_ogg)
|
||
pack_thread.start()
|
||
pack_thread.join()
|
||
except (RuntimeError, ValueError):
|
||
handle_pack_ogg()
|
||
return io_buffer
|
||
|
||
|
||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||
del rate
|
||
io_buffer.write(data.tobytes())
|
||
return io_buffer
|
||
|
||
|
||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||
del io_buffer
|
||
wav_buffer = BytesIO()
|
||
sf.write(wav_buffer, data, rate, format="wav")
|
||
return wav_buffer
|
||
|
||
|
||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||
process = subprocess.Popen(
|
||
[
|
||
"ffmpeg",
|
||
"-f",
|
||
"s16le",
|
||
"-ar",
|
||
str(rate),
|
||
"-ac",
|
||
"1",
|
||
"-i",
|
||
"pipe:0",
|
||
"-c:a",
|
||
"aac",
|
||
"-b:a",
|
||
"192k",
|
||
"-vn",
|
||
"-f",
|
||
"adts",
|
||
"pipe:1",
|
||
],
|
||
stdin=subprocess.PIPE,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
)
|
||
out, _ = process.communicate(input=data.tobytes())
|
||
io_buffer.write(out)
|
||
return io_buffer
|
||
|
||
|
||
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
||
if media_type == "ogg":
|
||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||
elif media_type == "aac":
|
||
io_buffer = pack_aac(io_buffer, data, rate)
|
||
elif media_type == "wav":
|
||
io_buffer = pack_wav(io_buffer, data, rate)
|
||
else:
|
||
io_buffer = pack_raw(io_buffer, data, rate)
|
||
io_buffer.seek(0)
|
||
return io_buffer
|
||
|
||
|
||
def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes:
|
||
wav_buf = BytesIO()
|
||
with wave.open(wav_buf, "wb") as vfout:
|
||
vfout.setnchannels(channels)
|
||
vfout.setsampwidth(sample_width)
|
||
vfout.setframerate(sample_rate)
|
||
vfout.writeframes(frame_input)
|
||
wav_buf.seek(0)
|
||
return wav_buf.read()
|
||
|
||
|
||
def request_to_dict(request: BaseModel) -> Dict[str, Any]:
|
||
if hasattr(request, "model_dump"):
|
||
return request.model_dump(exclude_none=True)
|
||
return request.dict(exclude_none=True)
|
||
|
||
|
||
def get_default_voice_name() -> Optional[str]:
|
||
default_voice = simple_config.get("default_voice")
|
||
if default_voice:
|
||
return str(default_voice)
|
||
voices = simple_config.get("voices", {}) or {}
|
||
if "default" in voices:
|
||
return "default"
|
||
return next(iter(voices.keys()), None)
|
||
|
||
|
||
def resolve_project_path(path_value: Optional[str]) -> Optional[str]:
|
||
if path_value in [None, ""]:
|
||
return None
|
||
path = Path(str(path_value))
|
||
if not path.is_absolute():
|
||
path = PROJECT_ROOT / path
|
||
return str(path)
|
||
|
||
|
||
def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]:
|
||
if streaming_mode is None:
|
||
streaming_mode = 2 if stream else 0
|
||
elif isinstance(streaming_mode, bool):
|
||
streaming_mode = 2 if streaming_mode else 0
|
||
|
||
if streaming_mode == 0:
|
||
return False, False, False, False
|
||
if streaming_mode == 1:
|
||
return False, True, False, True
|
||
if streaming_mode == 2:
|
||
return True, False, False, True
|
||
if streaming_mode == 3:
|
||
return True, False, True, True
|
||
raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false")
|
||
|
||
|
||
def get_upload_config() -> Dict[str, Any]:
|
||
return simple_config.get("upload", {}) or {}
|
||
|
||
|
||
def get_upload_root() -> Path:
|
||
upload_dir = str(get_upload_config().get("dir", "runtime/uploads"))
|
||
root = Path(upload_dir)
|
||
if not root.is_absolute():
|
||
root = PROJECT_ROOT / root
|
||
root.mkdir(parents=True, exist_ok=True)
|
||
return root
|
||
|
||
|
||
def get_request_upload_dir() -> Path:
|
||
request_dir = get_upload_root() / uuid.uuid4().hex
|
||
request_dir.mkdir(parents=True, exist_ok=True)
|
||
return request_dir
|
||
|
||
|
||
def get_upload_limits() -> Tuple[float, float, int]:
|
||
upload_config = get_upload_config()
|
||
min_seconds = float(upload_config.get("min_ref_seconds", 3.0))
|
||
max_seconds = float(upload_config.get("max_ref_seconds", 10.0))
|
||
max_upload_mb = int(upload_config.get("max_upload_mb", 80))
|
||
return min_seconds, max_seconds, max_upload_mb
|
||
|
||
|
||
def get_upload_suffix(upload: UploadFile) -> str:
|
||
suffix = Path(upload.filename or "").suffix.lower()
|
||
if suffix:
|
||
return suffix
|
||
content_type = (upload.content_type or "").lower()
|
||
if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}:
|
||
return ".wav"
|
||
if content_type == "audio/mpeg":
|
||
return ".mp3"
|
||
if content_type == "audio/ogg":
|
||
return ".ogg"
|
||
if content_type in {"audio/aac", "audio/aacp"}:
|
||
return ".aac"
|
||
return ".wav"
|
||
|
||
|
||
def validate_audio_upload(upload: UploadFile) -> None:
|
||
suffix = get_upload_suffix(upload)
|
||
content_type = (upload.content_type or "").lower()
|
||
is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""}
|
||
if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type:
|
||
raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}")
|
||
|
||
|
||
async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str:
|
||
validate_audio_upload(upload)
|
||
_, _, max_upload_mb = get_upload_limits()
|
||
max_bytes = max_upload_mb * 1024 * 1024
|
||
suffix = get_upload_suffix(upload)
|
||
target_path = target_dir / f"{name_prefix}{suffix}"
|
||
total_size = 0
|
||
|
||
try:
|
||
with target_path.open("wb") as f:
|
||
while True:
|
||
chunk = await upload.read(1024 * 1024)
|
||
if not chunk:
|
||
break
|
||
total_size += len(chunk)
|
||
if total_size > max_bytes:
|
||
raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB")
|
||
f.write(chunk)
|
||
if total_size == 0:
|
||
raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}")
|
||
finally:
|
||
await upload.close()
|
||
|
||
return str(target_path)
|
||
|
||
|
||
def get_audio_duration_with_ffprobe(audio_path: str) -> float:
|
||
result = subprocess.run(
|
||
[
|
||
"ffprobe",
|
||
"-v",
|
||
"error",
|
||
"-show_entries",
|
||
"format=duration",
|
||
"-of",
|
||
"default=noprint_wrappers=1:nokey=1",
|
||
audio_path,
|
||
],
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=15,
|
||
)
|
||
if result.returncode != 0:
|
||
raise RuntimeError(result.stderr.strip() or "ffprobe failed")
|
||
return float(result.stdout.strip())
|
||
|
||
|
||
def get_audio_duration_seconds(audio_path: str) -> float:
|
||
try:
|
||
info = sf.info(audio_path)
|
||
if info.samplerate <= 0:
|
||
raise ValueError("invalid sample rate")
|
||
return float(info.frames) / float(info.samplerate)
|
||
except Exception as sf_exc:
|
||
try:
|
||
return get_audio_duration_with_ffprobe(audio_path)
|
||
except Exception as ffprobe_exc:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}",
|
||
) from ffprobe_exc
|
||
|
||
|
||
def validate_main_ref_duration(audio_path: str) -> None:
|
||
min_seconds, max_seconds, _ = get_upload_limits()
|
||
duration = get_audio_duration_seconds(audio_path)
|
||
if duration < min_seconds or duration > max_seconds:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s",
|
||
)
|
||
|
||
|
||
def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None:
|
||
if emotion in [None, ""]:
|
||
return
|
||
emotion_name = str(emotion).strip().lower()
|
||
emotion_presets = simple_config.get("emotion_presets", {}) or {}
|
||
if emotion_name not in emotion_presets:
|
||
supported = ", ".join(sorted(emotion_presets.keys())) or "none"
|
||
raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}")
|
||
payload.update(emotion_presets[emotion_name] or {})
|
||
|
||
|
||
def prompt_text_required_for_current_model() -> bool:
|
||
if tts_config is None:
|
||
return False
|
||
use_vocoder = getattr(tts_config, "use_vocoder", None)
|
||
if use_vocoder is not None:
|
||
return bool(use_vocoder)
|
||
return getattr(tts_config, "version", None) in {"v3", "v4"}
|
||
|
||
|
||
def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]:
|
||
ref_audio_path = resolve_project_path(profile.get("ref_audio_path"))
|
||
return {
|
||
"name": name,
|
||
"description": profile.get("description", ""),
|
||
"text_lang": profile.get("text_lang"),
|
||
"prompt_lang": profile.get("prompt_lang"),
|
||
"ref_audio_path": profile.get("ref_audio_path"),
|
||
"ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")),
|
||
}
|
||
|
||
|
||
def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]:
|
||
voices = simple_config.get("voices", {}) or {}
|
||
explicit_ref_audio = bool(payload.get("ref_audio_path"))
|
||
voice_name = payload.pop("voice", None)
|
||
if voice_name is None and not explicit_ref_audio:
|
||
voice_name = get_default_voice_name()
|
||
voice_profile: Dict[str, Any] = {}
|
||
|
||
if voice_name:
|
||
if voice_name not in voices:
|
||
raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}")
|
||
voice_profile = deepcopy(voices[voice_name] or {})
|
||
|
||
tts_req = deepcopy(DEFAULT_TTS_PARAMS)
|
||
tts_req.update(simple_config.get("defaults", {}) or {})
|
||
tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS})
|
||
|
||
stream = payload.pop("stream", None)
|
||
media_type = payload.pop("format", None)
|
||
speed = payload.pop("speed", None)
|
||
if media_type is not None:
|
||
tts_req["media_type"] = media_type
|
||
|
||
for key, value in payload.items():
|
||
if key in DEFAULT_TTS_PARAMS:
|
||
tts_req[key] = value
|
||
if speed is not None:
|
||
tts_req["speed_factor"] = speed
|
||
|
||
ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path"))
|
||
if ref_audio_path:
|
||
tts_req["ref_audio_path"] = ref_audio_path
|
||
|
||
aux_paths = tts_req.get("aux_ref_audio_paths") or []
|
||
tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item]
|
||
|
||
media_type = str(tts_req.get("media_type", "wav")).lower()
|
||
tts_req["media_type"] = media_type
|
||
if media_type not in SUPPORTED_MEDIA_TYPES:
|
||
raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}")
|
||
|
||
text = str(tts_req.get("text") or "").strip()
|
||
if not text:
|
||
raise HTTPException(status_code=400, detail="text is required")
|
||
tts_req["text"] = text
|
||
|
||
if not tts_req.get("ref_audio_path"):
|
||
raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request")
|
||
if not Path(tts_req["ref_audio_path"]).exists():
|
||
raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}")
|
||
|
||
for aux_path in tts_req["aux_ref_audio_paths"]:
|
||
if aux_path and not Path(aux_path).exists():
|
||
raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}")
|
||
|
||
if tts_config is None:
|
||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||
|
||
text_lang = str(tts_req.get("text_lang") or "").lower()
|
||
prompt_lang = str(tts_req.get("prompt_lang") or "").lower()
|
||
tts_req["text_lang"] = text_lang
|
||
tts_req["prompt_lang"] = prompt_lang
|
||
|
||
if text_lang not in tts_config.languages:
|
||
raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}")
|
||
if prompt_lang not in tts_config.languages:
|
||
raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}")
|
||
if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model():
|
||
version = getattr(tts_config, "version", "current")
|
||
raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}")
|
||
|
||
split_method = str(tts_req.get("text_split_method") or "cut5")
|
||
if split_method not in cut_method_names:
|
||
raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}")
|
||
|
||
streaming_mode = tts_req.pop("streaming_mode", None)
|
||
streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream)
|
||
tts_req["streaming_mode"] = streaming_enabled
|
||
tts_req["return_fragment"] = return_fragment
|
||
tts_req["fixed_length_chunk"] = fixed_length_chunk
|
||
|
||
return tts_req, media_type, response_stream
|
||
|
||
|
||
def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]:
|
||
tts_req, media_type, response_stream = build_tts_request(payload)
|
||
if response_stream:
|
||
raise HTTPException(status_code=400, detail="base64 output does not support streaming")
|
||
|
||
if tts_pipeline is None:
|
||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||
|
||
try:
|
||
with infer_lock:
|
||
sr, audio_data = next(tts_pipeline.run(tts_req))
|
||
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||
return audio_bytes, media_type
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc
|
||
|
||
|
||
def synthesize_response(payload: Dict[str, Any]) -> Response:
|
||
tts_req, media_type, response_stream = build_tts_request(payload)
|
||
if tts_pipeline is None:
|
||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||
|
||
if response_stream:
|
||
|
||
def streaming_generator() -> Generator[bytes, None, None]:
|
||
first_chunk = True
|
||
chunk_media_type = media_type
|
||
with infer_lock:
|
||
for sr, chunk in tts_pipeline.run(tts_req):
|
||
if first_chunk and chunk_media_type == "wav":
|
||
yield wave_header_chunk(sample_rate=sr)
|
||
chunk_media_type = "raw"
|
||
first_chunk = False
|
||
yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue()
|
||
|
||
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
||
|
||
try:
|
||
with infer_lock:
|
||
sr, audio_data = next(tts_pipeline.run(tts_req))
|
||
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||
return Response(audio_bytes, media_type=f"audio/{media_type}")
|
||
except Exception as exc:
|
||
return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)})
|
||
|
||
|
||
@APP.on_event("startup")
|
||
def startup() -> None:
|
||
global tts_config, tts_pipeline
|
||
tts_config = TTS_Config(tts_config_path)
|
||
print(tts_config)
|
||
tts_pipeline = TTS(tts_config)
|
||
|
||
|
||
@APP.get("/", tags=["System"])
|
||
def index() -> Dict[str, Any]:
|
||
return {
|
||
"name": "GPT-SoVITS Simple API",
|
||
"version": "1.1.0",
|
||
"docs": "/docs",
|
||
"endpoints": {
|
||
"system": ["/health", "/voices"],
|
||
"mvp": ["/api/tts"],
|
||
"profile": ["/speak", "/speak/base64", "/v1/tts"],
|
||
"admin": ["/admin/reload-config", "/admin/weights"],
|
||
},
|
||
}
|
||
|
||
|
||
@APP.get("/health", tags=["System"], summary="健康检查")
|
||
def health() -> Dict[str, Any]:
|
||
import os
|
||
|
||
result: Dict[str, Any] = {
|
||
"status": "ok" if tts_pipeline is not None else "starting",
|
||
"tts_config": tts_config_path,
|
||
"version": getattr(tts_config, "version", None),
|
||
"languages": getattr(tts_config, "languages", []),
|
||
"pid": os.getpid(),
|
||
}
|
||
try:
|
||
import psutil
|
||
result["memory_mb"] = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 1)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
import torch
|
||
if torch.cuda.is_available():
|
||
result["gpu"] = {
|
||
"name": torch.cuda.get_device_name(0),
|
||
"memory_used_mb": round(torch.cuda.memory_allocated(0) / 1024 / 1024, 1),
|
||
"memory_total_mb": round(torch.cuda.get_device_properties(0).total_mem / 1024 / 1024, 1),
|
||
}
|
||
except Exception:
|
||
pass
|
||
return result
|
||
|
||
|
||
@APP.get("/voices", tags=["System"], summary="列出可用 voice profiles")
|
||
def list_voices() -> Dict[str, Any]:
|
||
voices = simple_config.get("voices", {}) or {}
|
||
return {
|
||
"default_voice": get_default_voice_name(),
|
||
"voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()],
|
||
}
|
||
|
||
|
||
@APP.post(
|
||
"/api/tts",
|
||
tags=["MVP"],
|
||
summary="核心 TTS 接口",
|
||
description=(
|
||
"上传参考音频和需要生成的文字,返回生成的音频。\n\n"
|
||
"**主参考音频要求**:3-10 秒,支持 wav/flac/ogg/mp3/m4a/aac 格式。\n\n"
|
||
"**文字切句**:固定使用 `cut5`(按标点符号切句)。\n\n"
|
||
"**情绪预设**:neutral / happy / calm / sad / angry,本质是映射到采样和语速参数。"
|
||
),
|
||
)
|
||
async def mvp_tts(
|
||
text: str = Form(...),
|
||
ref_audio: UploadFile = File(...),
|
||
aux_ref_audio: Optional[List[UploadFile]] = File(default=None),
|
||
prompt_text: str = Form(default=""),
|
||
text_lang: str = Form(default="zh"),
|
||
prompt_lang: str = Form(default="zh"),
|
||
format: str = Form(default="wav"),
|
||
emotion: Optional[str] = Form(default=None),
|
||
speed: Optional[float] = Form(default=None),
|
||
seed: int = Form(default=-1),
|
||
) -> Response:
|
||
request_dir = get_request_upload_dir()
|
||
try:
|
||
ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref")
|
||
validate_main_ref_duration(ref_audio_path)
|
||
|
||
aux_ref_audio_paths = []
|
||
for index, upload in enumerate(aux_ref_audio or []):
|
||
aux_path = await save_upload_file(upload, request_dir, f"aux_{index}")
|
||
get_audio_duration_seconds(aux_path)
|
||
aux_ref_audio_paths.append(aux_path)
|
||
|
||
payload: Dict[str, Any] = {
|
||
"text": text,
|
||
"text_lang": text_lang,
|
||
"ref_audio_path": ref_audio_path,
|
||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||
"prompt_text": prompt_text or "",
|
||
"prompt_lang": prompt_lang,
|
||
"format": format,
|
||
"text_split_method": "cut5",
|
||
"streaming_mode": 0,
|
||
"seed": seed,
|
||
}
|
||
apply_emotion_preset(payload, emotion)
|
||
payload["text_split_method"] = "cut5"
|
||
payload["streaming_mode"] = 0
|
||
if speed is not None:
|
||
payload["speed"] = speed
|
||
|
||
return synthesize_response(payload)
|
||
finally:
|
||
shutil.rmtree(request_dir, ignore_errors=True)
|
||
|
||
|
||
@APP.get("/speak", tags=["Profile"], summary="GET 方式调用 voice profile TTS")
|
||
def speak_get(
|
||
text: str,
|
||
voice: Optional[str] = None,
|
||
text_lang: Optional[str] = None,
|
||
format: Optional[str] = None,
|
||
stream: Optional[bool] = None,
|
||
speed: Optional[float] = None,
|
||
) -> Response:
|
||
payload = {
|
||
"text": text,
|
||
"voice": voice,
|
||
"text_lang": text_lang,
|
||
"format": format,
|
||
"stream": stream,
|
||
"speed": speed,
|
||
}
|
||
return synthesize_response({k: v for k, v in payload.items() if v is not None})
|
||
|
||
|
||
@APP.post("/speak", tags=["Profile"], summary="POST 方式调用 voice profile TTS")
|
||
def speak_post(request: SpeakRequest) -> Response:
|
||
return synthesize_response(request_to_dict(request))
|
||
|
||
|
||
@APP.post("/v1/tts", tags=["Profile"], summary="OpenAI 兼容格式 TTS")
|
||
def openai_style_tts(request: SpeakRequest) -> Response:
|
||
return synthesize_response(request_to_dict(request))
|
||
|
||
|
||
@APP.post("/speak/base64", tags=["Profile"], summary="返回 Base64 编码的音频")
|
||
def speak_base64(request: SpeakRequest) -> Dict[str, Any]:
|
||
audio_bytes, media_type = synthesize_once(request_to_dict(request))
|
||
return {
|
||
"media_type": f"audio/{media_type}",
|
||
"audio_base64": base64.b64encode(audio_bytes).decode("ascii"),
|
||
}
|
||
|
||
|
||
@APP.post("/admin/reload-config", tags=["Admin"], summary="热加载 simple_api.yaml 配置")
|
||
def reload_config() -> Dict[str, Any]:
|
||
global simple_config
|
||
simple_config = load_yaml_config(args.config)
|
||
return {"message": "success", "default_voice": get_default_voice_name()}
|
||
|
||
|
||
@APP.post("/admin/weights", tags=["Admin"], summary="运行时切换 GPT-SoVITS 模型权重")
|
||
def set_weights(request: WeightsRequest) -> Dict[str, Any]:
|
||
if tts_pipeline is None:
|
||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||
if not request.gpt_weights_path and not request.sovits_weights_path:
|
||
raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required")
|
||
|
||
try:
|
||
with infer_lock:
|
||
if request.gpt_weights_path:
|
||
tts_pipeline.init_t2s_weights(request.gpt_weights_path)
|
||
if request.sovits_weights_path:
|
||
tts_pipeline.init_vits_weights(request.sovits_weights_path)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc
|
||
|
||
return {"message": "success"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
try:
|
||
uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1)
|
||
except Exception:
|
||
traceback.print_exc()
|
||
sys.exit(1)
|