Merge a47b87bb7b1219d133523f88024888aa296594bf into bf81cdb14a38b674b6e9996dabc97340bc9978d2

This commit is contained in:
qimingnan17 2026-06-20 23:27:35 +08:00 committed by GitHub
commit e5f4f31a1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 2711 additions and 0 deletions

View File

@ -27,6 +27,44 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
---
## 简化接口 / 测试前端
本工作区新增了一个用于 MVP 调用的中间层接口和测试前端:
| 页面 | 地址 | 说明 |
|------|------|------|
| 测试前端 | http://127.0.0.1:9881/test/ | 带波形裁剪、视频转音频的测试 UI |
| Swagger UI | http://127.0.0.1:9881/docs | 交互式 API 文档 |
| ReDoc | http://127.0.0.1:9881/redoc | 可读式 API 文档 |
文档:
- 快速启动:[docs/simple_api_quickstart.md](./docs/simple_api_quickstart.md)
- 完整教程:[docs/simple_api.md](./docs/simple_api.md)
核心功能:
- `/api/tts` — 上传参考音频/视频 + 文字,直接返回生成的音频
- 前端支持视频上传,自动提取音频
- 前端波形裁剪工具,可选择 3-10 秒片段
- 5 种情绪预设neutral/happy/calm/sad/angry
- 7 个契约测试,无需 GPU 即可运行
启动命令:
```
cd D:\tts\GPT-SoVITS
.\go-simple-api.ps1
```
运行测试:
```
python -m unittest tests.test_simple_api_contract -v
```
---
## Features:
1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion.
@ -478,3 +516,4 @@ Thankful to @Naozumi520 for providing the Cantonese training set and for the gui
<a href="https://github.com/RVC-Boss/GPT-SoVITS/graphs/contributors" target="_blank">
<img src="https://contrib.rocks/image?repo=RVC-Boss/GPT-SoVITS" />
</a>

402
docs/simple_api.md Normal file
View File

@ -0,0 +1,402 @@
# GPT-SoVITS 简化接口文档
本项目新增 `simple_api.py` 作为中间层,封装 GPT-SoVITS 推理引擎,提供更简洁的调用方式。
## 快速开始
```bash
# 安装依赖
python -m pip install -r requirements.txt
# 启动
python simple_api.py -c simple_api.yaml
# 访问
Swagger UI: http://127.0.0.1:9881/docs
ReDoc: http://127.0.0.1:9881/redoc
测试前端: http://127.0.0.1:9881/test/
```
## 接口总览
| 方法 | 路径 | 说明 | 标签 |
|------|------|------|------|
| GET | `/health` | 健康检查(含 GPU 信息) | System |
| GET | `/voices` | 列出 voice profiles | System |
| **POST** | **`/api/tts`** | **核心 TTS 接口MVP** | **MVP** |
| GET | `/speak` | voice profile TTS (GET) | Profile |
| POST | `/speak` | voice profile TTS (POST) | Profile |
| POST | `/v1/tts` | OpenAI 兼容格式 TTS | Profile |
| POST | `/speak/base64` | 返回 Base64 音频 | Profile |
| POST | `/admin/reload-config` | 热加载配置 | Admin |
| POST | `/admin/weights` | 切换模型权重 | Admin |
---
## 1. POST /api/tts — 核心 TTS 接口
**推荐使用此接口**。上传参考音频和文字,直接返回生成的音频。
### 请求格式
```
Content-Type: multipart/form-data
```
### 字段说明
| 字段 | 类型 | 必填 | 默认值 | 说明 |
|------|------|------|--------|------|
| `text` | string | **是** | — | 需要生成的文字 |
| `ref_audio` | file | **是** | — | 主参考音频3-10 秒(支持 wav/flac/ogg/mp3/m4a/aac |
| `aux_ref_audio` | file[] | 否 | — | 辅助参考音频,可上传多个 |
| `prompt_text` | string | 否 | `""` | 主参考音频对应文字v2 可留空v3/v4 必填) |
| `text_lang` | string | 否 | `zh` | 生成文字语言zh/en/ja/ko/yue/auto |
| `prompt_lang` | string | 否 | `zh` | 参考音频语言zh/en/ja/ko/yue/auto |
| `format` | string | 否 | `wav` | 返回格式wav/ogg/aac/raw |
| `emotion` | string | 否 | `neutral` | 情绪预设neutral/happy/calm/sad/angry |
| `speed` | float | 否 | — | 语速0.5-2.0),覆盖情绪预设中的语速 |
| `seed` | int | 否 | `-1` | 随机种子,-1 为随机 |
### 情绪预设参数映射
| 情绪 | temperature | top_p | top_k | speed_factor | repetition_penalty |
|------|-------------|-------|-------|--------------|-------------------|
| neutral | — | — | — | — | — |
| happy | 1.1 | 0.95 | — | — | — |
| calm | 0.8 | 0.85 | — | 0.92 | — |
| sad | 0.75 | 0.85 | — | 0.9 | — |
| angry | 1.2 | — | 20 | — | 1.25 |
> 显式传入 `speed` 会覆盖情绪预设中的 `speed_factor`
### curl 示例
**基础调用:**
```powershell
curl.exe -X POST http://127.0.0.1:9881/api/tts `
-F "text=你好,欢迎使用这个声音。" `
-F "ref_audio=@D:\audio\ref.wav" `
--output output.wav
```
**带辅助参考音频和情绪:**
```powershell
curl.exe -X POST http://127.0.0.1:9881/api/tts `
-F "text=你好,欢迎使用这个声音。" `
-F "ref_audio=@D:\audio\ref.wav" `
-F "aux_ref_audio=@D:\audio\aux1.wav" `
-F "emotion=happy" `
-F "speed=1.1" `
--output output.wav
```
**Linux/macOS**
```bash
curl -X POST http://127.0.0.1:9881/api/tts \
-F "text=你好,欢迎使用这个声音。" \
-F "ref_audio=@/path/to/ref.wav" \
-F "emotion=calm" \
--output output.wav
```
### 返回
- 成功音频二进制流Content-Type: `audio/wav` 等)
- 失败JSON 错误信息
```json
{"message": "tts failed", "exception": "..."}
```
### 常见错误
| HTTP 状态码 | 原因 |
|------------|------|
| 400 | text 为空 / ref_audio 缺失 / 音频时长不在 3-10 秒 / 不支持的 format / v3/v4 时 prompt_text 为空 |
| 404 | voice profile 不存在(仅 /speak 接口) |
| 503 | TTS pipeline 未就绪(模型未加载) |
---
## 2. GET /health — 健康检查
```bash
curl http://127.0.0.1:9881/health
```
返回示例:
```json
{
"status": "ok",
"tts_config": "GPT_SoVITS/configs/tts_infer.yaml",
"version": "v2",
"languages": ["auto", "en", "zh"],
"pid": 12345,
"memory_mb": 2048.5,
"gpu": {
"name": "NVIDIA GeForce RTX 3080",
"memory_used_mb": 4096.2,
"memory_total_mb": 10240.0
}
}
```
---
## 3. GET /voices — 列出 voice profiles
```bash
curl http://127.0.0.1:9881/voices
```
返回示例:
```json
{
"default_voice": "default",
"voices": [
{
"name": "default",
"description": "Replace this profile with your reference voice.",
"text_lang": "zh",
"prompt_lang": "zh",
"ref_audio_path": "reference.wav",
"ready": true
}
]
}
```
---
## 4. POST /speak — voice profile TTS
基于 `simple_api.yaml` 中配置的 voice profile 调用 TTS。
### 请求体JSON
```json
{
"text": "hello world",
"voice": "default",
"text_lang": "zh",
"format": "wav",
"speed": 1.0
}
```
| 字段 | 类型 | 必填 | 说明 |
|------|------|------|------|
| `text` | string | **是** | 需要生成的文字 |
| `voice` | string | 否 | voice profile 名称,不传则使用 default |
| `text_lang` | string | 否 | 生成文字语言 |
| `format` | string | 否 | 返回格式 |
| `stream` | bool | 否 | 是否流式返回 |
| `speed` | float | 否 | 语速 |
### curl 示例
```bash
curl -X POST http://127.0.0.1:9881/speak \
-H "Content-Type: application/json" \
-d '{"text":"你好世界","voice":"default"}' \
--output output.wav
```
---
## 5. GET /speak — voice profile TTS (GET)
与 POST /speak 相同,但通过 URL 参数传递。
```
GET /speak?text=hello&voice=default&format=wav
```
---
## 6. POST /speak/base64 — 返回 Base64 音频
返回 Base64 编码的音频,适合 Web 前端直接使用。
```bash
curl -X POST http://127.0.0.1:9881/speak/base64 \
-H "Content-Type: application/json" \
-d '{"text":"hello","voice":"default"}'
```
返回:
```json
{
"media_type": "audio/wav",
"audio_base64": "UklGRi..."
}
```
---
## 7. POST /v1/tts — OpenAI 兼容格式
请求格式与 POST /speak 相同,路径兼容 OpenAI TTS API 风格。
---
## 8. POST /admin/reload-config — 热加载配置
重新加载 `simple_api.yaml`,无需重启服务。
```bash
curl -X POST http://127.0.0.1:9881/admin/reload-config
```
返回:`{"message": "success", "default_voice": "default"}`
---
## 9. POST /admin/weights — 切换模型权重
运行时切换 GPT-SoVITS 模型权重文件。
```bash
curl -X POST http://127.0.0.1:9881/admin/weights \
-H "Content-Type: application/json" \
-d '{"gpt_weights_path":"path/to/gpt.pt","sovits_weights_path":"path/to/sovits.pt"}'
```
---
## 配置文件
`simple_api.yaml`
```yaml
server:
host: 127.0.0.1
port: 9881
tts_config: GPT_SoVITS/configs/tts_infer.yaml
cors_allow_origins:
- "*"
upload:
dir: runtime/uploads
min_ref_seconds: 3
max_ref_seconds: 10
max_upload_mb: 80
defaults:
text_lang: zh
prompt_lang: zh
media_type: wav
text_split_method: cut5
batch_size: 1
speed_factor: 1.0
seed: -1
emotion_presets:
neutral: {}
happy:
temperature: 1.1
top_p: 0.95
calm:
temperature: 0.8
top_p: 0.85
speed_factor: 0.92
sad:
temperature: 0.75
top_p: 0.85
speed_factor: 0.9
angry:
temperature: 1.2
top_k: 20
repetition_penalty: 1.25
voices:
default:
description: Replace this profile with your reference voice.
ref_audio_path: reference.wav
prompt_text: Replace this with the exact text spoken in reference.wav.
prompt_lang: zh
text_lang: zh
```
### 配置说明
| 配置项 | 说明 |
|--------|------|
| `server.host` | 监听地址 |
| `server.port` | 监听端口 |
| `server.tts_config` | GPT-SoVITS 推理配置文件路径 |
| `upload.dir` | 临时上传目录 |
| `upload.min_ref_seconds` | 主参考音频最短秒数 |
| `upload.max_ref_seconds` | 主参考音频最长秒数 |
| `upload.max_upload_mb` | 单个上传文件最大体积 (MB) |
| `defaults.*` | 所有接口的默认参数 |
| `emotion_presets.*` | 情绪预设参数映射 |
| `voices.*` | 固定音色 profile |
---
## 添加自定义音色
编辑 `simple_api.yaml`,在 `voices` 下添加:
```yaml
voices:
narrator:
description: "男声旁白"
ref_audio_path: voices/narrator.wav
prompt_text: "旁白参考音频的逐字稿"
prompt_lang: zh
text_lang: zh
```
然后热加载:
```bash
curl -X POST http://127.0.0.1:9881/admin/reload-config
```
---
## 测试
### 契约测试(无需 GPU
```bash
python -m unittest tests.test_simple_api_contract -v
```
覆盖:
- `/api/tts` 路由注册
- 上传接口参数构造
- 主参考音频 3-10 秒校验
- v2 空 prompt_text 允许 / v3/v4 空 prompt_text 拒绝
- 临时上传目录清理
- 情绪预设应用与 speed 覆盖
### 前端测试
1. 启动后端
2. 访问 `http://127.0.0.1:9881/test/`
3. 上传音频或视频(视频会自动提取音频)
4. 使用波形裁剪工具选择 3-10 秒片段
5. 填写文字,选择情绪和语速
6. 点击生成
---
## 启动脚本
| 脚本 | 平台 | 说明 |
|------|------|------|
| `go-simple-api.ps1` | Windows PowerShell | 自动检测 runtime\python.exe |
| `go-simple-api.bat` | Windows CMD | 同上 |
| `open-test-frontend.ps1` | Windows PowerShell | 直接打开测试前端 HTML |

View File

@ -0,0 +1,85 @@
# 简化 TTS 接口快速启动
完整教程见:`docs/simple_api.md`
## 一句话流程
启动后端,打开测试页,上传 3-10 秒参考音频或视频(视频自动提取音频),填写生成文字,点击生成。
## 1. 启动后端
```powershell
cd D:\tts\GPT-SoVITS
python -m pip install -r requirements.txt
.\go-simple-api.ps1
```
也可以运行:
```bat
go-simple-api.bat
```
默认后端地址:
```text
http://127.0.0.1:9881
```
## 2. 打开页面
| 页面 | 地址 | 说明 |
|------|------|------|
| 测试前端 | http://127.0.0.1:9881/test/ | 带波形裁剪的测试 UI |
| Swagger UI | http://127.0.0.1:9881/docs | 交互式 API 文档 |
| ReDoc | http://127.0.0.1:9881/redoc | 可读式 API 文档 |
## 3. 调用接口
接口:
```http
POST /api/tts
Content-Type: multipart/form-data
```
最小字段:
```text
text 要生成的文字
ref_audio 3-10 秒主参考音频(或视频,前端自动提取)
```
常用可选字段:
```text
aux_ref_audio 辅助参考音频,可多个
prompt_text 参考音频文本,可留空
text_lang 默认 zh
prompt_lang 默认 zh
emotion neutral / happy / calm / sad / angry
speed 语速,默认 1
seed 默认 -1
```
## 4. PowerShell 示例
```powershell
curl.exe -X POST http://127.0.0.1:9881/api/tts `
-F "text=你好,欢迎使用这个声音。" `
-F "ref_audio=@D:\audio\ref.wav" `
-F "prompt_text=" `
-F "text_lang=zh" `
-F "prompt_lang=zh" `
--output output.wav
```
## 5. 注意事项
- 主参考音频必须是 3-10 秒。
- 前端支持上传视频文件,会自动提取音频。
- 前端提供波形裁剪工具,可直接选择 3-10 秒片段。
- `prompt_text` 在当前 v2 配置下可以为空。
- 如果切换到 v3/v4`prompt_text` 会返回 400。
- 生成文字固定按标点符号切句。
- 更详细的配置、profile、base64 接口见 `docs/simple_api.md`

12
go-simple-api.bat Normal file
View File

@ -0,0 +1,12 @@
@echo off
set "SCRIPT_DIR=%~dp0"
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
cd /d "%SCRIPT_DIR%"
if exist "%SCRIPT_DIR%\runtime\python.exe" (
"%SCRIPT_DIR%\runtime\python.exe" "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
) else (
python "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
)
pause

13
go-simple-api.ps1 Normal file
View File

@ -0,0 +1,13 @@
$ErrorActionPreference = "Stop"
chcp 65001
Set-Location $PSScriptRoot
$runtimePython = Join-Path $PSScriptRoot "runtime\python.exe"
$apiScript = Join-Path $PSScriptRoot "simple_api.py"
$apiConfig = Join-Path $PSScriptRoot "simple_api.yaml"
if (Test-Path $runtimePython) {
& $runtimePython $apiScript -c $apiConfig @args
} else {
python $apiScript -c $apiConfig @args
}

3
open-test-frontend.ps1 Normal file
View File

@ -0,0 +1,3 @@
$ErrorActionPreference = "Stop"
Set-Location $PSScriptRoot
Start-Process (Join-Path $PSScriptRoot "test_frontend\index.html")

View File

@ -3,6 +3,7 @@ numpy<2.0
scipy
tensorboard
librosa==0.10.2
soundfile
numba
pytorch-lightning>=2.4
gradio<5
@ -22,6 +23,7 @@ transformers>=4.43,<=4.50
peft<0.18.0
chardet
PyYAML
python-multipart
psutil
jieba_fast
jieba

787
simple_api.py Normal file
View File

@ -0,0 +1,787 @@
"""
Small profile-based API layer for GPT-SoVITS.
Run:
python simple_api.py -c simple_api.yaml
Then call:
POST http://127.0.0.1:9881/speak
{"text": "hello", "voice": "default"}
"""
from __future__ import annotations
import argparse
import base64
import os
import shutil
import subprocess
import sys
import threading
import traceback
import uuid
import wave
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import soundfile as sf
import yaml
from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
PROJECT_ROOT = Path(__file__).resolve().parent
os.chdir(PROJECT_ROOT)
sys.path.append(str(PROJECT_ROOT))
sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS"))
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
get_method_names as get_cut_method_names,
)
DEFAULT_TTS_PARAMS: Dict[str, Any] = {
"text": "",
"text_lang": "zh",
"ref_audio_path": "",
"aux_ref_audio_paths": [],
"prompt_text": "",
"prompt_lang": "zh",
"top_k": 15,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut5",
"batch_size": 1,
"batch_threshold": 0.75,
"split_bucket": True,
"speed_factor": 1.0,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"return_fragment": False,
"fixed_length_chunk": False,
"parallel_infer": True,
"repetition_penalty": 1.35,
"sample_steps": 32,
"super_sampling": False,
"overlap_length": 2,
"min_chunk_length": 16,
}
SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"}
SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"}
VOICE_META_KEYS = {"description"}
class SpeakRequest(BaseModel):
text: str
voice: Optional[str] = None
text_lang: Optional[str] = None
ref_audio_path: Optional[str] = None
aux_ref_audio_paths: Optional[List[str]] = None
prompt_text: Optional[str] = None
prompt_lang: Optional[str] = None
format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac")
stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2")
streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false")
speed: Optional[float] = Field(default=None, description="Alias of speed_factor")
speed_factor: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
text_split_method: Optional[str] = None
batch_size: Optional[int] = None
batch_threshold: Optional[float] = None
split_bucket: Optional[bool] = None
fragment_interval: Optional[float] = None
seed: Optional[int] = None
parallel_infer: Optional[bool] = None
repetition_penalty: Optional[float] = None
sample_steps: Optional[int] = None
super_sampling: Optional[bool] = None
overlap_length: Optional[int] = None
min_chunk_length: Optional[int] = None
class WeightsRequest(BaseModel):
gpt_weights_path: Optional[str] = None
sovits_weights_path: Optional[str] = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="GPT-SoVITS simple API")
parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path")
parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path")
parser.add_argument("-a", "--bind-addr", default=None, help="bind address")
parser.add_argument("-p", "--port", type=int, default=None, help="bind port")
return parser.parse_args()
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
path = Path(config_path)
if not path.is_absolute():
path = PROJECT_ROOT / path
if not path.exists():
raise FileNotFoundError(f"simple API config not found: {path}")
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
if not isinstance(data, dict):
raise ValueError("simple API config must be a YAML object")
return data
args = parse_args()
simple_config = load_yaml_config(args.config)
server_config = simple_config.get("server", {}) or {}
host = args.bind_addr or server_config.get("host", "127.0.0.1")
port = args.port or int(server_config.get("port", 9881))
tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml")
cut_method_names = get_cut_method_names()
tts_config: Optional[TTS_Config] = None
tts_pipeline: Optional[TTS] = None
infer_lock = threading.Lock()
APP = FastAPI(
title="GPT-SoVITS Simple API",
description=(
"简化接口层,封装 GPT-SoVITS 推理引擎。\n\n"
"## 核心流程\n"
"1. 上传 3-10 秒参考音频(或视频,前端自动提取音频)\n"
"2. 填写需要生成的文字\n"
"3. 调用 `/api/tts` 获取生成的音频\n\n"
"## 其他接口\n"
"- `/speak` — 基于 voice profile 的调用方式\n"
"- `/v1/tts` — OpenAI 兼容格式\n"
"- `/admin/*` — 管理接口(热加载配置、切换模型)"
),
version="1.1.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_tags=[
{"name": "MVP", "description": "核心 TTS 接口,上传参考音频直接生成"},
{"name": "Profile", "description": "基于 voice profile 的调用方式"},
{"name": "Admin", "description": "管理接口:热加载配置、切换模型权重"},
{"name": "System", "description": "健康检查与信息查询"},
],
)
APP.add_middleware(
CORSMiddleware,
allow_origins=simple_config.get("cors_allow_origins", ["*"]),
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
test_frontend_dir = PROJECT_ROOT / "test_frontend"
if test_frontend_dir.exists():
APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend")
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
def handle_pack_ogg() -> None:
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
try:
threading.stack_size(4096 * 4096)
pack_thread = threading.Thread(target=handle_pack_ogg)
pack_thread.start()
pack_thread.join()
except (RuntimeError, ValueError):
handle_pack_ogg()
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
del rate
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
del io_buffer
wav_buffer = BytesIO()
sf.write(wav_buffer, data, rate, format="wav")
return wav_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le",
"-ar",
str(rate),
"-ac",
"1",
"-i",
"pipe:0",
"-c:a",
"aac",
"-b:a",
"192k",
"-vn",
"-f",
"adts",
"pipe:1",
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO:
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes:
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def request_to_dict(request: BaseModel) -> Dict[str, Any]:
if hasattr(request, "model_dump"):
return request.model_dump(exclude_none=True)
return request.dict(exclude_none=True)
def get_default_voice_name() -> Optional[str]:
default_voice = simple_config.get("default_voice")
if default_voice:
return str(default_voice)
voices = simple_config.get("voices", {}) or {}
if "default" in voices:
return "default"
return next(iter(voices.keys()), None)
def resolve_project_path(path_value: Optional[str]) -> Optional[str]:
if path_value in [None, ""]:
return None
path = Path(str(path_value))
if not path.is_absolute():
path = PROJECT_ROOT / path
return str(path)
def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]:
if streaming_mode is None:
streaming_mode = 2 if stream else 0
elif isinstance(streaming_mode, bool):
streaming_mode = 2 if streaming_mode else 0
if streaming_mode == 0:
return False, False, False, False
if streaming_mode == 1:
return False, True, False, True
if streaming_mode == 2:
return True, False, False, True
if streaming_mode == 3:
return True, False, True, True
raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false")
def get_upload_config() -> Dict[str, Any]:
return simple_config.get("upload", {}) or {}
def get_upload_root() -> Path:
upload_dir = str(get_upload_config().get("dir", "runtime/uploads"))
root = Path(upload_dir)
if not root.is_absolute():
root = PROJECT_ROOT / root
root.mkdir(parents=True, exist_ok=True)
return root
def get_request_upload_dir() -> Path:
request_dir = get_upload_root() / uuid.uuid4().hex
request_dir.mkdir(parents=True, exist_ok=True)
return request_dir
def get_upload_limits() -> Tuple[float, float, int]:
upload_config = get_upload_config()
min_seconds = float(upload_config.get("min_ref_seconds", 3.0))
max_seconds = float(upload_config.get("max_ref_seconds", 10.0))
max_upload_mb = int(upload_config.get("max_upload_mb", 80))
return min_seconds, max_seconds, max_upload_mb
def get_upload_suffix(upload: UploadFile) -> str:
suffix = Path(upload.filename or "").suffix.lower()
if suffix:
return suffix
content_type = (upload.content_type or "").lower()
if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}:
return ".wav"
if content_type == "audio/mpeg":
return ".mp3"
if content_type == "audio/ogg":
return ".ogg"
if content_type in {"audio/aac", "audio/aacp"}:
return ".aac"
return ".wav"
def validate_audio_upload(upload: UploadFile) -> None:
suffix = get_upload_suffix(upload)
content_type = (upload.content_type or "").lower()
is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""}
if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type:
raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}")
async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str:
validate_audio_upload(upload)
_, _, max_upload_mb = get_upload_limits()
max_bytes = max_upload_mb * 1024 * 1024
suffix = get_upload_suffix(upload)
target_path = target_dir / f"{name_prefix}{suffix}"
total_size = 0
try:
with target_path.open("wb") as f:
while True:
chunk = await upload.read(1024 * 1024)
if not chunk:
break
total_size += len(chunk)
if total_size > max_bytes:
raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB")
f.write(chunk)
if total_size == 0:
raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}")
finally:
await upload.close()
return str(target_path)
def get_audio_duration_with_ffprobe(audio_path: str) -> float:
result = subprocess.run(
[
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
audio_path,
],
capture_output=True,
text=True,
timeout=15,
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip() or "ffprobe failed")
return float(result.stdout.strip())
def get_audio_duration_seconds(audio_path: str) -> float:
try:
info = sf.info(audio_path)
if info.samplerate <= 0:
raise ValueError("invalid sample rate")
return float(info.frames) / float(info.samplerate)
except Exception as sf_exc:
try:
return get_audio_duration_with_ffprobe(audio_path)
except Exception as ffprobe_exc:
raise HTTPException(
status_code=400,
detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}",
) from ffprobe_exc
def validate_main_ref_duration(audio_path: str) -> None:
min_seconds, max_seconds, _ = get_upload_limits()
duration = get_audio_duration_seconds(audio_path)
if duration < min_seconds or duration > max_seconds:
raise HTTPException(
status_code=400,
detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s",
)
def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None:
if emotion in [None, ""]:
return
emotion_name = str(emotion).strip().lower()
emotion_presets = simple_config.get("emotion_presets", {}) or {}
if emotion_name not in emotion_presets:
supported = ", ".join(sorted(emotion_presets.keys())) or "none"
raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}")
payload.update(emotion_presets[emotion_name] or {})
def prompt_text_required_for_current_model() -> bool:
if tts_config is None:
return False
use_vocoder = getattr(tts_config, "use_vocoder", None)
if use_vocoder is not None:
return bool(use_vocoder)
return getattr(tts_config, "version", None) in {"v3", "v4"}
def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]:
ref_audio_path = resolve_project_path(profile.get("ref_audio_path"))
return {
"name": name,
"description": profile.get("description", ""),
"text_lang": profile.get("text_lang"),
"prompt_lang": profile.get("prompt_lang"),
"ref_audio_path": profile.get("ref_audio_path"),
"ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")),
}
def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]:
voices = simple_config.get("voices", {}) or {}
explicit_ref_audio = bool(payload.get("ref_audio_path"))
voice_name = payload.pop("voice", None)
if voice_name is None and not explicit_ref_audio:
voice_name = get_default_voice_name()
voice_profile: Dict[str, Any] = {}
if voice_name:
if voice_name not in voices:
raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}")
voice_profile = deepcopy(voices[voice_name] or {})
tts_req = deepcopy(DEFAULT_TTS_PARAMS)
tts_req.update(simple_config.get("defaults", {}) or {})
tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS})
stream = payload.pop("stream", None)
media_type = payload.pop("format", None)
speed = payload.pop("speed", None)
if media_type is not None:
tts_req["media_type"] = media_type
for key, value in payload.items():
if key in DEFAULT_TTS_PARAMS:
tts_req[key] = value
if speed is not None:
tts_req["speed_factor"] = speed
ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path"))
if ref_audio_path:
tts_req["ref_audio_path"] = ref_audio_path
aux_paths = tts_req.get("aux_ref_audio_paths") or []
tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item]
media_type = str(tts_req.get("media_type", "wav")).lower()
tts_req["media_type"] = media_type
if media_type not in SUPPORTED_MEDIA_TYPES:
raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}")
text = str(tts_req.get("text") or "").strip()
if not text:
raise HTTPException(status_code=400, detail="text is required")
tts_req["text"] = text
if not tts_req.get("ref_audio_path"):
raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request")
if not Path(tts_req["ref_audio_path"]).exists():
raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}")
for aux_path in tts_req["aux_ref_audio_paths"]:
if aux_path and not Path(aux_path).exists():
raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}")
if tts_config is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
text_lang = str(tts_req.get("text_lang") or "").lower()
prompt_lang = str(tts_req.get("prompt_lang") or "").lower()
tts_req["text_lang"] = text_lang
tts_req["prompt_lang"] = prompt_lang
if text_lang not in tts_config.languages:
raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}")
if prompt_lang not in tts_config.languages:
raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}")
if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model():
version = getattr(tts_config, "version", "current")
raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}")
split_method = str(tts_req.get("text_split_method") or "cut5")
if split_method not in cut_method_names:
raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}")
streaming_mode = tts_req.pop("streaming_mode", None)
streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream)
tts_req["streaming_mode"] = streaming_enabled
tts_req["return_fragment"] = return_fragment
tts_req["fixed_length_chunk"] = fixed_length_chunk
return tts_req, media_type, response_stream
def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]:
tts_req, media_type, response_stream = build_tts_request(payload)
if response_stream:
raise HTTPException(status_code=400, detail="base64 output does not support streaming")
if tts_pipeline is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
try:
with infer_lock:
sr, audio_data = next(tts_pipeline.run(tts_req))
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return audio_bytes, media_type
except Exception as exc:
raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc
def synthesize_response(payload: Dict[str, Any]) -> Response:
tts_req, media_type, response_stream = build_tts_request(payload)
if tts_pipeline is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
if response_stream:
def streaming_generator() -> Generator[bytes, None, None]:
first_chunk = True
chunk_media_type = media_type
with infer_lock:
for sr, chunk in tts_pipeline.run(tts_req):
if first_chunk and chunk_media_type == "wav":
yield wave_header_chunk(sample_rate=sr)
chunk_media_type = "raw"
first_chunk = False
yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue()
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
try:
with infer_lock:
sr, audio_data = next(tts_pipeline.run(tts_req))
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return Response(audio_bytes, media_type=f"audio/{media_type}")
except Exception as exc:
return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)})
@APP.on_event("startup")
def startup() -> None:
global tts_config, tts_pipeline
tts_config = TTS_Config(tts_config_path)
print(tts_config)
tts_pipeline = TTS(tts_config)
@APP.get("/", tags=["System"])
def index() -> Dict[str, Any]:
return {
"name": "GPT-SoVITS Simple API",
"version": "1.1.0",
"docs": "/docs",
"endpoints": {
"system": ["/health", "/voices"],
"mvp": ["/api/tts"],
"profile": ["/speak", "/speak/base64", "/v1/tts"],
"admin": ["/admin/reload-config", "/admin/weights"],
},
}
@APP.get("/health", tags=["System"], summary="健康检查")
def health() -> Dict[str, Any]:
import os
result: Dict[str, Any] = {
"status": "ok" if tts_pipeline is not None else "starting",
"tts_config": tts_config_path,
"version": getattr(tts_config, "version", None),
"languages": getattr(tts_config, "languages", []),
"pid": os.getpid(),
}
try:
import psutil
result["memory_mb"] = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 1)
except Exception:
pass
try:
import torch
if torch.cuda.is_available():
result["gpu"] = {
"name": torch.cuda.get_device_name(0),
"memory_used_mb": round(torch.cuda.memory_allocated(0) / 1024 / 1024, 1),
"memory_total_mb": round(torch.cuda.get_device_properties(0).total_mem / 1024 / 1024, 1),
}
except Exception:
pass
return result
@APP.get("/voices", tags=["System"], summary="列出可用 voice profiles")
def list_voices() -> Dict[str, Any]:
voices = simple_config.get("voices", {}) or {}
return {
"default_voice": get_default_voice_name(),
"voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()],
}
@APP.post(
"/api/tts",
tags=["MVP"],
summary="核心 TTS 接口",
description=(
"上传参考音频和需要生成的文字,返回生成的音频。\n\n"
"**主参考音频要求**3-10 秒,支持 wav/flac/ogg/mp3/m4a/aac 格式。\n\n"
"**文字切句**:固定使用 `cut5`(按标点符号切句)。\n\n"
"**情绪预设**neutral / happy / calm / sad / angry本质是映射到采样和语速参数。"
),
)
async def mvp_tts(
text: str = Form(...),
ref_audio: UploadFile = File(...),
aux_ref_audio: Optional[List[UploadFile]] = File(default=None),
prompt_text: str = Form(default=""),
text_lang: str = Form(default="zh"),
prompt_lang: str = Form(default="zh"),
format: str = Form(default="wav"),
emotion: Optional[str] = Form(default=None),
speed: Optional[float] = Form(default=None),
seed: int = Form(default=-1),
) -> Response:
request_dir = get_request_upload_dir()
try:
ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref")
validate_main_ref_duration(ref_audio_path)
aux_ref_audio_paths = []
for index, upload in enumerate(aux_ref_audio or []):
aux_path = await save_upload_file(upload, request_dir, f"aux_{index}")
get_audio_duration_seconds(aux_path)
aux_ref_audio_paths.append(aux_path)
payload: Dict[str, Any] = {
"text": text,
"text_lang": text_lang,
"ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": aux_ref_audio_paths,
"prompt_text": prompt_text or "",
"prompt_lang": prompt_lang,
"format": format,
"text_split_method": "cut5",
"streaming_mode": 0,
"seed": seed,
}
apply_emotion_preset(payload, emotion)
payload["text_split_method"] = "cut5"
payload["streaming_mode"] = 0
if speed is not None:
payload["speed"] = speed
return synthesize_response(payload)
finally:
shutil.rmtree(request_dir, ignore_errors=True)
@APP.get("/speak", tags=["Profile"], summary="GET 方式调用 voice profile TTS")
def speak_get(
text: str,
voice: Optional[str] = None,
text_lang: Optional[str] = None,
format: Optional[str] = None,
stream: Optional[bool] = None,
speed: Optional[float] = None,
) -> Response:
payload = {
"text": text,
"voice": voice,
"text_lang": text_lang,
"format": format,
"stream": stream,
"speed": speed,
}
return synthesize_response({k: v for k, v in payload.items() if v is not None})
@APP.post("/speak", tags=["Profile"], summary="POST 方式调用 voice profile TTS")
def speak_post(request: SpeakRequest) -> Response:
return synthesize_response(request_to_dict(request))
@APP.post("/v1/tts", tags=["Profile"], summary="OpenAI 兼容格式 TTS")
def openai_style_tts(request: SpeakRequest) -> Response:
return synthesize_response(request_to_dict(request))
@APP.post("/speak/base64", tags=["Profile"], summary="返回 Base64 编码的音频")
def speak_base64(request: SpeakRequest) -> Dict[str, Any]:
audio_bytes, media_type = synthesize_once(request_to_dict(request))
return {
"media_type": f"audio/{media_type}",
"audio_base64": base64.b64encode(audio_bytes).decode("ascii"),
}
@APP.post("/admin/reload-config", tags=["Admin"], summary="热加载 simple_api.yaml 配置")
def reload_config() -> Dict[str, Any]:
global simple_config
simple_config = load_yaml_config(args.config)
return {"message": "success", "default_voice": get_default_voice_name()}
@APP.post("/admin/weights", tags=["Admin"], summary="运行时切换 GPT-SoVITS 模型权重")
def set_weights(request: WeightsRequest) -> Dict[str, Any]:
if tts_pipeline is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
if not request.gpt_weights_path and not request.sovits_weights_path:
raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required")
try:
with infer_lock:
if request.gpt_weights_path:
tts_pipeline.init_t2s_weights(request.gpt_weights_path)
if request.sovits_weights_path:
tts_pipeline.init_vits_weights(request.sovits_weights_path)
except Exception as exc:
raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc
return {"message": "success"}
if __name__ == "__main__":
import uvicorn
try:
uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1)
except Exception:
traceback.print_exc()
sys.exit(1)

59
simple_api.yaml Normal file
View File

@ -0,0 +1,59 @@
server:
host: 127.0.0.1
port: 9881
tts_config: GPT_SoVITS/configs/tts_infer.yaml
cors_allow_origins:
- "*"
upload:
dir: runtime/uploads
min_ref_seconds: 3
max_ref_seconds: 10
max_upload_mb: 80
default_voice: default
defaults:
text_lang: zh
prompt_lang: zh
media_type: wav
text_split_method: cut5
batch_size: 1
batch_threshold: 0.75
split_bucket: true
speed_factor: 1.0
fragment_interval: 0.3
seed: -1
parallel_infer: true
repetition_penalty: 1.35
sample_steps: 32
super_sampling: false
overlap_length: 2
min_chunk_length: 16
emotion_presets:
neutral: {}
happy:
temperature: 1.1
top_p: 0.95
calm:
temperature: 0.8
top_p: 0.85
speed_factor: 0.92
sad:
temperature: 0.75
top_p: 0.85
speed_factor: 0.9
angry:
temperature: 1.2
top_k: 20
repetition_penalty: 1.25
voices:
default:
description: Replace this profile with your reference voice.
ref_audio_path: reference.wav
prompt_text: Replace this with the exact text spoken in reference.wav.
prompt_lang: zh
text_lang: zh

1006
test_frontend/index.html Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,303 @@
import asyncio
import importlib.util
import sys
import tempfile
import types
import unittest
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
SIMPLE_API_PATH = PROJECT_ROOT / "simple_api.py"
class DummyHTTPException(Exception):
def __init__(self, status_code=500, detail=None):
super().__init__(detail)
self.status_code = status_code
self.detail = detail
class DummyFastAPI:
def __init__(self, *args, **kwargs):
self.routes = []
def add_middleware(self, *args, **kwargs):
pass
def mount(self, path, app, name=None):
self.routes.append(types.SimpleNamespace(path=path, endpoint=app, name=name))
def get(self, path, **kwargs):
return self._route(path)
def post(self, path, **kwargs):
return self._route(path)
def on_event(self, event):
return lambda func: func
def _route(self, path):
def decorator(func):
self.routes.append(types.SimpleNamespace(path=path, endpoint=func))
return func
return decorator
class DummyBaseModel:
def dict(self, exclude_none=False):
data = dict(self.__dict__)
if exclude_none:
data = {key: value for key, value in data.items() if value is not None}
return data
class DummyTTSConfig:
version = "v2"
languages = ["zh", "en", "auto"]
use_vocoder = False
class FakeUpload:
def __init__(self, filename, chunks, content_type="audio/wav"):
self.filename = filename
self.content_type = content_type
self._chunks = list(chunks)
self.closed = False
async def read(self, size):
del size
if self._chunks:
return self._chunks.pop(0)
return b""
async def close(self):
self.closed = True
def install_dummy_modules():
fastapi = types.ModuleType("fastapi")
fastapi.FastAPI = DummyFastAPI
fastapi.File = lambda default=None, **kwargs: default
fastapi.Form = lambda default=None, **kwargs: default
fastapi.HTTPException = DummyHTTPException
fastapi.Response = type("Response", (), {})
fastapi.UploadFile = type("UploadFile", (), {})
sys.modules["fastapi"] = fastapi
middleware = types.ModuleType("fastapi.middleware")
cors = types.ModuleType("fastapi.middleware.cors")
cors.CORSMiddleware = type("CORSMiddleware", (), {})
sys.modules["fastapi.middleware"] = middleware
sys.modules["fastapi.middleware.cors"] = cors
responses = types.ModuleType("fastapi.responses")
responses.JSONResponse = type("JSONResponse", (), {})
responses.StreamingResponse = type("StreamingResponse", (), {})
sys.modules["fastapi.responses"] = responses
staticfiles = types.ModuleType("fastapi.staticfiles")
staticfiles.StaticFiles = type("StaticFiles", (), {"__init__": lambda self, *args, **kwargs: None})
sys.modules["fastapi.staticfiles"] = staticfiles
pydantic = types.ModuleType("pydantic")
pydantic.BaseModel = DummyBaseModel
pydantic.Field = lambda default=None, **kwargs: default
sys.modules["pydantic"] = pydantic
numpy = types.ModuleType("numpy")
numpy.ndarray = object
sys.modules["numpy"] = numpy
soundfile = types.ModuleType("soundfile")
soundfile.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
sys.modules["soundfile"] = soundfile
yaml = types.ModuleType("yaml")
yaml.safe_load = lambda stream: {
"server": {"host": "127.0.0.1", "port": 9881, "tts_config": "dummy.yaml"},
"upload": {"dir": "runtime/test_uploads", "min_ref_seconds": 3, "max_ref_seconds": 10, "max_upload_mb": 1},
"defaults": {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "text_split_method": "cut5"},
"emotion_presets": {"calm": {"temperature": 0.8, "speed_factor": 0.9}},
"voices": {},
}
sys.modules["yaml"] = yaml
sys.modules["GPT_SoVITS"] = types.ModuleType("GPT_SoVITS")
sys.modules["GPT_SoVITS.TTS_infer_pack"] = types.ModuleType("GPT_SoVITS.TTS_infer_pack")
tts_module = types.ModuleType("GPT_SoVITS.TTS_infer_pack.TTS")
tts_module.TTS = type("DummyTTS", (), {})
tts_module.TTS_Config = DummyTTSConfig
sys.modules["GPT_SoVITS.TTS_infer_pack.TTS"] = tts_module
segmentation = types.ModuleType("GPT_SoVITS.TTS_infer_pack.text_segmentation_method")
segmentation.get_method_names = lambda: ["cut0", "cut5"]
sys.modules["GPT_SoVITS.TTS_infer_pack.text_segmentation_method"] = segmentation
def load_simple_api():
install_dummy_modules()
sys.argv = ["simple_api.py"]
spec = importlib.util.spec_from_file_location("simple_api_under_test", SIMPLE_API_PATH)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module.tts_config = DummyTTSConfig()
return module
class SimpleApiContractTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.simple_api = load_simple_api()
def test_api_tts_route_is_registered(self):
routes = {route.path for route in self.simple_api.APP.routes}
self.assertIn("/api/tts", routes)
def test_build_request_accepts_explicit_ref_without_voice_profile(self):
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
req, media_type, response_stream = self.simple_api.build_tts_request(
{
"text": "Hello, test.",
"ref_audio_path": ref.name,
"prompt_text": "",
"prompt_lang": "zh",
"text_lang": "zh",
"format": "wav",
"text_split_method": "cut5",
"streaming_mode": 0,
}
)
self.assertEqual(req["prompt_text"], "")
self.assertEqual(req["text_split_method"], "cut5")
self.assertEqual(req["ref_audio_path"], ref.name)
self.assertEqual(media_type, "wav")
self.assertFalse(response_stream)
def test_explicit_speed_overrides_emotion_speed_preset(self):
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
payload = {
"text": "Hello.",
"ref_audio_path": ref.name,
"prompt_lang": "zh",
"text_lang": "zh",
"format": "wav",
"streaming_mode": 0,
"speed": 1.05,
}
self.simple_api.apply_emotion_preset(payload, "calm")
payload["speed"] = 1.05
req, _, _ = self.simple_api.build_tts_request(payload)
self.assertEqual(req["temperature"], 0.8)
self.assertEqual(req["speed_factor"], 1.05)
def test_empty_prompt_text_is_rejected_for_vocoder_models(self):
original_use_vocoder = self.simple_api.tts_config.use_vocoder
original_version = self.simple_api.tts_config.version
self.simple_api.tts_config.use_vocoder = True
self.simple_api.tts_config.version = "v3"
try:
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
with self.assertRaises(DummyHTTPException) as exc:
self.simple_api.build_tts_request(
{
"text": "Hello.",
"ref_audio_path": ref.name,
"prompt_text": "",
"prompt_lang": "zh",
"text_lang": "zh",
"format": "wav",
"text_split_method": "cut5",
"streaming_mode": 0,
}
)
self.assertEqual(exc.exception.status_code, 400)
self.assertIn("prompt_text is required", str(exc.exception.detail))
finally:
self.simple_api.tts_config.use_vocoder = original_use_vocoder
self.simple_api.tts_config.version = original_version
def test_ref_audio_duration_limits(self):
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
self.simple_api.validate_main_ref_duration("ok.wav")
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 2, samplerate=16000)
with self.assertRaises(DummyHTTPException) as too_short:
self.simple_api.validate_main_ref_duration("short.wav")
self.assertEqual(too_short.exception.status_code, 400)
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 11, samplerate=16000)
with self.assertRaises(DummyHTTPException) as too_long:
self.simple_api.validate_main_ref_duration("long.wav")
self.assertEqual(too_long.exception.status_code, 400)
def test_save_upload_file_writes_and_closes(self):
with tempfile.TemporaryDirectory() as tmp:
upload = FakeUpload("ref.wav", [b"abc", b"def"])
saved_path = asyncio.run(self.simple_api.save_upload_file(upload, Path(tmp), "ref"))
self.assertTrue(upload.closed)
self.assertEqual(Path(saved_path).read_bytes(), b"abcdef")
def test_mvp_tts_builds_payload_from_uploads_and_cleans_tmp_dir(self):
captured = {}
with tempfile.TemporaryDirectory() as tmp:
request_dir = Path(tmp) / "request"
request_dir.mkdir()
original_get_request_upload_dir = self.simple_api.get_request_upload_dir
original_synthesize_response = self.simple_api.synthesize_response
original_sf_info = self.simple_api.sf.info
def fake_synthesize_response(payload):
captured.update(payload)
return "audio-response"
try:
self.simple_api.get_request_upload_dir = lambda: request_dir
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
self.simple_api.synthesize_response = fake_synthesize_response
response = asyncio.run(
self.simple_api.mvp_tts(
text="Hello, test.",
ref_audio=FakeUpload("ref.wav", [b"main"]),
aux_ref_audio=[FakeUpload("aux.wav", [b"aux"])],
prompt_text="",
text_lang="zh",
prompt_lang="zh",
format="wav",
emotion="calm",
speed=1.05,
seed=123,
)
)
finally:
self.simple_api.get_request_upload_dir = original_get_request_upload_dir
self.simple_api.synthesize_response = original_synthesize_response
self.simple_api.sf.info = original_sf_info
self.assertEqual(response, "audio-response")
self.assertFalse(request_dir.exists())
self.assertEqual(captured["text"], "Hello, test.")
self.assertEqual(captured["prompt_text"], "")
self.assertEqual(captured["text_lang"], "zh")
self.assertEqual(captured["prompt_lang"], "zh")
self.assertEqual(captured["format"], "wav")
self.assertEqual(captured["text_split_method"], "cut5")
self.assertEqual(captured["streaming_mode"], 0)
self.assertEqual(captured["seed"], 123)
self.assertEqual(captured["temperature"], 0.8)
self.assertEqual(captured["speed"], 1.05)
self.assertEqual(len(captured["aux_ref_audio_paths"]), 1)
if __name__ == "__main__":
unittest.main()