GPT-SoVITS-WebUI
A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.
@@ -27,6 +27,24 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.
---
+## 简化接口 / 测试前端
+
+本工作区新增了一个用于 MVP 调用的中间层接口和测试前端:
+
+- 快速启动:[docs/simple_api_quickstart.md](./docs/simple_api_quickstart.md)
+- 完整教程:[docs/simple_api.md](./docs/simple_api.md)
+- 后端入口:simple_api.py
+- 测试前端:启动后访问 http://127.0.0.1:9881/test/
+
+启动命令:
+
+`powershell
+cd D:\tts\GPT-SoVITS
+.\go-simple-api.ps1
+`
+
+---
+
## Features:
1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion.
@@ -478,3 +496,4 @@ Thankful to @Naozumi520 for providing the Cantonese training set and for the gui
+
diff --git a/docs/simple_api.md b/docs/simple_api.md
new file mode 100644
index 00000000..bcdedaf6
--- /dev/null
+++ b/docs/simple_api.md
@@ -0,0 +1,320 @@
+# GPT-SoVITS 简化接口教程
+
+本项目原本已经有 `api_v2.py`,但调用时需要传很多 GPT-SoVITS 参数。新增的 `simple_api.py` 是一个中间层,目标是让前端或业务方用更简单的方式调用。
+
+当前 MVP 推荐使用:
+
+```http
+POST /api/tts
+Content-Type: multipart/form-data
+```
+
+适合你的流程:
+
+1. 前端从视频中提取音频(或直接上传已裁剪的 3-10 秒音频)。
+2. 用户人工裁剪主参考音频到 3-10 秒。
+3. 前端把主参考音频、要生成的文字、可选辅助音频提交给后端。
+4. 后端调用 GPT-SoVITS 生成音频并返回。
+
+## 1. 安装依赖
+
+在 GPT-SoVITS 使用的 Python 环境里运行:
+
+```bash
+python -m pip install -r requirements.txt
+```
+
+本中间层额外需要:
+
+- `python-multipart`:用于接收前端上传的音频文件。
+- `soundfile`:用于读取音频信息和校验参考音频时长。
+
+这两个依赖已经写入 `requirements.txt`。
+
+## 2. 检查配置
+
+配置文件:
+
+```text
+simple_api.yaml
+```
+
+常用配置:
+
+```yaml
+server:
+ host: 127.0.0.1
+ port: 9881
+ tts_config: GPT_SoVITS/configs/tts_infer.yaml
+
+upload:
+ dir: runtime/uploads
+ min_ref_seconds: 3
+ max_ref_seconds: 10
+ max_upload_mb: 80
+```
+
+说明:
+
+- `server.host`:后端监听地址。
+- `server.port`:后端端口。
+- `server.tts_config`:GPT-SoVITS 推理配置文件。
+- `upload.dir`:临时上传目录。
+- `upload.min_ref_seconds`:主参考音频最短秒数。
+- `upload.max_ref_seconds`:主参考音频最长秒数。
+- `upload.max_upload_mb`:单个上传音频最大体积。
+
+当前默认后端地址:
+
+```text
+http://127.0.0.1:9881
+```
+
+## 3. 启动后端
+
+进入项目目录:
+
+```powershell
+cd D:\tts\GPT-SoVITS
+```
+
+方式一:
+
+```powershell
+python simple_api.py -c simple_api.yaml
+```
+
+方式二,Windows PowerShell:
+
+```powershell
+.\go-simple-api.ps1
+```
+
+方式三,Windows 批处理:
+
+```bat
+go-simple-api.bat
+```
+
+启动成功后,后端会监听:
+
+```text
+http://127.0.0.1:9881
+```
+
+## 4. 健康检查
+
+后端启动后,可以访问:
+
+```text
+http://127.0.0.1:9881/health
+```
+
+或者命令行测试:
+
+```bash
+curl http://127.0.0.1:9881/health
+```
+
+返回示例:
+
+```json
+{
+ "status": "ok",
+ "tts_config": "GPT_SoVITS/configs/tts_infer.yaml",
+ "version": "v2",
+ "languages": ["auto", "auto_yue", "en", "zh"]
+}
+```
+
+## 5. 打开测试前端
+
+后端启动后,推荐直接打开:
+
+```text
+http://127.0.0.1:9881/test/
+```
+
+这个页面已经挂载在后端服务里,不需要额外启动前端服务。
+
+也可以直接打开本地文件:
+
+```text
+test_frontend/index.html
+```
+
+Windows 下也可以运行:
+
+```powershell
+.\open-test-frontend.ps1
+```
+
+测试前端里有一个“后端接口地址”输入框。
+
+默认值:
+
+```text
+http://127.0.0.1:9881/api/tts
+```
+
+如果你修改了 `server.host` 或 `server.port`,记得同步修改页面里的接口地址。
+
+## 6. MVP 接口
+
+接口地址:
+
+```http
+POST /api/tts
+```
+
+请求类型:
+
+```http
+multipart/form-data
+```
+
+字段说明:
+
+```text
+text 必填。需要生成的文字。
+ref_audio 必填。主参考音频(支持上传视频,前端会自动提取音频),要求 3-10 秒。
+aux_ref_audio 可选。辅助参考音频,可以上传多个。
+prompt_text 可选。主参考音频对应文字,可以留空。
+text_lang 可选。生成文字语言,默认 zh。
+prompt_lang 可选。参考音频语言,默认 zh。即使 prompt_text 为空,也需要传给 GPT-SoVITS 内部。
+format 可选。返回格式,默认 wav。
+emotion 可选。情绪 preset:neutral、happy、calm、sad、angry。
+speed 可选。语速,对应 GPT-SoVITS 的 speed_factor。
+seed 可选。随机种子,默认 -1。
+```
+
+## 7. curl 调用示例
+
+PowerShell 示例:
+
+```powershell
+curl.exe -X POST http://127.0.0.1:9881/api/tts `
+ -F "text=你好,欢迎使用这个声音。" `
+ -F "ref_audio=@D:\audio\ref.wav" `
+ -F "prompt_text=" `
+ -F "text_lang=zh" `
+ -F "prompt_lang=zh" `
+ -F "emotion=neutral" `
+ --output output.wav
+```
+
+如果有辅助参考音频:
+
+```powershell
+curl.exe -X POST http://127.0.0.1:9881/api/tts `
+ -F "text=你好,欢迎使用这个声音。" `
+ -F "ref_audio=@D:\audio\ref.wav" `
+ -F "aux_ref_audio=@D:\audio\aux1.wav" `
+ -F "aux_ref_audio=@D:\audio\aux2.wav" `
+ -F "prompt_text=" `
+ -F "text_lang=zh" `
+ -F "prompt_lang=zh" `
+ --output output.wav
+```
+
+## 8. 重要规则
+
+- 主参考音频必须是 3-10 秒。
+- `aux_ref_audio` 是可选项。
+- `prompt_text` 可以为空,但当前主要针对 GPT-SoVITS v2。
+- 如果切到 GPT-SoVITS v3/v4,空 `prompt_text` 会被中间层直接返回 400。
+- 生成文字固定使用 `cut5`,也就是按照标点符号切句。
+- `emotion` 目前是轻量 preset,本质是映射到采样和语速参数;更稳定的情绪控制仍然依赖带情绪的参考音频。
+
+## 9. 测试前端使用步骤
+
+1. 启动后端。
+2. 打开 `http://127.0.0.1:9881/test/`。
+3. 检查“后端接口地址”是否为:
+
+```text
+http://127.0.0.1:9881/api/tts
+```
+
+4. 填写“需要生成的文字”。
+5. 上传主参考音频。
+6. 可选上传辅助参考音频。
+7. 可选填写参考音频文字。
+8. 选择语言、情绪、语速。
+9. 点击“生成音频”。
+10. 页面右侧会显示返回结果,可以在线播放和下载。
+
+## 10. 其他接口
+
+除了 MVP 上传接口,还保留了 profile 调用接口:
+
+```http
+GET /voices
+GET /speak?text=hello&voice=default
+POST /speak
+POST /speak/base64
+POST /v1/tts
+POST /admin/reload-config
+POST /admin/weights
+```
+
+这些接口适合后续做固定音色 profile,不是当前 MVP 的主流程。
+
+## 11. Base64 返回示例
+
+```bash
+curl -X POST http://127.0.0.1:9881/speak/base64 ^
+ -H "Content-Type: application/json" ^
+ -d "{\"text\":\"hello\",\"voice\":\"default\"}"
+```
+
+返回格式:
+
+```json
+{
+ "media_type": "audio/wav",
+ "audio_base64": "..."
+}
+```
+
+## 12. 添加固定音色 profile
+
+如果后续要做固定音色,可以编辑 `simple_api.yaml`:
+
+```yaml
+voices:
+ default:
+ ref_audio_path: reference.wav
+ prompt_text: exact transcript
+ prompt_lang: zh
+ text_lang: zh
+
+ narrator:
+ ref_audio_path: voices/narrator.wav
+ prompt_text: exact transcript of narrator.wav
+ prompt_lang: zh
+ text_lang: zh
+```
+
+编辑后调用:
+
+```bash
+curl -X POST http://127.0.0.1:9881/admin/reload-config
+```
+
+## 13. 契约测试
+
+这个测试使用 mock,不会加载 GPT-SoVITS 模型:
+
+```bash
+python -m unittest tests.test_simple_api_contract
+```
+
+用于确认:
+
+- `/api/tts` 路由存在。
+- 上传接口能构造正确参数。
+- 主参考音频 3-10 秒校验正常。
+- 空 `prompt_text` 在 v2 可用。
+- v3/v4 空 `prompt_text` 会返回 400。
+- 临时上传目录会被清理。
diff --git a/docs/simple_api_quickstart.md b/docs/simple_api_quickstart.md
new file mode 100644
index 00000000..2f07eb65
--- /dev/null
+++ b/docs/simple_api_quickstart.md
@@ -0,0 +1,89 @@
+# 简化 TTS 接口快速启动
+
+完整教程见:`docs/simple_api.md`
+
+## 一句话流程
+
+启动后端,打开测试页,上传 3-10 秒参考音频或视频(视频会自动提取音频),填写生成文字,点击生成。
+
+## 1. 启动后端
+
+```powershell
+cd D:\tts\GPT-SoVITS
+python -m pip install -r requirements.txt
+.\go-simple-api.ps1
+```
+
+也可以运行:
+
+```bat
+go-simple-api.bat
+```
+
+默认后端地址:
+
+```text
+http://127.0.0.1:9881
+```
+
+## 2. 打开测试前端
+
+后端启动后访问:
+
+```text
+http://127.0.0.1:9881/test/
+```
+
+测试页面里的默认接口地址是:
+
+```text
+http://127.0.0.1:9881/api/tts
+```
+
+## 3. 调用接口
+
+接口:
+
+```http
+POST /api/tts
+Content-Type: multipart/form-data
+```
+
+最小字段:
+
+```text
+text 要生成的文字
+ref_audio 3-10 秒主参考音频(支持视频,前端自动提取音频)
+```
+
+常用可选字段:
+
+```text
+aux_ref_audio 辅助参考音频,可多个
+prompt_text 参考音频文本,可留空
+text_lang 默认 zh
+prompt_lang 默认 zh
+emotion neutral / happy / calm / sad / angry
+speed 语速,默认 1
+seed 默认 -1
+```
+
+## 4. PowerShell 示例
+
+```powershell
+curl.exe -X POST http://127.0.0.1:9881/api/tts `
+ -F "text=你好,欢迎使用这个声音。" `
+ -F "ref_audio=@D:\audio\ref.wav" `
+ -F "prompt_text=" `
+ -F "text_lang=zh" `
+ -F "prompt_lang=zh" `
+ --output output.wav
+```
+
+## 5. 注意事项
+
+- 主参考音频必须是 3-10 秒。
+- `prompt_text` 在当前 v2 配置下可以为空。
+- 如果切换到 v3/v4,空 `prompt_text` 会返回 400。
+- 生成文字固定按标点符号切句。
+- 更详细的配置、profile、base64 接口见 `docs/simple_api.md`。
diff --git a/go-simple-api.bat b/go-simple-api.bat
new file mode 100644
index 00000000..2c2c7d62
--- /dev/null
+++ b/go-simple-api.bat
@@ -0,0 +1,12 @@
+@echo off
+set "SCRIPT_DIR=%~dp0"
+set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
+cd /d "%SCRIPT_DIR%"
+
+if exist "%SCRIPT_DIR%\runtime\python.exe" (
+ "%SCRIPT_DIR%\runtime\python.exe" "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
+) else (
+ python "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
+)
+
+pause
diff --git a/go-simple-api.ps1 b/go-simple-api.ps1
new file mode 100644
index 00000000..4d1336c5
--- /dev/null
+++ b/go-simple-api.ps1
@@ -0,0 +1,13 @@
+$ErrorActionPreference = "Stop"
+chcp 65001
+Set-Location $PSScriptRoot
+
+$runtimePython = Join-Path $PSScriptRoot "runtime\python.exe"
+$apiScript = Join-Path $PSScriptRoot "simple_api.py"
+$apiConfig = Join-Path $PSScriptRoot "simple_api.yaml"
+
+if (Test-Path $runtimePython) {
+ & $runtimePython $apiScript -c $apiConfig @args
+} else {
+ python $apiScript -c $apiConfig @args
+}
diff --git a/open-test-frontend.ps1 b/open-test-frontend.ps1
new file mode 100644
index 00000000..7e26d5a6
--- /dev/null
+++ b/open-test-frontend.ps1
@@ -0,0 +1,3 @@
+$ErrorActionPreference = "Stop"
+Set-Location $PSScriptRoot
+Start-Process (Join-Path $PSScriptRoot "test_frontend\index.html")
diff --git a/requirements.txt b/requirements.txt
index 3b7cd898..caa9a448 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ numpy<2.0
scipy
tensorboard
librosa==0.10.2
+soundfile
numba
pytorch-lightning>=2.4
gradio<5
@@ -22,6 +23,7 @@ transformers>=4.43,<=4.50
peft<0.18.0
chardet
PyYAML
+python-multipart
psutil
jieba_fast
jieba
diff --git a/simple_api.py b/simple_api.py
new file mode 100644
index 00000000..91180548
--- /dev/null
+++ b/simple_api.py
@@ -0,0 +1,741 @@
+"""
+Small profile-based API layer for GPT-SoVITS.
+
+Run:
+ python simple_api.py -c simple_api.yaml
+
+Then call:
+ POST http://127.0.0.1:9881/speak
+ {"text": "hello", "voice": "default"}
+"""
+
+from __future__ import annotations
+
+import argparse
+import base64
+import os
+import shutil
+import subprocess
+import sys
+import threading
+import traceback
+import uuid
+import wave
+from copy import deepcopy
+from io import BytesIO
+from pathlib import Path
+from typing import Any, Dict, Generator, List, Optional, Tuple, Union
+
+import numpy as np
+import soundfile as sf
+import yaml
+from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse, StreamingResponse
+from fastapi.staticfiles import StaticFiles
+from pydantic import BaseModel, Field
+
+PROJECT_ROOT = Path(__file__).resolve().parent
+os.chdir(PROJECT_ROOT)
+sys.path.append(str(PROJECT_ROOT))
+sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS"))
+
+from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
+from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
+ get_method_names as get_cut_method_names,
+)
+
+
+DEFAULT_TTS_PARAMS: Dict[str, Any] = {
+ "text": "",
+ "text_lang": "zh",
+ "ref_audio_path": "",
+ "aux_ref_audio_paths": [],
+ "prompt_text": "",
+ "prompt_lang": "zh",
+ "top_k": 15,
+ "top_p": 1.0,
+ "temperature": 1.0,
+ "text_split_method": "cut5",
+ "batch_size": 1,
+ "batch_threshold": 0.75,
+ "split_bucket": True,
+ "speed_factor": 1.0,
+ "fragment_interval": 0.3,
+ "seed": -1,
+ "media_type": "wav",
+ "streaming_mode": False,
+ "return_fragment": False,
+ "fixed_length_chunk": False,
+ "parallel_infer": True,
+ "repetition_penalty": 1.35,
+ "sample_steps": 32,
+ "super_sampling": False,
+ "overlap_length": 2,
+ "min_chunk_length": 16,
+}
+
+SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"}
+SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"}
+VOICE_META_KEYS = {"description"}
+
+
+class SpeakRequest(BaseModel):
+ text: str
+ voice: Optional[str] = None
+ text_lang: Optional[str] = None
+ ref_audio_path: Optional[str] = None
+ aux_ref_audio_paths: Optional[List[str]] = None
+ prompt_text: Optional[str] = None
+ prompt_lang: Optional[str] = None
+ format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac")
+ stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2")
+ streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false")
+ speed: Optional[float] = Field(default=None, description="Alias of speed_factor")
+ speed_factor: Optional[float] = None
+ top_k: Optional[int] = None
+ top_p: Optional[float] = None
+ temperature: Optional[float] = None
+ text_split_method: Optional[str] = None
+ batch_size: Optional[int] = None
+ batch_threshold: Optional[float] = None
+ split_bucket: Optional[bool] = None
+ fragment_interval: Optional[float] = None
+ seed: Optional[int] = None
+ parallel_infer: Optional[bool] = None
+ repetition_penalty: Optional[float] = None
+ sample_steps: Optional[int] = None
+ super_sampling: Optional[bool] = None
+ overlap_length: Optional[int] = None
+ min_chunk_length: Optional[int] = None
+
+
+class WeightsRequest(BaseModel):
+ gpt_weights_path: Optional[str] = None
+ sovits_weights_path: Optional[str] = None
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="GPT-SoVITS simple API")
+ parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path")
+ parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path")
+ parser.add_argument("-a", "--bind-addr", default=None, help="bind address")
+ parser.add_argument("-p", "--port", type=int, default=None, help="bind port")
+ return parser.parse_args()
+
+
+def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
+ path = Path(config_path)
+ if not path.is_absolute():
+ path = PROJECT_ROOT / path
+ if not path.exists():
+ raise FileNotFoundError(f"simple API config not found: {path}")
+ with path.open("r", encoding="utf-8") as f:
+ data = yaml.safe_load(f) or {}
+ if not isinstance(data, dict):
+ raise ValueError("simple API config must be a YAML object")
+ return data
+
+
+args = parse_args()
+simple_config = load_yaml_config(args.config)
+server_config = simple_config.get("server", {}) or {}
+host = args.bind_addr or server_config.get("host", "127.0.0.1")
+port = args.port or int(server_config.get("port", 9881))
+tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml")
+
+cut_method_names = get_cut_method_names()
+tts_config: Optional[TTS_Config] = None
+tts_pipeline: Optional[TTS] = None
+infer_lock = threading.Lock()
+
+APP = FastAPI(
+ title="GPT-SoVITS Simple API",
+ description="Profile-based API layer that hides GPT-SoVITS request details.",
+ version="1.0.0",
+)
+APP.add_middleware(
+ CORSMiddleware,
+ allow_origins=simple_config.get("cors_allow_origins", ["*"]),
+ allow_credentials=False,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+test_frontend_dir = PROJECT_ROOT / "test_frontend"
+if test_frontend_dir.exists():
+ APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend")
+
+
+def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
+ def handle_pack_ogg() -> None:
+ with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
+ audio_file.write(data)
+
+ try:
+ threading.stack_size(4096 * 4096)
+ pack_thread = threading.Thread(target=handle_pack_ogg)
+ pack_thread.start()
+ pack_thread.join()
+ except (RuntimeError, ValueError):
+ handle_pack_ogg()
+ return io_buffer
+
+
+def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
+ del rate
+ io_buffer.write(data.tobytes())
+ return io_buffer
+
+
+def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
+ del io_buffer
+ wav_buffer = BytesIO()
+ sf.write(wav_buffer, data, rate, format="wav")
+ return wav_buffer
+
+
+def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
+ process = subprocess.Popen(
+ [
+ "ffmpeg",
+ "-f",
+ "s16le",
+ "-ar",
+ str(rate),
+ "-ac",
+ "1",
+ "-i",
+ "pipe:0",
+ "-c:a",
+ "aac",
+ "-b:a",
+ "192k",
+ "-vn",
+ "-f",
+ "adts",
+ "pipe:1",
+ ],
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ out, _ = process.communicate(input=data.tobytes())
+ io_buffer.write(out)
+ return io_buffer
+
+
+def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO:
+ if media_type == "ogg":
+ io_buffer = pack_ogg(io_buffer, data, rate)
+ elif media_type == "aac":
+ io_buffer = pack_aac(io_buffer, data, rate)
+ elif media_type == "wav":
+ io_buffer = pack_wav(io_buffer, data, rate)
+ else:
+ io_buffer = pack_raw(io_buffer, data, rate)
+ io_buffer.seek(0)
+ return io_buffer
+
+
+def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes:
+ wav_buf = BytesIO()
+ with wave.open(wav_buf, "wb") as vfout:
+ vfout.setnchannels(channels)
+ vfout.setsampwidth(sample_width)
+ vfout.setframerate(sample_rate)
+ vfout.writeframes(frame_input)
+ wav_buf.seek(0)
+ return wav_buf.read()
+
+
+def request_to_dict(request: BaseModel) -> Dict[str, Any]:
+ if hasattr(request, "model_dump"):
+ return request.model_dump(exclude_none=True)
+ return request.dict(exclude_none=True)
+
+
+def get_default_voice_name() -> Optional[str]:
+ default_voice = simple_config.get("default_voice")
+ if default_voice:
+ return str(default_voice)
+ voices = simple_config.get("voices", {}) or {}
+ if "default" in voices:
+ return "default"
+ return next(iter(voices.keys()), None)
+
+
+def resolve_project_path(path_value: Optional[str]) -> Optional[str]:
+ if path_value in [None, ""]:
+ return None
+ path = Path(str(path_value))
+ if not path.is_absolute():
+ path = PROJECT_ROOT / path
+ return str(path)
+
+
+def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]:
+ if streaming_mode is None:
+ streaming_mode = 2 if stream else 0
+ elif isinstance(streaming_mode, bool):
+ streaming_mode = 2 if streaming_mode else 0
+
+ if streaming_mode == 0:
+ return False, False, False, False
+ if streaming_mode == 1:
+ return False, True, False, True
+ if streaming_mode == 2:
+ return True, False, False, True
+ if streaming_mode == 3:
+ return True, False, True, True
+ raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false")
+
+
+def get_upload_config() -> Dict[str, Any]:
+ return simple_config.get("upload", {}) or {}
+
+
+def get_upload_root() -> Path:
+ upload_dir = str(get_upload_config().get("dir", "runtime/uploads"))
+ root = Path(upload_dir)
+ if not root.is_absolute():
+ root = PROJECT_ROOT / root
+ root.mkdir(parents=True, exist_ok=True)
+ return root
+
+
+def get_request_upload_dir() -> Path:
+ request_dir = get_upload_root() / uuid.uuid4().hex
+ request_dir.mkdir(parents=True, exist_ok=True)
+ return request_dir
+
+
+def get_upload_limits() -> Tuple[float, float, int]:
+ upload_config = get_upload_config()
+ min_seconds = float(upload_config.get("min_ref_seconds", 3.0))
+ max_seconds = float(upload_config.get("max_ref_seconds", 10.0))
+ max_upload_mb = int(upload_config.get("max_upload_mb", 80))
+ return min_seconds, max_seconds, max_upload_mb
+
+
+def get_upload_suffix(upload: UploadFile) -> str:
+ suffix = Path(upload.filename or "").suffix.lower()
+ if suffix:
+ return suffix
+ content_type = (upload.content_type or "").lower()
+ if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}:
+ return ".wav"
+ if content_type == "audio/mpeg":
+ return ".mp3"
+ if content_type == "audio/ogg":
+ return ".ogg"
+ if content_type in {"audio/aac", "audio/aacp"}:
+ return ".aac"
+ return ".wav"
+
+
+def validate_audio_upload(upload: UploadFile) -> None:
+ suffix = get_upload_suffix(upload)
+ content_type = (upload.content_type or "").lower()
+ is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""}
+ if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type:
+ raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}")
+
+
+async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str:
+ validate_audio_upload(upload)
+ _, _, max_upload_mb = get_upload_limits()
+ max_bytes = max_upload_mb * 1024 * 1024
+ suffix = get_upload_suffix(upload)
+ target_path = target_dir / f"{name_prefix}{suffix}"
+ total_size = 0
+
+ try:
+ with target_path.open("wb") as f:
+ while True:
+ chunk = await upload.read(1024 * 1024)
+ if not chunk:
+ break
+ total_size += len(chunk)
+ if total_size > max_bytes:
+ raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB")
+ f.write(chunk)
+ if total_size == 0:
+ raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}")
+ finally:
+ await upload.close()
+
+ return str(target_path)
+
+
+def get_audio_duration_with_ffprobe(audio_path: str) -> float:
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "error",
+ "-show_entries",
+ "format=duration",
+ "-of",
+ "default=noprint_wrappers=1:nokey=1",
+ audio_path,
+ ],
+ capture_output=True,
+ text=True,
+ timeout=15,
+ )
+ if result.returncode != 0:
+ raise RuntimeError(result.stderr.strip() or "ffprobe failed")
+ return float(result.stdout.strip())
+
+
+def get_audio_duration_seconds(audio_path: str) -> float:
+ try:
+ info = sf.info(audio_path)
+ if info.samplerate <= 0:
+ raise ValueError("invalid sample rate")
+ return float(info.frames) / float(info.samplerate)
+ except Exception as sf_exc:
+ try:
+ return get_audio_duration_with_ffprobe(audio_path)
+ except Exception as ffprobe_exc:
+ raise HTTPException(
+ status_code=400,
+ detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}",
+ ) from ffprobe_exc
+
+
+def validate_main_ref_duration(audio_path: str) -> None:
+ min_seconds, max_seconds, _ = get_upload_limits()
+ duration = get_audio_duration_seconds(audio_path)
+ if duration < min_seconds or duration > max_seconds:
+ raise HTTPException(
+ status_code=400,
+ detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s",
+ )
+
+
+def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None:
+ if emotion in [None, ""]:
+ return
+ emotion_name = str(emotion).strip().lower()
+ emotion_presets = simple_config.get("emotion_presets", {}) or {}
+ if emotion_name not in emotion_presets:
+ supported = ", ".join(sorted(emotion_presets.keys())) or "none"
+ raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}")
+ payload.update(emotion_presets[emotion_name] or {})
+
+
+def prompt_text_required_for_current_model() -> bool:
+ if tts_config is None:
+ return False
+ use_vocoder = getattr(tts_config, "use_vocoder", None)
+ if use_vocoder is not None:
+ return bool(use_vocoder)
+ return getattr(tts_config, "version", None) in {"v3", "v4"}
+
+
+def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]:
+ ref_audio_path = resolve_project_path(profile.get("ref_audio_path"))
+ return {
+ "name": name,
+ "description": profile.get("description", ""),
+ "text_lang": profile.get("text_lang"),
+ "prompt_lang": profile.get("prompt_lang"),
+ "ref_audio_path": profile.get("ref_audio_path"),
+ "ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")),
+ }
+
+
+def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]:
+ voices = simple_config.get("voices", {}) or {}
+ explicit_ref_audio = bool(payload.get("ref_audio_path"))
+ voice_name = payload.pop("voice", None)
+ if voice_name is None and not explicit_ref_audio:
+ voice_name = get_default_voice_name()
+ voice_profile: Dict[str, Any] = {}
+
+ if voice_name:
+ if voice_name not in voices:
+ raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}")
+ voice_profile = deepcopy(voices[voice_name] or {})
+
+ tts_req = deepcopy(DEFAULT_TTS_PARAMS)
+ tts_req.update(simple_config.get("defaults", {}) or {})
+ tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS})
+
+ stream = payload.pop("stream", None)
+ media_type = payload.pop("format", None)
+ speed = payload.pop("speed", None)
+ if media_type is not None:
+ tts_req["media_type"] = media_type
+
+ for key, value in payload.items():
+ if key in DEFAULT_TTS_PARAMS:
+ tts_req[key] = value
+ if speed is not None:
+ tts_req["speed_factor"] = speed
+
+ ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path"))
+ if ref_audio_path:
+ tts_req["ref_audio_path"] = ref_audio_path
+
+ aux_paths = tts_req.get("aux_ref_audio_paths") or []
+ tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item]
+
+ media_type = str(tts_req.get("media_type", "wav")).lower()
+ tts_req["media_type"] = media_type
+ if media_type not in SUPPORTED_MEDIA_TYPES:
+ raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}")
+
+ text = str(tts_req.get("text") or "").strip()
+ if not text:
+ raise HTTPException(status_code=400, detail="text is required")
+ tts_req["text"] = text
+
+ if not tts_req.get("ref_audio_path"):
+ raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request")
+ if not Path(tts_req["ref_audio_path"]).exists():
+ raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}")
+
+ for aux_path in tts_req["aux_ref_audio_paths"]:
+ if aux_path and not Path(aux_path).exists():
+ raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}")
+
+ if tts_config is None:
+ raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
+
+ text_lang = str(tts_req.get("text_lang") or "").lower()
+ prompt_lang = str(tts_req.get("prompt_lang") or "").lower()
+ tts_req["text_lang"] = text_lang
+ tts_req["prompt_lang"] = prompt_lang
+
+ if text_lang not in tts_config.languages:
+ raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}")
+ if prompt_lang not in tts_config.languages:
+ raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}")
+ if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model():
+ version = getattr(tts_config, "version", "current")
+ raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}")
+
+ split_method = str(tts_req.get("text_split_method") or "cut5")
+ if split_method not in cut_method_names:
+ raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}")
+
+ streaming_mode = tts_req.pop("streaming_mode", None)
+ streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream)
+ tts_req["streaming_mode"] = streaming_enabled
+ tts_req["return_fragment"] = return_fragment
+ tts_req["fixed_length_chunk"] = fixed_length_chunk
+
+ return tts_req, media_type, response_stream
+
+
+def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]:
+ tts_req, media_type, response_stream = build_tts_request(payload)
+ if response_stream:
+ raise HTTPException(status_code=400, detail="base64 output does not support streaming")
+
+ if tts_pipeline is None:
+ raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
+
+ try:
+ with infer_lock:
+ sr, audio_data = next(tts_pipeline.run(tts_req))
+ audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
+ return audio_bytes, media_type
+ except Exception as exc:
+ raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc
+
+
+def synthesize_response(payload: Dict[str, Any]) -> Response:
+ tts_req, media_type, response_stream = build_tts_request(payload)
+ if tts_pipeline is None:
+ raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
+
+ if response_stream:
+
+ def streaming_generator() -> Generator[bytes, None, None]:
+ first_chunk = True
+ chunk_media_type = media_type
+ with infer_lock:
+ for sr, chunk in tts_pipeline.run(tts_req):
+ if first_chunk and chunk_media_type == "wav":
+ yield wave_header_chunk(sample_rate=sr)
+ chunk_media_type = "raw"
+ first_chunk = False
+ yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue()
+
+ return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
+
+ try:
+ with infer_lock:
+ sr, audio_data = next(tts_pipeline.run(tts_req))
+ audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
+ return Response(audio_bytes, media_type=f"audio/{media_type}")
+ except Exception as exc:
+ return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)})
+
+
+@APP.on_event("startup")
+def startup() -> None:
+ global tts_config, tts_pipeline
+ tts_config = TTS_Config(tts_config_path)
+ print(tts_config)
+ tts_pipeline = TTS(tts_config)
+
+
+@APP.get("/")
+def index() -> Dict[str, Any]:
+ return {
+ "name": "GPT-SoVITS Simple API",
+ "endpoints": [
+ "/health",
+ "/voices",
+ "/api/tts",
+ "/speak",
+ "/speak/base64",
+ "/admin/weights",
+ "/admin/reload-config",
+ ],
+ }
+
+
+@APP.get("/health")
+def health() -> Dict[str, Any]:
+ return {
+ "status": "ok" if tts_pipeline is not None else "starting",
+ "tts_config": tts_config_path,
+ "version": getattr(tts_config, "version", None),
+ "languages": getattr(tts_config, "languages", []),
+ }
+
+
+@APP.get("/voices")
+def list_voices() -> Dict[str, Any]:
+ voices = simple_config.get("voices", {}) or {}
+ return {
+ "default_voice": get_default_voice_name(),
+ "voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()],
+ }
+
+
+@APP.post("/api/tts")
+async def mvp_tts(
+ text: str = Form(...),
+ ref_audio: UploadFile = File(...),
+ aux_ref_audio: Optional[List[UploadFile]] = File(default=None),
+ prompt_text: str = Form(default=""),
+ text_lang: str = Form(default="zh"),
+ prompt_lang: str = Form(default="zh"),
+ format: str = Form(default="wav"),
+ emotion: Optional[str] = Form(default=None),
+ speed: Optional[float] = Form(default=None),
+ seed: int = Form(default=-1),
+) -> Response:
+ request_dir = get_request_upload_dir()
+ try:
+ ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref")
+ validate_main_ref_duration(ref_audio_path)
+
+ aux_ref_audio_paths = []
+ for index, upload in enumerate(aux_ref_audio or []):
+ aux_path = await save_upload_file(upload, request_dir, f"aux_{index}")
+ get_audio_duration_seconds(aux_path)
+ aux_ref_audio_paths.append(aux_path)
+
+ payload: Dict[str, Any] = {
+ "text": text,
+ "text_lang": text_lang,
+ "ref_audio_path": ref_audio_path,
+ "aux_ref_audio_paths": aux_ref_audio_paths,
+ "prompt_text": prompt_text or "",
+ "prompt_lang": prompt_lang,
+ "format": format,
+ "text_split_method": "cut5",
+ "streaming_mode": 0,
+ "seed": seed,
+ }
+ apply_emotion_preset(payload, emotion)
+ payload["text_split_method"] = "cut5"
+ payload["streaming_mode"] = 0
+ if speed is not None:
+ payload["speed"] = speed
+
+ return synthesize_response(payload)
+ finally:
+ shutil.rmtree(request_dir, ignore_errors=True)
+
+
+@APP.get("/speak")
+def speak_get(
+ text: str,
+ voice: Optional[str] = None,
+ text_lang: Optional[str] = None,
+ format: Optional[str] = None,
+ stream: Optional[bool] = None,
+ speed: Optional[float] = None,
+) -> Response:
+ payload = {
+ "text": text,
+ "voice": voice,
+ "text_lang": text_lang,
+ "format": format,
+ "stream": stream,
+ "speed": speed,
+ }
+ return synthesize_response({k: v for k, v in payload.items() if v is not None})
+
+
+@APP.post("/speak")
+def speak_post(request: SpeakRequest) -> Response:
+ return synthesize_response(request_to_dict(request))
+
+
+@APP.post("/v1/tts")
+def openai_style_tts(request: SpeakRequest) -> Response:
+ return synthesize_response(request_to_dict(request))
+
+
+@APP.post("/speak/base64")
+def speak_base64(request: SpeakRequest) -> Dict[str, Any]:
+ audio_bytes, media_type = synthesize_once(request_to_dict(request))
+ return {
+ "media_type": f"audio/{media_type}",
+ "audio_base64": base64.b64encode(audio_bytes).decode("ascii"),
+ }
+
+
+@APP.post("/admin/reload-config")
+def reload_config() -> Dict[str, Any]:
+ global simple_config
+ simple_config = load_yaml_config(args.config)
+ return {"message": "success", "default_voice": get_default_voice_name()}
+
+
+@APP.post("/admin/weights")
+def set_weights(request: WeightsRequest) -> Dict[str, Any]:
+ if tts_pipeline is None:
+ raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
+ if not request.gpt_weights_path and not request.sovits_weights_path:
+ raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required")
+
+ try:
+ with infer_lock:
+ if request.gpt_weights_path:
+ tts_pipeline.init_t2s_weights(request.gpt_weights_path)
+ if request.sovits_weights_path:
+ tts_pipeline.init_vits_weights(request.sovits_weights_path)
+ except Exception as exc:
+ raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc
+
+ return {"message": "success"}
+
+
+if __name__ == "__main__":
+ import uvicorn
+
+ try:
+ uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1)
+ except Exception:
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/simple_api.yaml b/simple_api.yaml
new file mode 100644
index 00000000..4c02394f
--- /dev/null
+++ b/simple_api.yaml
@@ -0,0 +1,59 @@
+server:
+ host: 127.0.0.1
+ port: 9881
+ tts_config: GPT_SoVITS/configs/tts_infer.yaml
+
+cors_allow_origins:
+ - "*"
+
+upload:
+ dir: runtime/uploads
+ min_ref_seconds: 3
+ max_ref_seconds: 10
+ max_upload_mb: 80
+
+default_voice: default
+
+defaults:
+ text_lang: zh
+ prompt_lang: zh
+ media_type: wav
+ text_split_method: cut5
+ batch_size: 1
+ batch_threshold: 0.75
+ split_bucket: true
+ speed_factor: 1.0
+ fragment_interval: 0.3
+ seed: -1
+ parallel_infer: true
+ repetition_penalty: 1.35
+ sample_steps: 32
+ super_sampling: false
+ overlap_length: 2
+ min_chunk_length: 16
+
+emotion_presets:
+ neutral: {}
+ happy:
+ temperature: 1.1
+ top_p: 0.95
+ calm:
+ temperature: 0.8
+ top_p: 0.85
+ speed_factor: 0.92
+ sad:
+ temperature: 0.75
+ top_p: 0.85
+ speed_factor: 0.9
+ angry:
+ temperature: 1.2
+ top_k: 20
+ repetition_penalty: 1.25
+
+voices:
+ default:
+ description: Replace this profile with your reference voice.
+ ref_audio_path: reference.wav
+ prompt_text: Replace this with the exact text spoken in reference.wav.
+ prompt_lang: zh
+ text_lang: zh
diff --git a/test_frontend/index.html b/test_frontend/index.html
new file mode 100644
index 00000000..c7e1b5ba
--- /dev/null
+++ b/test_frontend/index.html
@@ -0,0 +1,713 @@
+
+
+
+
+
+
GPT-SoVITS API Test
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/test_simple_api_contract.py b/tests/test_simple_api_contract.py
new file mode 100644
index 00000000..4cfa3a3e
--- /dev/null
+++ b/tests/test_simple_api_contract.py
@@ -0,0 +1,303 @@
+import asyncio
+import importlib.util
+import sys
+import tempfile
+import types
+import unittest
+from pathlib import Path
+
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+SIMPLE_API_PATH = PROJECT_ROOT / "simple_api.py"
+
+
+class DummyHTTPException(Exception):
+ def __init__(self, status_code=500, detail=None):
+ super().__init__(detail)
+ self.status_code = status_code
+ self.detail = detail
+
+
+class DummyFastAPI:
+ def __init__(self, *args, **kwargs):
+ self.routes = []
+
+ def add_middleware(self, *args, **kwargs):
+ pass
+
+ def mount(self, path, app, name=None):
+ self.routes.append(types.SimpleNamespace(path=path, endpoint=app, name=name))
+
+ def get(self, path):
+ return self._route(path)
+
+ def post(self, path):
+ return self._route(path)
+
+ def on_event(self, event):
+ return lambda func: func
+
+ def _route(self, path):
+ def decorator(func):
+ self.routes.append(types.SimpleNamespace(path=path, endpoint=func))
+ return func
+
+ return decorator
+
+
+class DummyBaseModel:
+ def dict(self, exclude_none=False):
+ data = dict(self.__dict__)
+ if exclude_none:
+ data = {key: value for key, value in data.items() if value is not None}
+ return data
+
+
+class DummyTTSConfig:
+ version = "v2"
+ languages = ["zh", "en", "auto"]
+ use_vocoder = False
+
+
+class FakeUpload:
+ def __init__(self, filename, chunks, content_type="audio/wav"):
+ self.filename = filename
+ self.content_type = content_type
+ self._chunks = list(chunks)
+ self.closed = False
+
+ async def read(self, size):
+ del size
+ if self._chunks:
+ return self._chunks.pop(0)
+ return b""
+
+ async def close(self):
+ self.closed = True
+
+
+def install_dummy_modules():
+ fastapi = types.ModuleType("fastapi")
+ fastapi.FastAPI = DummyFastAPI
+ fastapi.File = lambda default=None, **kwargs: default
+ fastapi.Form = lambda default=None, **kwargs: default
+ fastapi.HTTPException = DummyHTTPException
+ fastapi.Response = type("Response", (), {})
+ fastapi.UploadFile = type("UploadFile", (), {})
+ sys.modules["fastapi"] = fastapi
+
+ middleware = types.ModuleType("fastapi.middleware")
+ cors = types.ModuleType("fastapi.middleware.cors")
+ cors.CORSMiddleware = type("CORSMiddleware", (), {})
+ sys.modules["fastapi.middleware"] = middleware
+ sys.modules["fastapi.middleware.cors"] = cors
+
+ responses = types.ModuleType("fastapi.responses")
+ responses.JSONResponse = type("JSONResponse", (), {})
+ responses.StreamingResponse = type("StreamingResponse", (), {})
+ sys.modules["fastapi.responses"] = responses
+
+ staticfiles = types.ModuleType("fastapi.staticfiles")
+ staticfiles.StaticFiles = type("StaticFiles", (), {"__init__": lambda self, *args, **kwargs: None})
+ sys.modules["fastapi.staticfiles"] = staticfiles
+
+ pydantic = types.ModuleType("pydantic")
+ pydantic.BaseModel = DummyBaseModel
+ pydantic.Field = lambda default=None, **kwargs: default
+ sys.modules["pydantic"] = pydantic
+
+ numpy = types.ModuleType("numpy")
+ numpy.ndarray = object
+ sys.modules["numpy"] = numpy
+
+ soundfile = types.ModuleType("soundfile")
+ soundfile.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
+ sys.modules["soundfile"] = soundfile
+
+ yaml = types.ModuleType("yaml")
+ yaml.safe_load = lambda stream: {
+ "server": {"host": "127.0.0.1", "port": 9881, "tts_config": "dummy.yaml"},
+ "upload": {"dir": "runtime/test_uploads", "min_ref_seconds": 3, "max_ref_seconds": 10, "max_upload_mb": 1},
+ "defaults": {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "text_split_method": "cut5"},
+ "emotion_presets": {"calm": {"temperature": 0.8, "speed_factor": 0.9}},
+ "voices": {},
+ }
+ sys.modules["yaml"] = yaml
+
+ sys.modules["GPT_SoVITS"] = types.ModuleType("GPT_SoVITS")
+ sys.modules["GPT_SoVITS.TTS_infer_pack"] = types.ModuleType("GPT_SoVITS.TTS_infer_pack")
+
+ tts_module = types.ModuleType("GPT_SoVITS.TTS_infer_pack.TTS")
+ tts_module.TTS = type("DummyTTS", (), {})
+ tts_module.TTS_Config = DummyTTSConfig
+ sys.modules["GPT_SoVITS.TTS_infer_pack.TTS"] = tts_module
+
+ segmentation = types.ModuleType("GPT_SoVITS.TTS_infer_pack.text_segmentation_method")
+ segmentation.get_method_names = lambda: ["cut0", "cut5"]
+ sys.modules["GPT_SoVITS.TTS_infer_pack.text_segmentation_method"] = segmentation
+
+
+def load_simple_api():
+ install_dummy_modules()
+ sys.argv = ["simple_api.py"]
+ spec = importlib.util.spec_from_file_location("simple_api_under_test", SIMPLE_API_PATH)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ module.tts_config = DummyTTSConfig()
+ return module
+
+
+class SimpleApiContractTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.simple_api = load_simple_api()
+
+ def test_api_tts_route_is_registered(self):
+ routes = {route.path for route in self.simple_api.APP.routes}
+ self.assertIn("/api/tts", routes)
+
+ def test_build_request_accepts_explicit_ref_without_voice_profile(self):
+ with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
+ req, media_type, response_stream = self.simple_api.build_tts_request(
+ {
+ "text": "Hello, test.",
+ "ref_audio_path": ref.name,
+ "prompt_text": "",
+ "prompt_lang": "zh",
+ "text_lang": "zh",
+ "format": "wav",
+ "text_split_method": "cut5",
+ "streaming_mode": 0,
+ }
+ )
+
+ self.assertEqual(req["prompt_text"], "")
+ self.assertEqual(req["text_split_method"], "cut5")
+ self.assertEqual(req["ref_audio_path"], ref.name)
+ self.assertEqual(media_type, "wav")
+ self.assertFalse(response_stream)
+
+ def test_explicit_speed_overrides_emotion_speed_preset(self):
+ with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
+ payload = {
+ "text": "Hello.",
+ "ref_audio_path": ref.name,
+ "prompt_lang": "zh",
+ "text_lang": "zh",
+ "format": "wav",
+ "streaming_mode": 0,
+ "speed": 1.05,
+ }
+ self.simple_api.apply_emotion_preset(payload, "calm")
+ payload["speed"] = 1.05
+ req, _, _ = self.simple_api.build_tts_request(payload)
+
+ self.assertEqual(req["temperature"], 0.8)
+ self.assertEqual(req["speed_factor"], 1.05)
+
+ def test_empty_prompt_text_is_rejected_for_vocoder_models(self):
+ original_use_vocoder = self.simple_api.tts_config.use_vocoder
+ original_version = self.simple_api.tts_config.version
+ self.simple_api.tts_config.use_vocoder = True
+ self.simple_api.tts_config.version = "v3"
+
+ try:
+ with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
+ with self.assertRaises(DummyHTTPException) as exc:
+ self.simple_api.build_tts_request(
+ {
+ "text": "Hello.",
+ "ref_audio_path": ref.name,
+ "prompt_text": "",
+ "prompt_lang": "zh",
+ "text_lang": "zh",
+ "format": "wav",
+ "text_split_method": "cut5",
+ "streaming_mode": 0,
+ }
+ )
+ self.assertEqual(exc.exception.status_code, 400)
+ self.assertIn("prompt_text is required", str(exc.exception.detail))
+ finally:
+ self.simple_api.tts_config.use_vocoder = original_use_vocoder
+ self.simple_api.tts_config.version = original_version
+
+ def test_ref_audio_duration_limits(self):
+ self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
+ self.simple_api.validate_main_ref_duration("ok.wav")
+
+ self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 2, samplerate=16000)
+ with self.assertRaises(DummyHTTPException) as too_short:
+ self.simple_api.validate_main_ref_duration("short.wav")
+ self.assertEqual(too_short.exception.status_code, 400)
+
+ self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 11, samplerate=16000)
+ with self.assertRaises(DummyHTTPException) as too_long:
+ self.simple_api.validate_main_ref_duration("long.wav")
+ self.assertEqual(too_long.exception.status_code, 400)
+
+ def test_save_upload_file_writes_and_closes(self):
+ with tempfile.TemporaryDirectory() as tmp:
+ upload = FakeUpload("ref.wav", [b"abc", b"def"])
+ saved_path = asyncio.run(self.simple_api.save_upload_file(upload, Path(tmp), "ref"))
+
+ self.assertTrue(upload.closed)
+ self.assertEqual(Path(saved_path).read_bytes(), b"abcdef")
+
+ def test_mvp_tts_builds_payload_from_uploads_and_cleans_tmp_dir(self):
+ captured = {}
+
+ with tempfile.TemporaryDirectory() as tmp:
+ request_dir = Path(tmp) / "request"
+ request_dir.mkdir()
+
+ original_get_request_upload_dir = self.simple_api.get_request_upload_dir
+ original_synthesize_response = self.simple_api.synthesize_response
+ original_sf_info = self.simple_api.sf.info
+
+ def fake_synthesize_response(payload):
+ captured.update(payload)
+ return "audio-response"
+
+ try:
+ self.simple_api.get_request_upload_dir = lambda: request_dir
+ self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
+ self.simple_api.synthesize_response = fake_synthesize_response
+
+ response = asyncio.run(
+ self.simple_api.mvp_tts(
+ text="Hello, test.",
+ ref_audio=FakeUpload("ref.wav", [b"main"]),
+ aux_ref_audio=[FakeUpload("aux.wav", [b"aux"])],
+ prompt_text="",
+ text_lang="zh",
+ prompt_lang="zh",
+ format="wav",
+ emotion="calm",
+ speed=1.05,
+ seed=123,
+ )
+ )
+ finally:
+ self.simple_api.get_request_upload_dir = original_get_request_upload_dir
+ self.simple_api.synthesize_response = original_synthesize_response
+ self.simple_api.sf.info = original_sf_info
+
+ self.assertEqual(response, "audio-response")
+ self.assertFalse(request_dir.exists())
+
+ self.assertEqual(captured["text"], "Hello, test.")
+ self.assertEqual(captured["prompt_text"], "")
+ self.assertEqual(captured["text_lang"], "zh")
+ self.assertEqual(captured["prompt_lang"], "zh")
+ self.assertEqual(captured["format"], "wav")
+ self.assertEqual(captured["text_split_method"], "cut5")
+ self.assertEqual(captured["streaming_mode"], 0)
+ self.assertEqual(captured["seed"], 123)
+ self.assertEqual(captured["temperature"], 0.8)
+ self.assertEqual(captured["speed"], 1.05)
+ self.assertEqual(len(captured["aux_ref_audio_paths"]), 1)
+
+
+if __name__ == "__main__":
+ unittest.main()