""" Small profile-based API layer for GPT-SoVITS. Run: python simple_api.py -c simple_api.yaml Then call: POST http://127.0.0.1:9881/speak {"text": "hello", "voice": "default"} """ from __future__ import annotations import argparse import base64 import os import shutil import subprocess import sys import threading import traceback import uuid import wave from copy import deepcopy from io import BytesIO from pathlib import Path from typing import Any, Dict, Generator, List, Optional, Tuple, Union import numpy as np import soundfile as sf import yaml from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field PROJECT_ROOT = Path(__file__).resolve().parent os.chdir(PROJECT_ROOT) sys.path.append(str(PROJECT_ROOT)) sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS")) from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402 from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402 get_method_names as get_cut_method_names, ) DEFAULT_TTS_PARAMS: Dict[str, Any] = { "text": "", "text_lang": "zh", "ref_audio_path": "", "aux_ref_audio_paths": [], "prompt_text": "", "prompt_lang": "zh", "top_k": 15, "top_p": 1.0, "temperature": 1.0, "text_split_method": "cut5", "batch_size": 1, "batch_threshold": 0.75, "split_bucket": True, "speed_factor": 1.0, "fragment_interval": 0.3, "seed": -1, "media_type": "wav", "streaming_mode": False, "return_fragment": False, "fixed_length_chunk": False, "parallel_infer": True, "repetition_penalty": 1.35, "sample_steps": 32, "super_sampling": False, "overlap_length": 2, "min_chunk_length": 16, } SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"} SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"} VOICE_META_KEYS = {"description"} class SpeakRequest(BaseModel): text: str voice: Optional[str] = None text_lang: Optional[str] = None ref_audio_path: Optional[str] = None aux_ref_audio_paths: Optional[List[str]] = None prompt_text: Optional[str] = None prompt_lang: Optional[str] = None format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac") stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2") streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false") speed: Optional[float] = Field(default=None, description="Alias of speed_factor") speed_factor: Optional[float] = None top_k: Optional[int] = None top_p: Optional[float] = None temperature: Optional[float] = None text_split_method: Optional[str] = None batch_size: Optional[int] = None batch_threshold: Optional[float] = None split_bucket: Optional[bool] = None fragment_interval: Optional[float] = None seed: Optional[int] = None parallel_infer: Optional[bool] = None repetition_penalty: Optional[float] = None sample_steps: Optional[int] = None super_sampling: Optional[bool] = None overlap_length: Optional[int] = None min_chunk_length: Optional[int] = None class WeightsRequest(BaseModel): gpt_weights_path: Optional[str] = None sovits_weights_path: Optional[str] = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="GPT-SoVITS simple API") parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path") parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path") parser.add_argument("-a", "--bind-addr", default=None, help="bind address") parser.add_argument("-p", "--port", type=int, default=None, help="bind port") return parser.parse_args() def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]: path = Path(config_path) if not path.is_absolute(): path = PROJECT_ROOT / path if not path.exists(): raise FileNotFoundError(f"simple API config not found: {path}") with path.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} if not isinstance(data, dict): raise ValueError("simple API config must be a YAML object") return data args = parse_args() simple_config = load_yaml_config(args.config) server_config = simple_config.get("server", {}) or {} host = args.bind_addr or server_config.get("host", "127.0.0.1") port = args.port or int(server_config.get("port", 9881)) tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml") cut_method_names = get_cut_method_names() tts_config: Optional[TTS_Config] = None tts_pipeline: Optional[TTS] = None infer_lock = threading.Lock() APP = FastAPI( title="GPT-SoVITS Simple API", description="Profile-based API layer that hides GPT-SoVITS request details.", version="1.0.0", ) APP.add_middleware( CORSMiddleware, allow_origins=simple_config.get("cors_allow_origins", ["*"]), allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) test_frontend_dir = PROJECT_ROOT / "test_frontend" if test_frontend_dir.exists(): APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend") def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: def handle_pack_ogg() -> None: with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: audio_file.write(data) try: threading.stack_size(4096 * 4096) pack_thread = threading.Thread(target=handle_pack_ogg) pack_thread.start() pack_thread.join() except (RuntimeError, ValueError): handle_pack_ogg() return io_buffer def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: del rate io_buffer.write(data.tobytes()) return io_buffer def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: del io_buffer wav_buffer = BytesIO() sf.write(wav_buffer, data, rate, format="wav") return wav_buffer def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: process = subprocess.Popen( [ "ffmpeg", "-f", "s16le", "-ar", str(rate), "-ac", "1", "-i", "pipe:0", "-c:a", "aac", "-b:a", "192k", "-vn", "-f", "adts", "pipe:1", ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) out, _ = process.communicate(input=data.tobytes()) io_buffer.write(out) return io_buffer def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO: if media_type == "ogg": io_buffer = pack_ogg(io_buffer, data, rate) elif media_type == "aac": io_buffer = pack_aac(io_buffer, data, rate) elif media_type == "wav": io_buffer = pack_wav(io_buffer, data, rate) else: io_buffer = pack_raw(io_buffer, data, rate) io_buffer.seek(0) return io_buffer def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes: wav_buf = BytesIO() with wave.open(wav_buf, "wb") as vfout: vfout.setnchannels(channels) vfout.setsampwidth(sample_width) vfout.setframerate(sample_rate) vfout.writeframes(frame_input) wav_buf.seek(0) return wav_buf.read() def request_to_dict(request: BaseModel) -> Dict[str, Any]: if hasattr(request, "model_dump"): return request.model_dump(exclude_none=True) return request.dict(exclude_none=True) def get_default_voice_name() -> Optional[str]: default_voice = simple_config.get("default_voice") if default_voice: return str(default_voice) voices = simple_config.get("voices", {}) or {} if "default" in voices: return "default" return next(iter(voices.keys()), None) def resolve_project_path(path_value: Optional[str]) -> Optional[str]: if path_value in [None, ""]: return None path = Path(str(path_value)) if not path.is_absolute(): path = PROJECT_ROOT / path return str(path) def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]: if streaming_mode is None: streaming_mode = 2 if stream else 0 elif isinstance(streaming_mode, bool): streaming_mode = 2 if streaming_mode else 0 if streaming_mode == 0: return False, False, False, False if streaming_mode == 1: return False, True, False, True if streaming_mode == 2: return True, False, False, True if streaming_mode == 3: return True, False, True, True raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false") def get_upload_config() -> Dict[str, Any]: return simple_config.get("upload", {}) or {} def get_upload_root() -> Path: upload_dir = str(get_upload_config().get("dir", "runtime/uploads")) root = Path(upload_dir) if not root.is_absolute(): root = PROJECT_ROOT / root root.mkdir(parents=True, exist_ok=True) return root def get_request_upload_dir() -> Path: request_dir = get_upload_root() / uuid.uuid4().hex request_dir.mkdir(parents=True, exist_ok=True) return request_dir def get_upload_limits() -> Tuple[float, float, int]: upload_config = get_upload_config() min_seconds = float(upload_config.get("min_ref_seconds", 3.0)) max_seconds = float(upload_config.get("max_ref_seconds", 10.0)) max_upload_mb = int(upload_config.get("max_upload_mb", 80)) return min_seconds, max_seconds, max_upload_mb def get_upload_suffix(upload: UploadFile) -> str: suffix = Path(upload.filename or "").suffix.lower() if suffix: return suffix content_type = (upload.content_type or "").lower() if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}: return ".wav" if content_type == "audio/mpeg": return ".mp3" if content_type == "audio/ogg": return ".ogg" if content_type in {"audio/aac", "audio/aacp"}: return ".aac" return ".wav" def validate_audio_upload(upload: UploadFile) -> None: suffix = get_upload_suffix(upload) content_type = (upload.content_type or "").lower() is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""} if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type: raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}") async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str: validate_audio_upload(upload) _, _, max_upload_mb = get_upload_limits() max_bytes = max_upload_mb * 1024 * 1024 suffix = get_upload_suffix(upload) target_path = target_dir / f"{name_prefix}{suffix}" total_size = 0 try: with target_path.open("wb") as f: while True: chunk = await upload.read(1024 * 1024) if not chunk: break total_size += len(chunk) if total_size > max_bytes: raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB") f.write(chunk) if total_size == 0: raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}") finally: await upload.close() return str(target_path) def get_audio_duration_with_ffprobe(audio_path: str) -> float: result = subprocess.run( [ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", audio_path, ], capture_output=True, text=True, timeout=15, ) if result.returncode != 0: raise RuntimeError(result.stderr.strip() or "ffprobe failed") return float(result.stdout.strip()) def get_audio_duration_seconds(audio_path: str) -> float: try: info = sf.info(audio_path) if info.samplerate <= 0: raise ValueError("invalid sample rate") return float(info.frames) / float(info.samplerate) except Exception as sf_exc: try: return get_audio_duration_with_ffprobe(audio_path) except Exception as ffprobe_exc: raise HTTPException( status_code=400, detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}", ) from ffprobe_exc def validate_main_ref_duration(audio_path: str) -> None: min_seconds, max_seconds, _ = get_upload_limits() duration = get_audio_duration_seconds(audio_path) if duration < min_seconds or duration > max_seconds: raise HTTPException( status_code=400, detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s", ) def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None: if emotion in [None, ""]: return emotion_name = str(emotion).strip().lower() emotion_presets = simple_config.get("emotion_presets", {}) or {} if emotion_name not in emotion_presets: supported = ", ".join(sorted(emotion_presets.keys())) or "none" raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}") payload.update(emotion_presets[emotion_name] or {}) def prompt_text_required_for_current_model() -> bool: if tts_config is None: return False use_vocoder = getattr(tts_config, "use_vocoder", None) if use_vocoder is not None: return bool(use_vocoder) return getattr(tts_config, "version", None) in {"v3", "v4"} def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]: ref_audio_path = resolve_project_path(profile.get("ref_audio_path")) return { "name": name, "description": profile.get("description", ""), "text_lang": profile.get("text_lang"), "prompt_lang": profile.get("prompt_lang"), "ref_audio_path": profile.get("ref_audio_path"), "ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")), } def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]: voices = simple_config.get("voices", {}) or {} explicit_ref_audio = bool(payload.get("ref_audio_path")) voice_name = payload.pop("voice", None) if voice_name is None and not explicit_ref_audio: voice_name = get_default_voice_name() voice_profile: Dict[str, Any] = {} if voice_name: if voice_name not in voices: raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}") voice_profile = deepcopy(voices[voice_name] or {}) tts_req = deepcopy(DEFAULT_TTS_PARAMS) tts_req.update(simple_config.get("defaults", {}) or {}) tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS}) stream = payload.pop("stream", None) media_type = payload.pop("format", None) speed = payload.pop("speed", None) if media_type is not None: tts_req["media_type"] = media_type for key, value in payload.items(): if key in DEFAULT_TTS_PARAMS: tts_req[key] = value if speed is not None: tts_req["speed_factor"] = speed ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path")) if ref_audio_path: tts_req["ref_audio_path"] = ref_audio_path aux_paths = tts_req.get("aux_ref_audio_paths") or [] tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item] media_type = str(tts_req.get("media_type", "wav")).lower() tts_req["media_type"] = media_type if media_type not in SUPPORTED_MEDIA_TYPES: raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}") text = str(tts_req.get("text") or "").strip() if not text: raise HTTPException(status_code=400, detail="text is required") tts_req["text"] = text if not tts_req.get("ref_audio_path"): raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request") if not Path(tts_req["ref_audio_path"]).exists(): raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}") for aux_path in tts_req["aux_ref_audio_paths"]: if aux_path and not Path(aux_path).exists(): raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}") if tts_config is None: raise HTTPException(status_code=503, detail="TTS pipeline is not ready") text_lang = str(tts_req.get("text_lang") or "").lower() prompt_lang = str(tts_req.get("prompt_lang") or "").lower() tts_req["text_lang"] = text_lang tts_req["prompt_lang"] = prompt_lang if text_lang not in tts_config.languages: raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}") if prompt_lang not in tts_config.languages: raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}") if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model(): version = getattr(tts_config, "version", "current") raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}") split_method = str(tts_req.get("text_split_method") or "cut5") if split_method not in cut_method_names: raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}") streaming_mode = tts_req.pop("streaming_mode", None) streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream) tts_req["streaming_mode"] = streaming_enabled tts_req["return_fragment"] = return_fragment tts_req["fixed_length_chunk"] = fixed_length_chunk return tts_req, media_type, response_stream def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]: tts_req, media_type, response_stream = build_tts_request(payload) if response_stream: raise HTTPException(status_code=400, detail="base64 output does not support streaming") if tts_pipeline is None: raise HTTPException(status_code=503, detail="TTS pipeline is not ready") try: with infer_lock: sr, audio_data = next(tts_pipeline.run(tts_req)) audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() return audio_bytes, media_type except Exception as exc: raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc def synthesize_response(payload: Dict[str, Any]) -> Response: tts_req, media_type, response_stream = build_tts_request(payload) if tts_pipeline is None: raise HTTPException(status_code=503, detail="TTS pipeline is not ready") if response_stream: def streaming_generator() -> Generator[bytes, None, None]: first_chunk = True chunk_media_type = media_type with infer_lock: for sr, chunk in tts_pipeline.run(tts_req): if first_chunk and chunk_media_type == "wav": yield wave_header_chunk(sample_rate=sr) chunk_media_type = "raw" first_chunk = False yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue() return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}") try: with infer_lock: sr, audio_data = next(tts_pipeline.run(tts_req)) audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() return Response(audio_bytes, media_type=f"audio/{media_type}") except Exception as exc: return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)}) @APP.on_event("startup") def startup() -> None: global tts_config, tts_pipeline tts_config = TTS_Config(tts_config_path) print(tts_config) tts_pipeline = TTS(tts_config) @APP.get("/") def index() -> Dict[str, Any]: return { "name": "GPT-SoVITS Simple API", "endpoints": [ "/health", "/voices", "/api/tts", "/speak", "/speak/base64", "/admin/weights", "/admin/reload-config", ], } @APP.get("/health") def health() -> Dict[str, Any]: return { "status": "ok" if tts_pipeline is not None else "starting", "tts_config": tts_config_path, "version": getattr(tts_config, "version", None), "languages": getattr(tts_config, "languages", []), } @APP.get("/voices") def list_voices() -> Dict[str, Any]: voices = simple_config.get("voices", {}) or {} return { "default_voice": get_default_voice_name(), "voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()], } @APP.post("/api/tts") async def mvp_tts( text: str = Form(...), ref_audio: UploadFile = File(...), aux_ref_audio: Optional[List[UploadFile]] = File(default=None), prompt_text: str = Form(default=""), text_lang: str = Form(default="zh"), prompt_lang: str = Form(default="zh"), format: str = Form(default="wav"), emotion: Optional[str] = Form(default=None), speed: Optional[float] = Form(default=None), seed: int = Form(default=-1), ) -> Response: request_dir = get_request_upload_dir() try: ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref") validate_main_ref_duration(ref_audio_path) aux_ref_audio_paths = [] for index, upload in enumerate(aux_ref_audio or []): aux_path = await save_upload_file(upload, request_dir, f"aux_{index}") get_audio_duration_seconds(aux_path) aux_ref_audio_paths.append(aux_path) payload: Dict[str, Any] = { "text": text, "text_lang": text_lang, "ref_audio_path": ref_audio_path, "aux_ref_audio_paths": aux_ref_audio_paths, "prompt_text": prompt_text or "", "prompt_lang": prompt_lang, "format": format, "text_split_method": "cut5", "streaming_mode": 0, "seed": seed, } apply_emotion_preset(payload, emotion) payload["text_split_method"] = "cut5" payload["streaming_mode"] = 0 if speed is not None: payload["speed"] = speed return synthesize_response(payload) finally: shutil.rmtree(request_dir, ignore_errors=True) @APP.get("/speak") def speak_get( text: str, voice: Optional[str] = None, text_lang: Optional[str] = None, format: Optional[str] = None, stream: Optional[bool] = None, speed: Optional[float] = None, ) -> Response: payload = { "text": text, "voice": voice, "text_lang": text_lang, "format": format, "stream": stream, "speed": speed, } return synthesize_response({k: v for k, v in payload.items() if v is not None}) @APP.post("/speak") def speak_post(request: SpeakRequest) -> Response: return synthesize_response(request_to_dict(request)) @APP.post("/v1/tts") def openai_style_tts(request: SpeakRequest) -> Response: return synthesize_response(request_to_dict(request)) @APP.post("/speak/base64") def speak_base64(request: SpeakRequest) -> Dict[str, Any]: audio_bytes, media_type = synthesize_once(request_to_dict(request)) return { "media_type": f"audio/{media_type}", "audio_base64": base64.b64encode(audio_bytes).decode("ascii"), } @APP.post("/admin/reload-config") def reload_config() -> Dict[str, Any]: global simple_config simple_config = load_yaml_config(args.config) return {"message": "success", "default_voice": get_default_voice_name()} @APP.post("/admin/weights") def set_weights(request: WeightsRequest) -> Dict[str, Any]: if tts_pipeline is None: raise HTTPException(status_code=503, detail="TTS pipeline is not ready") if not request.gpt_weights_path and not request.sovits_weights_path: raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required") try: with infer_lock: if request.gpt_weights_path: tts_pipeline.init_t2s_weights(request.gpt_weights_path) if request.sovits_weights_path: tts_pipeline.init_vits_weights(request.sovits_weights_path) except Exception as exc: raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc return {"message": "success"} if __name__ == "__main__": import uvicorn try: uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1) except Exception: traceback.print_exc() sys.exit(1)