mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-03 12:38:12 +08:00
Merge a47b87bb7b1219d133523f88024888aa296594bf into bf81cdb14a38b674b6e9996dabc97340bc9978d2
This commit is contained in:
commit
e5f4f31a1f
39
README.md
39
README.md
@ -27,6 +27,44 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
|
||||
|
||||
---
|
||||
|
||||
## 简化接口 / 测试前端
|
||||
|
||||
本工作区新增了一个用于 MVP 调用的中间层接口和测试前端:
|
||||
|
||||
| 页面 | 地址 | 说明 |
|
||||
|------|------|------|
|
||||
| 测试前端 | http://127.0.0.1:9881/test/ | 带波形裁剪、视频转音频的测试 UI |
|
||||
| Swagger UI | http://127.0.0.1:9881/docs | 交互式 API 文档 |
|
||||
| ReDoc | http://127.0.0.1:9881/redoc | 可读式 API 文档 |
|
||||
|
||||
文档:
|
||||
|
||||
- 快速启动:[docs/simple_api_quickstart.md](./docs/simple_api_quickstart.md)
|
||||
- 完整教程:[docs/simple_api.md](./docs/simple_api.md)
|
||||
|
||||
核心功能:
|
||||
|
||||
- `/api/tts` — 上传参考音频/视频 + 文字,直接返回生成的音频
|
||||
- 前端支持视频上传,自动提取音频
|
||||
- 前端波形裁剪工具,可选择 3-10 秒片段
|
||||
- 5 种情绪预设(neutral/happy/calm/sad/angry)
|
||||
- 7 个契约测试,无需 GPU 即可运行
|
||||
|
||||
启动命令:
|
||||
|
||||
```
|
||||
cd D:\tts\GPT-SoVITS
|
||||
.\go-simple-api.ps1
|
||||
```
|
||||
|
||||
运行测试:
|
||||
|
||||
```
|
||||
python -m unittest tests.test_simple_api_contract -v
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Features:
|
||||
|
||||
1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion.
|
||||
@ -478,3 +516,4 @@ Thankful to @Naozumi520 for providing the Cantonese training set and for the gui
|
||||
<a href="https://github.com/RVC-Boss/GPT-SoVITS/graphs/contributors" target="_blank">
|
||||
<img src="https://contrib.rocks/image?repo=RVC-Boss/GPT-SoVITS" />
|
||||
</a>
|
||||
|
||||
|
||||
402
docs/simple_api.md
Normal file
402
docs/simple_api.md
Normal file
@ -0,0 +1,402 @@
|
||||
# GPT-SoVITS 简化接口文档
|
||||
|
||||
本项目新增 `simple_api.py` 作为中间层,封装 GPT-SoVITS 推理引擎,提供更简洁的调用方式。
|
||||
|
||||
## 快速开始
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
python -m pip install -r requirements.txt
|
||||
|
||||
# 启动
|
||||
python simple_api.py -c simple_api.yaml
|
||||
|
||||
# 访问
|
||||
Swagger UI: http://127.0.0.1:9881/docs
|
||||
ReDoc: http://127.0.0.1:9881/redoc
|
||||
测试前端: http://127.0.0.1:9881/test/
|
||||
```
|
||||
|
||||
## 接口总览
|
||||
|
||||
| 方法 | 路径 | 说明 | 标签 |
|
||||
|------|------|------|------|
|
||||
| GET | `/health` | 健康检查(含 GPU 信息) | System |
|
||||
| GET | `/voices` | 列出 voice profiles | System |
|
||||
| **POST** | **`/api/tts`** | **核心 TTS 接口(MVP)** | **MVP** |
|
||||
| GET | `/speak` | voice profile TTS (GET) | Profile |
|
||||
| POST | `/speak` | voice profile TTS (POST) | Profile |
|
||||
| POST | `/v1/tts` | OpenAI 兼容格式 TTS | Profile |
|
||||
| POST | `/speak/base64` | 返回 Base64 音频 | Profile |
|
||||
| POST | `/admin/reload-config` | 热加载配置 | Admin |
|
||||
| POST | `/admin/weights` | 切换模型权重 | Admin |
|
||||
|
||||
---
|
||||
|
||||
## 1. POST /api/tts — 核心 TTS 接口
|
||||
|
||||
**推荐使用此接口**。上传参考音频和文字,直接返回生成的音频。
|
||||
|
||||
### 请求格式
|
||||
|
||||
```
|
||||
Content-Type: multipart/form-data
|
||||
```
|
||||
|
||||
### 字段说明
|
||||
|
||||
| 字段 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| `text` | string | **是** | — | 需要生成的文字 |
|
||||
| `ref_audio` | file | **是** | — | 主参考音频,3-10 秒(支持 wav/flac/ogg/mp3/m4a/aac) |
|
||||
| `aux_ref_audio` | file[] | 否 | — | 辅助参考音频,可上传多个 |
|
||||
| `prompt_text` | string | 否 | `""` | 主参考音频对应文字(v2 可留空;v3/v4 必填) |
|
||||
| `text_lang` | string | 否 | `zh` | 生成文字语言:zh/en/ja/ko/yue/auto |
|
||||
| `prompt_lang` | string | 否 | `zh` | 参考音频语言:zh/en/ja/ko/yue/auto |
|
||||
| `format` | string | 否 | `wav` | 返回格式:wav/ogg/aac/raw |
|
||||
| `emotion` | string | 否 | `neutral` | 情绪预设:neutral/happy/calm/sad/angry |
|
||||
| `speed` | float | 否 | — | 语速(0.5-2.0),覆盖情绪预设中的语速 |
|
||||
| `seed` | int | 否 | `-1` | 随机种子,-1 为随机 |
|
||||
|
||||
### 情绪预设参数映射
|
||||
|
||||
| 情绪 | temperature | top_p | top_k | speed_factor | repetition_penalty |
|
||||
|------|-------------|-------|-------|--------------|-------------------|
|
||||
| neutral | — | — | — | — | — |
|
||||
| happy | 1.1 | 0.95 | — | — | — |
|
||||
| calm | 0.8 | 0.85 | — | 0.92 | — |
|
||||
| sad | 0.75 | 0.85 | — | 0.9 | — |
|
||||
| angry | 1.2 | — | 20 | — | 1.25 |
|
||||
|
||||
> 显式传入 `speed` 会覆盖情绪预设中的 `speed_factor`。
|
||||
|
||||
### curl 示例
|
||||
|
||||
**基础调用:**
|
||||
|
||||
```powershell
|
||||
curl.exe -X POST http://127.0.0.1:9881/api/tts `
|
||||
-F "text=你好,欢迎使用这个声音。" `
|
||||
-F "ref_audio=@D:\audio\ref.wav" `
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
**带辅助参考音频和情绪:**
|
||||
|
||||
```powershell
|
||||
curl.exe -X POST http://127.0.0.1:9881/api/tts `
|
||||
-F "text=你好,欢迎使用这个声音。" `
|
||||
-F "ref_audio=@D:\audio\ref.wav" `
|
||||
-F "aux_ref_audio=@D:\audio\aux1.wav" `
|
||||
-F "emotion=happy" `
|
||||
-F "speed=1.1" `
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
**Linux/macOS:**
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/api/tts \
|
||||
-F "text=你好,欢迎使用这个声音。" \
|
||||
-F "ref_audio=@/path/to/ref.wav" \
|
||||
-F "emotion=calm" \
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
### 返回
|
||||
|
||||
- 成功:音频二进制流(Content-Type: `audio/wav` 等)
|
||||
- 失败:JSON 错误信息
|
||||
|
||||
```json
|
||||
{"message": "tts failed", "exception": "..."}
|
||||
```
|
||||
|
||||
### 常见错误
|
||||
|
||||
| HTTP 状态码 | 原因 |
|
||||
|------------|------|
|
||||
| 400 | text 为空 / ref_audio 缺失 / 音频时长不在 3-10 秒 / 不支持的 format / v3/v4 时 prompt_text 为空 |
|
||||
| 404 | voice profile 不存在(仅 /speak 接口) |
|
||||
| 503 | TTS pipeline 未就绪(模型未加载) |
|
||||
|
||||
---
|
||||
|
||||
## 2. GET /health — 健康检查
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:9881/health
|
||||
```
|
||||
|
||||
返回示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "ok",
|
||||
"tts_config": "GPT_SoVITS/configs/tts_infer.yaml",
|
||||
"version": "v2",
|
||||
"languages": ["auto", "en", "zh"],
|
||||
"pid": 12345,
|
||||
"memory_mb": 2048.5,
|
||||
"gpu": {
|
||||
"name": "NVIDIA GeForce RTX 3080",
|
||||
"memory_used_mb": 4096.2,
|
||||
"memory_total_mb": 10240.0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. GET /voices — 列出 voice profiles
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:9881/voices
|
||||
```
|
||||
|
||||
返回示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"default_voice": "default",
|
||||
"voices": [
|
||||
{
|
||||
"name": "default",
|
||||
"description": "Replace this profile with your reference voice.",
|
||||
"text_lang": "zh",
|
||||
"prompt_lang": "zh",
|
||||
"ref_audio_path": "reference.wav",
|
||||
"ready": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. POST /speak — voice profile TTS
|
||||
|
||||
基于 `simple_api.yaml` 中配置的 voice profile 调用 TTS。
|
||||
|
||||
### 请求体(JSON)
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "hello world",
|
||||
"voice": "default",
|
||||
"text_lang": "zh",
|
||||
"format": "wav",
|
||||
"speed": 1.0
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `text` | string | **是** | 需要生成的文字 |
|
||||
| `voice` | string | 否 | voice profile 名称,不传则使用 default |
|
||||
| `text_lang` | string | 否 | 生成文字语言 |
|
||||
| `format` | string | 否 | 返回格式 |
|
||||
| `stream` | bool | 否 | 是否流式返回 |
|
||||
| `speed` | float | 否 | 语速 |
|
||||
|
||||
### curl 示例
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/speak \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text":"你好世界","voice":"default"}' \
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. GET /speak — voice profile TTS (GET)
|
||||
|
||||
与 POST /speak 相同,但通过 URL 参数传递。
|
||||
|
||||
```
|
||||
GET /speak?text=hello&voice=default&format=wav
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. POST /speak/base64 — 返回 Base64 音频
|
||||
|
||||
返回 Base64 编码的音频,适合 Web 前端直接使用。
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/speak/base64 \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text":"hello","voice":"default"}'
|
||||
```
|
||||
|
||||
返回:
|
||||
|
||||
```json
|
||||
{
|
||||
"media_type": "audio/wav",
|
||||
"audio_base64": "UklGRi..."
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. POST /v1/tts — OpenAI 兼容格式
|
||||
|
||||
请求格式与 POST /speak 相同,路径兼容 OpenAI TTS API 风格。
|
||||
|
||||
---
|
||||
|
||||
## 8. POST /admin/reload-config — 热加载配置
|
||||
|
||||
重新加载 `simple_api.yaml`,无需重启服务。
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/admin/reload-config
|
||||
```
|
||||
|
||||
返回:`{"message": "success", "default_voice": "default"}`
|
||||
|
||||
---
|
||||
|
||||
## 9. POST /admin/weights — 切换模型权重
|
||||
|
||||
运行时切换 GPT-SoVITS 模型权重文件。
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/admin/weights \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"gpt_weights_path":"path/to/gpt.pt","sovits_weights_path":"path/to/sovits.pt"}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 配置文件
|
||||
|
||||
`simple_api.yaml`:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: 127.0.0.1
|
||||
port: 9881
|
||||
tts_config: GPT_SoVITS/configs/tts_infer.yaml
|
||||
|
||||
cors_allow_origins:
|
||||
- "*"
|
||||
|
||||
upload:
|
||||
dir: runtime/uploads
|
||||
min_ref_seconds: 3
|
||||
max_ref_seconds: 10
|
||||
max_upload_mb: 80
|
||||
|
||||
defaults:
|
||||
text_lang: zh
|
||||
prompt_lang: zh
|
||||
media_type: wav
|
||||
text_split_method: cut5
|
||||
batch_size: 1
|
||||
speed_factor: 1.0
|
||||
seed: -1
|
||||
|
||||
emotion_presets:
|
||||
neutral: {}
|
||||
happy:
|
||||
temperature: 1.1
|
||||
top_p: 0.95
|
||||
calm:
|
||||
temperature: 0.8
|
||||
top_p: 0.85
|
||||
speed_factor: 0.92
|
||||
sad:
|
||||
temperature: 0.75
|
||||
top_p: 0.85
|
||||
speed_factor: 0.9
|
||||
angry:
|
||||
temperature: 1.2
|
||||
top_k: 20
|
||||
repetition_penalty: 1.25
|
||||
|
||||
voices:
|
||||
default:
|
||||
description: Replace this profile with your reference voice.
|
||||
ref_audio_path: reference.wav
|
||||
prompt_text: Replace this with the exact text spoken in reference.wav.
|
||||
prompt_lang: zh
|
||||
text_lang: zh
|
||||
```
|
||||
|
||||
### 配置说明
|
||||
|
||||
| 配置项 | 说明 |
|
||||
|--------|------|
|
||||
| `server.host` | 监听地址 |
|
||||
| `server.port` | 监听端口 |
|
||||
| `server.tts_config` | GPT-SoVITS 推理配置文件路径 |
|
||||
| `upload.dir` | 临时上传目录 |
|
||||
| `upload.min_ref_seconds` | 主参考音频最短秒数 |
|
||||
| `upload.max_ref_seconds` | 主参考音频最长秒数 |
|
||||
| `upload.max_upload_mb` | 单个上传文件最大体积 (MB) |
|
||||
| `defaults.*` | 所有接口的默认参数 |
|
||||
| `emotion_presets.*` | 情绪预设参数映射 |
|
||||
| `voices.*` | 固定音色 profile |
|
||||
|
||||
---
|
||||
|
||||
## 添加自定义音色
|
||||
|
||||
编辑 `simple_api.yaml`,在 `voices` 下添加:
|
||||
|
||||
```yaml
|
||||
voices:
|
||||
narrator:
|
||||
description: "男声旁白"
|
||||
ref_audio_path: voices/narrator.wav
|
||||
prompt_text: "旁白参考音频的逐字稿"
|
||||
prompt_lang: zh
|
||||
text_lang: zh
|
||||
```
|
||||
|
||||
然后热加载:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/admin/reload-config
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 测试
|
||||
|
||||
### 契约测试(无需 GPU)
|
||||
|
||||
```bash
|
||||
python -m unittest tests.test_simple_api_contract -v
|
||||
```
|
||||
|
||||
覆盖:
|
||||
|
||||
- `/api/tts` 路由注册
|
||||
- 上传接口参数构造
|
||||
- 主参考音频 3-10 秒校验
|
||||
- v2 空 prompt_text 允许 / v3/v4 空 prompt_text 拒绝
|
||||
- 临时上传目录清理
|
||||
- 情绪预设应用与 speed 覆盖
|
||||
|
||||
### 前端测试
|
||||
|
||||
1. 启动后端
|
||||
2. 访问 `http://127.0.0.1:9881/test/`
|
||||
3. 上传音频或视频(视频会自动提取音频)
|
||||
4. 使用波形裁剪工具选择 3-10 秒片段
|
||||
5. 填写文字,选择情绪和语速
|
||||
6. 点击生成
|
||||
|
||||
---
|
||||
|
||||
## 启动脚本
|
||||
|
||||
| 脚本 | 平台 | 说明 |
|
||||
|------|------|------|
|
||||
| `go-simple-api.ps1` | Windows PowerShell | 自动检测 runtime\python.exe |
|
||||
| `go-simple-api.bat` | Windows CMD | 同上 |
|
||||
| `open-test-frontend.ps1` | Windows PowerShell | 直接打开测试前端 HTML |
|
||||
85
docs/simple_api_quickstart.md
Normal file
85
docs/simple_api_quickstart.md
Normal file
@ -0,0 +1,85 @@
|
||||
# 简化 TTS 接口快速启动
|
||||
|
||||
完整教程见:`docs/simple_api.md`
|
||||
|
||||
## 一句话流程
|
||||
|
||||
启动后端,打开测试页,上传 3-10 秒参考音频或视频(视频自动提取音频),填写生成文字,点击生成。
|
||||
|
||||
## 1. 启动后端
|
||||
|
||||
```powershell
|
||||
cd D:\tts\GPT-SoVITS
|
||||
python -m pip install -r requirements.txt
|
||||
.\go-simple-api.ps1
|
||||
```
|
||||
|
||||
也可以运行:
|
||||
|
||||
```bat
|
||||
go-simple-api.bat
|
||||
```
|
||||
|
||||
默认后端地址:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881
|
||||
```
|
||||
|
||||
## 2. 打开页面
|
||||
|
||||
| 页面 | 地址 | 说明 |
|
||||
|------|------|------|
|
||||
| 测试前端 | http://127.0.0.1:9881/test/ | 带波形裁剪的测试 UI |
|
||||
| Swagger UI | http://127.0.0.1:9881/docs | 交互式 API 文档 |
|
||||
| ReDoc | http://127.0.0.1:9881/redoc | 可读式 API 文档 |
|
||||
|
||||
## 3. 调用接口
|
||||
|
||||
接口:
|
||||
|
||||
```http
|
||||
POST /api/tts
|
||||
Content-Type: multipart/form-data
|
||||
```
|
||||
|
||||
最小字段:
|
||||
|
||||
```text
|
||||
text 要生成的文字
|
||||
ref_audio 3-10 秒主参考音频(或视频,前端自动提取)
|
||||
```
|
||||
|
||||
常用可选字段:
|
||||
|
||||
```text
|
||||
aux_ref_audio 辅助参考音频,可多个
|
||||
prompt_text 参考音频文本,可留空
|
||||
text_lang 默认 zh
|
||||
prompt_lang 默认 zh
|
||||
emotion neutral / happy / calm / sad / angry
|
||||
speed 语速,默认 1
|
||||
seed 默认 -1
|
||||
```
|
||||
|
||||
## 4. PowerShell 示例
|
||||
|
||||
```powershell
|
||||
curl.exe -X POST http://127.0.0.1:9881/api/tts `
|
||||
-F "text=你好,欢迎使用这个声音。" `
|
||||
-F "ref_audio=@D:\audio\ref.wav" `
|
||||
-F "prompt_text=" `
|
||||
-F "text_lang=zh" `
|
||||
-F "prompt_lang=zh" `
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
## 5. 注意事项
|
||||
|
||||
- 主参考音频必须是 3-10 秒。
|
||||
- 前端支持上传视频文件,会自动提取音频。
|
||||
- 前端提供波形裁剪工具,可直接选择 3-10 秒片段。
|
||||
- `prompt_text` 在当前 v2 配置下可以为空。
|
||||
- 如果切换到 v3/v4,空 `prompt_text` 会返回 400。
|
||||
- 生成文字固定按标点符号切句。
|
||||
- 更详细的配置、profile、base64 接口见 `docs/simple_api.md`。
|
||||
12
go-simple-api.bat
Normal file
12
go-simple-api.bat
Normal file
@ -0,0 +1,12 @@
|
||||
@echo off
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
|
||||
cd /d "%SCRIPT_DIR%"
|
||||
|
||||
if exist "%SCRIPT_DIR%\runtime\python.exe" (
|
||||
"%SCRIPT_DIR%\runtime\python.exe" "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
|
||||
) else (
|
||||
python "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
|
||||
)
|
||||
|
||||
pause
|
||||
13
go-simple-api.ps1
Normal file
13
go-simple-api.ps1
Normal file
@ -0,0 +1,13 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
chcp 65001
|
||||
Set-Location $PSScriptRoot
|
||||
|
||||
$runtimePython = Join-Path $PSScriptRoot "runtime\python.exe"
|
||||
$apiScript = Join-Path $PSScriptRoot "simple_api.py"
|
||||
$apiConfig = Join-Path $PSScriptRoot "simple_api.yaml"
|
||||
|
||||
if (Test-Path $runtimePython) {
|
||||
& $runtimePython $apiScript -c $apiConfig @args
|
||||
} else {
|
||||
python $apiScript -c $apiConfig @args
|
||||
}
|
||||
3
open-test-frontend.ps1
Normal file
3
open-test-frontend.ps1
Normal file
@ -0,0 +1,3 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
Set-Location $PSScriptRoot
|
||||
Start-Process (Join-Path $PSScriptRoot "test_frontend\index.html")
|
||||
@ -3,6 +3,7 @@ numpy<2.0
|
||||
scipy
|
||||
tensorboard
|
||||
librosa==0.10.2
|
||||
soundfile
|
||||
numba
|
||||
pytorch-lightning>=2.4
|
||||
gradio<5
|
||||
@ -22,6 +23,7 @@ transformers>=4.43,<=4.50
|
||||
peft<0.18.0
|
||||
chardet
|
||||
PyYAML
|
||||
python-multipart
|
||||
psutil
|
||||
jieba_fast
|
||||
jieba
|
||||
|
||||
787
simple_api.py
Normal file
787
simple_api.py
Normal file
@ -0,0 +1,787 @@
|
||||
"""
|
||||
Small profile-based API layer for GPT-SoVITS.
|
||||
|
||||
Run:
|
||||
python simple_api.py -c simple_api.yaml
|
||||
|
||||
Then call:
|
||||
POST http://127.0.0.1:9881/speak
|
||||
{"text": "hello", "voice": "default"}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
import wave
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import yaml
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent
|
||||
os.chdir(PROJECT_ROOT)
|
||||
sys.path.append(str(PROJECT_ROOT))
|
||||
sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS"))
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
|
||||
get_method_names as get_cut_method_names,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TTS_PARAMS: Dict[str, Any] = {
|
||||
"text": "",
|
||||
"text_lang": "zh",
|
||||
"ref_audio_path": "",
|
||||
"aux_ref_audio_paths": [],
|
||||
"prompt_text": "",
|
||||
"prompt_lang": "zh",
|
||||
"top_k": 15,
|
||||
"top_p": 1.0,
|
||||
"temperature": 1.0,
|
||||
"text_split_method": "cut5",
|
||||
"batch_size": 1,
|
||||
"batch_threshold": 0.75,
|
||||
"split_bucket": True,
|
||||
"speed_factor": 1.0,
|
||||
"fragment_interval": 0.3,
|
||||
"seed": -1,
|
||||
"media_type": "wav",
|
||||
"streaming_mode": False,
|
||||
"return_fragment": False,
|
||||
"fixed_length_chunk": False,
|
||||
"parallel_infer": True,
|
||||
"repetition_penalty": 1.35,
|
||||
"sample_steps": 32,
|
||||
"super_sampling": False,
|
||||
"overlap_length": 2,
|
||||
"min_chunk_length": 16,
|
||||
}
|
||||
|
||||
SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"}
|
||||
SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"}
|
||||
VOICE_META_KEYS = {"description"}
|
||||
|
||||
|
||||
class SpeakRequest(BaseModel):
|
||||
text: str
|
||||
voice: Optional[str] = None
|
||||
text_lang: Optional[str] = None
|
||||
ref_audio_path: Optional[str] = None
|
||||
aux_ref_audio_paths: Optional[List[str]] = None
|
||||
prompt_text: Optional[str] = None
|
||||
prompt_lang: Optional[str] = None
|
||||
format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac")
|
||||
stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2")
|
||||
streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false")
|
||||
speed: Optional[float] = Field(default=None, description="Alias of speed_factor")
|
||||
speed_factor: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
text_split_method: Optional[str] = None
|
||||
batch_size: Optional[int] = None
|
||||
batch_threshold: Optional[float] = None
|
||||
split_bucket: Optional[bool] = None
|
||||
fragment_interval: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
parallel_infer: Optional[bool] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
sample_steps: Optional[int] = None
|
||||
super_sampling: Optional[bool] = None
|
||||
overlap_length: Optional[int] = None
|
||||
min_chunk_length: Optional[int] = None
|
||||
|
||||
|
||||
class WeightsRequest(BaseModel):
|
||||
gpt_weights_path: Optional[str] = None
|
||||
sovits_weights_path: Optional[str] = None
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS simple API")
|
||||
parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path")
|
||||
parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path")
|
||||
parser.add_argument("-a", "--bind-addr", default=None, help="bind address")
|
||||
parser.add_argument("-p", "--port", type=int, default=None, help="bind port")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
path = Path(config_path)
|
||||
if not path.is_absolute():
|
||||
path = PROJECT_ROOT / path
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"simple API config not found: {path}")
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("simple API config must be a YAML object")
|
||||
return data
|
||||
|
||||
|
||||
args = parse_args()
|
||||
simple_config = load_yaml_config(args.config)
|
||||
server_config = simple_config.get("server", {}) or {}
|
||||
host = args.bind_addr or server_config.get("host", "127.0.0.1")
|
||||
port = args.port or int(server_config.get("port", 9881))
|
||||
tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml")
|
||||
|
||||
cut_method_names = get_cut_method_names()
|
||||
tts_config: Optional[TTS_Config] = None
|
||||
tts_pipeline: Optional[TTS] = None
|
||||
infer_lock = threading.Lock()
|
||||
|
||||
APP = FastAPI(
|
||||
title="GPT-SoVITS Simple API",
|
||||
description=(
|
||||
"简化接口层,封装 GPT-SoVITS 推理引擎。\n\n"
|
||||
"## 核心流程\n"
|
||||
"1. 上传 3-10 秒参考音频(或视频,前端自动提取音频)\n"
|
||||
"2. 填写需要生成的文字\n"
|
||||
"3. 调用 `/api/tts` 获取生成的音频\n\n"
|
||||
"## 其他接口\n"
|
||||
"- `/speak` — 基于 voice profile 的调用方式\n"
|
||||
"- `/v1/tts` — OpenAI 兼容格式\n"
|
||||
"- `/admin/*` — 管理接口(热加载配置、切换模型)"
|
||||
),
|
||||
version="1.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_tags=[
|
||||
{"name": "MVP", "description": "核心 TTS 接口,上传参考音频直接生成"},
|
||||
{"name": "Profile", "description": "基于 voice profile 的调用方式"},
|
||||
{"name": "Admin", "description": "管理接口:热加载配置、切换模型权重"},
|
||||
{"name": "System", "description": "健康检查与信息查询"},
|
||||
],
|
||||
)
|
||||
APP.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=simple_config.get("cors_allow_origins", ["*"]),
|
||||
allow_credentials=False,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
test_frontend_dir = PROJECT_ROOT / "test_frontend"
|
||||
if test_frontend_dir.exists():
|
||||
APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend")
|
||||
|
||||
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
def handle_pack_ogg() -> None:
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
try:
|
||||
threading.stack_size(4096 * 4096)
|
||||
pack_thread = threading.Thread(target=handle_pack_ogg)
|
||||
pack_thread.start()
|
||||
pack_thread.join()
|
||||
except (RuntimeError, ValueError):
|
||||
handle_pack_ogg()
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
del rate
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
del io_buffer
|
||||
wav_buffer = BytesIO()
|
||||
sf.write(wav_buffer, data, rate, format="wav")
|
||||
return wav_buffer
|
||||
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ar",
|
||||
str(rate),
|
||||
"-ac",
|
||||
"1",
|
||||
"-i",
|
||||
"pipe:0",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-b:a",
|
||||
"192k",
|
||||
"-vn",
|
||||
"-f",
|
||||
"adts",
|
||||
"pipe:1",
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes:
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
def request_to_dict(request: BaseModel) -> Dict[str, Any]:
|
||||
if hasattr(request, "model_dump"):
|
||||
return request.model_dump(exclude_none=True)
|
||||
return request.dict(exclude_none=True)
|
||||
|
||||
|
||||
def get_default_voice_name() -> Optional[str]:
|
||||
default_voice = simple_config.get("default_voice")
|
||||
if default_voice:
|
||||
return str(default_voice)
|
||||
voices = simple_config.get("voices", {}) or {}
|
||||
if "default" in voices:
|
||||
return "default"
|
||||
return next(iter(voices.keys()), None)
|
||||
|
||||
|
||||
def resolve_project_path(path_value: Optional[str]) -> Optional[str]:
|
||||
if path_value in [None, ""]:
|
||||
return None
|
||||
path = Path(str(path_value))
|
||||
if not path.is_absolute():
|
||||
path = PROJECT_ROOT / path
|
||||
return str(path)
|
||||
|
||||
|
||||
def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]:
|
||||
if streaming_mode is None:
|
||||
streaming_mode = 2 if stream else 0
|
||||
elif isinstance(streaming_mode, bool):
|
||||
streaming_mode = 2 if streaming_mode else 0
|
||||
|
||||
if streaming_mode == 0:
|
||||
return False, False, False, False
|
||||
if streaming_mode == 1:
|
||||
return False, True, False, True
|
||||
if streaming_mode == 2:
|
||||
return True, False, False, True
|
||||
if streaming_mode == 3:
|
||||
return True, False, True, True
|
||||
raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false")
|
||||
|
||||
|
||||
def get_upload_config() -> Dict[str, Any]:
|
||||
return simple_config.get("upload", {}) or {}
|
||||
|
||||
|
||||
def get_upload_root() -> Path:
|
||||
upload_dir = str(get_upload_config().get("dir", "runtime/uploads"))
|
||||
root = Path(upload_dir)
|
||||
if not root.is_absolute():
|
||||
root = PROJECT_ROOT / root
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def get_request_upload_dir() -> Path:
|
||||
request_dir = get_upload_root() / uuid.uuid4().hex
|
||||
request_dir.mkdir(parents=True, exist_ok=True)
|
||||
return request_dir
|
||||
|
||||
|
||||
def get_upload_limits() -> Tuple[float, float, int]:
|
||||
upload_config = get_upload_config()
|
||||
min_seconds = float(upload_config.get("min_ref_seconds", 3.0))
|
||||
max_seconds = float(upload_config.get("max_ref_seconds", 10.0))
|
||||
max_upload_mb = int(upload_config.get("max_upload_mb", 80))
|
||||
return min_seconds, max_seconds, max_upload_mb
|
||||
|
||||
|
||||
def get_upload_suffix(upload: UploadFile) -> str:
|
||||
suffix = Path(upload.filename or "").suffix.lower()
|
||||
if suffix:
|
||||
return suffix
|
||||
content_type = (upload.content_type or "").lower()
|
||||
if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}:
|
||||
return ".wav"
|
||||
if content_type == "audio/mpeg":
|
||||
return ".mp3"
|
||||
if content_type == "audio/ogg":
|
||||
return ".ogg"
|
||||
if content_type in {"audio/aac", "audio/aacp"}:
|
||||
return ".aac"
|
||||
return ".wav"
|
||||
|
||||
|
||||
def validate_audio_upload(upload: UploadFile) -> None:
|
||||
suffix = get_upload_suffix(upload)
|
||||
content_type = (upload.content_type or "").lower()
|
||||
is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""}
|
||||
if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type:
|
||||
raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}")
|
||||
|
||||
|
||||
async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str:
|
||||
validate_audio_upload(upload)
|
||||
_, _, max_upload_mb = get_upload_limits()
|
||||
max_bytes = max_upload_mb * 1024 * 1024
|
||||
suffix = get_upload_suffix(upload)
|
||||
target_path = target_dir / f"{name_prefix}{suffix}"
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
with target_path.open("wb") as f:
|
||||
while True:
|
||||
chunk = await upload.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total_size += len(chunk)
|
||||
if total_size > max_bytes:
|
||||
raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB")
|
||||
f.write(chunk)
|
||||
if total_size == 0:
|
||||
raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}")
|
||||
finally:
|
||||
await upload.close()
|
||||
|
||||
return str(target_path)
|
||||
|
||||
|
||||
def get_audio_duration_with_ffprobe(audio_path: str) -> float:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
audio_path,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(result.stderr.strip() or "ffprobe failed")
|
||||
return float(result.stdout.strip())
|
||||
|
||||
|
||||
def get_audio_duration_seconds(audio_path: str) -> float:
|
||||
try:
|
||||
info = sf.info(audio_path)
|
||||
if info.samplerate <= 0:
|
||||
raise ValueError("invalid sample rate")
|
||||
return float(info.frames) / float(info.samplerate)
|
||||
except Exception as sf_exc:
|
||||
try:
|
||||
return get_audio_duration_with_ffprobe(audio_path)
|
||||
except Exception as ffprobe_exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}",
|
||||
) from ffprobe_exc
|
||||
|
||||
|
||||
def validate_main_ref_duration(audio_path: str) -> None:
|
||||
min_seconds, max_seconds, _ = get_upload_limits()
|
||||
duration = get_audio_duration_seconds(audio_path)
|
||||
if duration < min_seconds or duration > max_seconds:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s",
|
||||
)
|
||||
|
||||
|
||||
def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None:
|
||||
if emotion in [None, ""]:
|
||||
return
|
||||
emotion_name = str(emotion).strip().lower()
|
||||
emotion_presets = simple_config.get("emotion_presets", {}) or {}
|
||||
if emotion_name not in emotion_presets:
|
||||
supported = ", ".join(sorted(emotion_presets.keys())) or "none"
|
||||
raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}")
|
||||
payload.update(emotion_presets[emotion_name] or {})
|
||||
|
||||
|
||||
def prompt_text_required_for_current_model() -> bool:
|
||||
if tts_config is None:
|
||||
return False
|
||||
use_vocoder = getattr(tts_config, "use_vocoder", None)
|
||||
if use_vocoder is not None:
|
||||
return bool(use_vocoder)
|
||||
return getattr(tts_config, "version", None) in {"v3", "v4"}
|
||||
|
||||
|
||||
def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ref_audio_path = resolve_project_path(profile.get("ref_audio_path"))
|
||||
return {
|
||||
"name": name,
|
||||
"description": profile.get("description", ""),
|
||||
"text_lang": profile.get("text_lang"),
|
||||
"prompt_lang": profile.get("prompt_lang"),
|
||||
"ref_audio_path": profile.get("ref_audio_path"),
|
||||
"ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")),
|
||||
}
|
||||
|
||||
|
||||
def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]:
|
||||
voices = simple_config.get("voices", {}) or {}
|
||||
explicit_ref_audio = bool(payload.get("ref_audio_path"))
|
||||
voice_name = payload.pop("voice", None)
|
||||
if voice_name is None and not explicit_ref_audio:
|
||||
voice_name = get_default_voice_name()
|
||||
voice_profile: Dict[str, Any] = {}
|
||||
|
||||
if voice_name:
|
||||
if voice_name not in voices:
|
||||
raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}")
|
||||
voice_profile = deepcopy(voices[voice_name] or {})
|
||||
|
||||
tts_req = deepcopy(DEFAULT_TTS_PARAMS)
|
||||
tts_req.update(simple_config.get("defaults", {}) or {})
|
||||
tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS})
|
||||
|
||||
stream = payload.pop("stream", None)
|
||||
media_type = payload.pop("format", None)
|
||||
speed = payload.pop("speed", None)
|
||||
if media_type is not None:
|
||||
tts_req["media_type"] = media_type
|
||||
|
||||
for key, value in payload.items():
|
||||
if key in DEFAULT_TTS_PARAMS:
|
||||
tts_req[key] = value
|
||||
if speed is not None:
|
||||
tts_req["speed_factor"] = speed
|
||||
|
||||
ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path"))
|
||||
if ref_audio_path:
|
||||
tts_req["ref_audio_path"] = ref_audio_path
|
||||
|
||||
aux_paths = tts_req.get("aux_ref_audio_paths") or []
|
||||
tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item]
|
||||
|
||||
media_type = str(tts_req.get("media_type", "wav")).lower()
|
||||
tts_req["media_type"] = media_type
|
||||
if media_type not in SUPPORTED_MEDIA_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}")
|
||||
|
||||
text = str(tts_req.get("text") or "").strip()
|
||||
if not text:
|
||||
raise HTTPException(status_code=400, detail="text is required")
|
||||
tts_req["text"] = text
|
||||
|
||||
if not tts_req.get("ref_audio_path"):
|
||||
raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request")
|
||||
if not Path(tts_req["ref_audio_path"]).exists():
|
||||
raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}")
|
||||
|
||||
for aux_path in tts_req["aux_ref_audio_paths"]:
|
||||
if aux_path and not Path(aux_path).exists():
|
||||
raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}")
|
||||
|
||||
if tts_config is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
|
||||
text_lang = str(tts_req.get("text_lang") or "").lower()
|
||||
prompt_lang = str(tts_req.get("prompt_lang") or "").lower()
|
||||
tts_req["text_lang"] = text_lang
|
||||
tts_req["prompt_lang"] = prompt_lang
|
||||
|
||||
if text_lang not in tts_config.languages:
|
||||
raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}")
|
||||
if prompt_lang not in tts_config.languages:
|
||||
raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}")
|
||||
if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model():
|
||||
version = getattr(tts_config, "version", "current")
|
||||
raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}")
|
||||
|
||||
split_method = str(tts_req.get("text_split_method") or "cut5")
|
||||
if split_method not in cut_method_names:
|
||||
raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}")
|
||||
|
||||
streaming_mode = tts_req.pop("streaming_mode", None)
|
||||
streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream)
|
||||
tts_req["streaming_mode"] = streaming_enabled
|
||||
tts_req["return_fragment"] = return_fragment
|
||||
tts_req["fixed_length_chunk"] = fixed_length_chunk
|
||||
|
||||
return tts_req, media_type, response_stream
|
||||
|
||||
|
||||
def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]:
|
||||
tts_req, media_type, response_stream = build_tts_request(payload)
|
||||
if response_stream:
|
||||
raise HTTPException(status_code=400, detail="base64 output does not support streaming")
|
||||
|
||||
if tts_pipeline is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
|
||||
try:
|
||||
with infer_lock:
|
||||
sr, audio_data = next(tts_pipeline.run(tts_req))
|
||||
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return audio_bytes, media_type
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc
|
||||
|
||||
|
||||
def synthesize_response(payload: Dict[str, Any]) -> Response:
|
||||
tts_req, media_type, response_stream = build_tts_request(payload)
|
||||
if tts_pipeline is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
|
||||
if response_stream:
|
||||
|
||||
def streaming_generator() -> Generator[bytes, None, None]:
|
||||
first_chunk = True
|
||||
chunk_media_type = media_type
|
||||
with infer_lock:
|
||||
for sr, chunk in tts_pipeline.run(tts_req):
|
||||
if first_chunk and chunk_media_type == "wav":
|
||||
yield wave_header_chunk(sample_rate=sr)
|
||||
chunk_media_type = "raw"
|
||||
first_chunk = False
|
||||
yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue()
|
||||
|
||||
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
||||
|
||||
try:
|
||||
with infer_lock:
|
||||
sr, audio_data = next(tts_pipeline.run(tts_req))
|
||||
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return Response(audio_bytes, media_type=f"audio/{media_type}")
|
||||
except Exception as exc:
|
||||
return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)})
|
||||
|
||||
|
||||
@APP.on_event("startup")
|
||||
def startup() -> None:
|
||||
global tts_config, tts_pipeline
|
||||
tts_config = TTS_Config(tts_config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
|
||||
@APP.get("/", tags=["System"])
|
||||
def index() -> Dict[str, Any]:
|
||||
return {
|
||||
"name": "GPT-SoVITS Simple API",
|
||||
"version": "1.1.0",
|
||||
"docs": "/docs",
|
||||
"endpoints": {
|
||||
"system": ["/health", "/voices"],
|
||||
"mvp": ["/api/tts"],
|
||||
"profile": ["/speak", "/speak/base64", "/v1/tts"],
|
||||
"admin": ["/admin/reload-config", "/admin/weights"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@APP.get("/health", tags=["System"], summary="健康检查")
|
||||
def health() -> Dict[str, Any]:
|
||||
import os
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"status": "ok" if tts_pipeline is not None else "starting",
|
||||
"tts_config": tts_config_path,
|
||||
"version": getattr(tts_config, "version", None),
|
||||
"languages": getattr(tts_config, "languages", []),
|
||||
"pid": os.getpid(),
|
||||
}
|
||||
try:
|
||||
import psutil
|
||||
result["memory_mb"] = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 1)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
result["gpu"] = {
|
||||
"name": torch.cuda.get_device_name(0),
|
||||
"memory_used_mb": round(torch.cuda.memory_allocated(0) / 1024 / 1024, 1),
|
||||
"memory_total_mb": round(torch.cuda.get_device_properties(0).total_mem / 1024 / 1024, 1),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
@APP.get("/voices", tags=["System"], summary="列出可用 voice profiles")
|
||||
def list_voices() -> Dict[str, Any]:
|
||||
voices = simple_config.get("voices", {}) or {}
|
||||
return {
|
||||
"default_voice": get_default_voice_name(),
|
||||
"voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()],
|
||||
}
|
||||
|
||||
|
||||
@APP.post(
|
||||
"/api/tts",
|
||||
tags=["MVP"],
|
||||
summary="核心 TTS 接口",
|
||||
description=(
|
||||
"上传参考音频和需要生成的文字,返回生成的音频。\n\n"
|
||||
"**主参考音频要求**:3-10 秒,支持 wav/flac/ogg/mp3/m4a/aac 格式。\n\n"
|
||||
"**文字切句**:固定使用 `cut5`(按标点符号切句)。\n\n"
|
||||
"**情绪预设**:neutral / happy / calm / sad / angry,本质是映射到采样和语速参数。"
|
||||
),
|
||||
)
|
||||
async def mvp_tts(
|
||||
text: str = Form(...),
|
||||
ref_audio: UploadFile = File(...),
|
||||
aux_ref_audio: Optional[List[UploadFile]] = File(default=None),
|
||||
prompt_text: str = Form(default=""),
|
||||
text_lang: str = Form(default="zh"),
|
||||
prompt_lang: str = Form(default="zh"),
|
||||
format: str = Form(default="wav"),
|
||||
emotion: Optional[str] = Form(default=None),
|
||||
speed: Optional[float] = Form(default=None),
|
||||
seed: int = Form(default=-1),
|
||||
) -> Response:
|
||||
request_dir = get_request_upload_dir()
|
||||
try:
|
||||
ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref")
|
||||
validate_main_ref_duration(ref_audio_path)
|
||||
|
||||
aux_ref_audio_paths = []
|
||||
for index, upload in enumerate(aux_ref_audio or []):
|
||||
aux_path = await save_upload_file(upload, request_dir, f"aux_{index}")
|
||||
get_audio_duration_seconds(aux_path)
|
||||
aux_ref_audio_paths.append(aux_path)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"text": text,
|
||||
"text_lang": text_lang,
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_text": prompt_text or "",
|
||||
"prompt_lang": prompt_lang,
|
||||
"format": format,
|
||||
"text_split_method": "cut5",
|
||||
"streaming_mode": 0,
|
||||
"seed": seed,
|
||||
}
|
||||
apply_emotion_preset(payload, emotion)
|
||||
payload["text_split_method"] = "cut5"
|
||||
payload["streaming_mode"] = 0
|
||||
if speed is not None:
|
||||
payload["speed"] = speed
|
||||
|
||||
return synthesize_response(payload)
|
||||
finally:
|
||||
shutil.rmtree(request_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@APP.get("/speak", tags=["Profile"], summary="GET 方式调用 voice profile TTS")
|
||||
def speak_get(
|
||||
text: str,
|
||||
voice: Optional[str] = None,
|
||||
text_lang: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> Response:
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"text_lang": text_lang,
|
||||
"format": format,
|
||||
"stream": stream,
|
||||
"speed": speed,
|
||||
}
|
||||
return synthesize_response({k: v for k, v in payload.items() if v is not None})
|
||||
|
||||
|
||||
@APP.post("/speak", tags=["Profile"], summary="POST 方式调用 voice profile TTS")
|
||||
def speak_post(request: SpeakRequest) -> Response:
|
||||
return synthesize_response(request_to_dict(request))
|
||||
|
||||
|
||||
@APP.post("/v1/tts", tags=["Profile"], summary="OpenAI 兼容格式 TTS")
|
||||
def openai_style_tts(request: SpeakRequest) -> Response:
|
||||
return synthesize_response(request_to_dict(request))
|
||||
|
||||
|
||||
@APP.post("/speak/base64", tags=["Profile"], summary="返回 Base64 编码的音频")
|
||||
def speak_base64(request: SpeakRequest) -> Dict[str, Any]:
|
||||
audio_bytes, media_type = synthesize_once(request_to_dict(request))
|
||||
return {
|
||||
"media_type": f"audio/{media_type}",
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("ascii"),
|
||||
}
|
||||
|
||||
|
||||
@APP.post("/admin/reload-config", tags=["Admin"], summary="热加载 simple_api.yaml 配置")
|
||||
def reload_config() -> Dict[str, Any]:
|
||||
global simple_config
|
||||
simple_config = load_yaml_config(args.config)
|
||||
return {"message": "success", "default_voice": get_default_voice_name()}
|
||||
|
||||
|
||||
@APP.post("/admin/weights", tags=["Admin"], summary="运行时切换 GPT-SoVITS 模型权重")
|
||||
def set_weights(request: WeightsRequest) -> Dict[str, Any]:
|
||||
if tts_pipeline is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
if not request.gpt_weights_path and not request.sovits_weights_path:
|
||||
raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required")
|
||||
|
||||
try:
|
||||
with infer_lock:
|
||||
if request.gpt_weights_path:
|
||||
tts_pipeline.init_t2s_weights(request.gpt_weights_path)
|
||||
if request.sovits_weights_path:
|
||||
tts_pipeline.init_vits_weights(request.sovits_weights_path)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc
|
||||
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
try:
|
||||
uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
59
simple_api.yaml
Normal file
59
simple_api.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
server:
|
||||
host: 127.0.0.1
|
||||
port: 9881
|
||||
tts_config: GPT_SoVITS/configs/tts_infer.yaml
|
||||
|
||||
cors_allow_origins:
|
||||
- "*"
|
||||
|
||||
upload:
|
||||
dir: runtime/uploads
|
||||
min_ref_seconds: 3
|
||||
max_ref_seconds: 10
|
||||
max_upload_mb: 80
|
||||
|
||||
default_voice: default
|
||||
|
||||
defaults:
|
||||
text_lang: zh
|
||||
prompt_lang: zh
|
||||
media_type: wav
|
||||
text_split_method: cut5
|
||||
batch_size: 1
|
||||
batch_threshold: 0.75
|
||||
split_bucket: true
|
||||
speed_factor: 1.0
|
||||
fragment_interval: 0.3
|
||||
seed: -1
|
||||
parallel_infer: true
|
||||
repetition_penalty: 1.35
|
||||
sample_steps: 32
|
||||
super_sampling: false
|
||||
overlap_length: 2
|
||||
min_chunk_length: 16
|
||||
|
||||
emotion_presets:
|
||||
neutral: {}
|
||||
happy:
|
||||
temperature: 1.1
|
||||
top_p: 0.95
|
||||
calm:
|
||||
temperature: 0.8
|
||||
top_p: 0.85
|
||||
speed_factor: 0.92
|
||||
sad:
|
||||
temperature: 0.75
|
||||
top_p: 0.85
|
||||
speed_factor: 0.9
|
||||
angry:
|
||||
temperature: 1.2
|
||||
top_k: 20
|
||||
repetition_penalty: 1.25
|
||||
|
||||
voices:
|
||||
default:
|
||||
description: Replace this profile with your reference voice.
|
||||
ref_audio_path: reference.wav
|
||||
prompt_text: Replace this with the exact text spoken in reference.wav.
|
||||
prompt_lang: zh
|
||||
text_lang: zh
|
||||
1006
test_frontend/index.html
Normal file
1006
test_frontend/index.html
Normal file
File diff suppressed because it is too large
Load Diff
303
tests/test_simple_api_contract.py
Normal file
303
tests/test_simple_api_contract.py
Normal file
@ -0,0 +1,303 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SIMPLE_API_PATH = PROJECT_ROOT / "simple_api.py"
|
||||
|
||||
|
||||
class DummyHTTPException(Exception):
|
||||
def __init__(self, status_code=500, detail=None):
|
||||
super().__init__(detail)
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyFastAPI:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.routes = []
|
||||
|
||||
def add_middleware(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def mount(self, path, app, name=None):
|
||||
self.routes.append(types.SimpleNamespace(path=path, endpoint=app, name=name))
|
||||
|
||||
def get(self, path, **kwargs):
|
||||
return self._route(path)
|
||||
|
||||
def post(self, path, **kwargs):
|
||||
return self._route(path)
|
||||
|
||||
def on_event(self, event):
|
||||
return lambda func: func
|
||||
|
||||
def _route(self, path):
|
||||
def decorator(func):
|
||||
self.routes.append(types.SimpleNamespace(path=path, endpoint=func))
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class DummyBaseModel:
|
||||
def dict(self, exclude_none=False):
|
||||
data = dict(self.__dict__)
|
||||
if exclude_none:
|
||||
data = {key: value for key, value in data.items() if value is not None}
|
||||
return data
|
||||
|
||||
|
||||
class DummyTTSConfig:
|
||||
version = "v2"
|
||||
languages = ["zh", "en", "auto"]
|
||||
use_vocoder = False
|
||||
|
||||
|
||||
class FakeUpload:
|
||||
def __init__(self, filename, chunks, content_type="audio/wav"):
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
self._chunks = list(chunks)
|
||||
self.closed = False
|
||||
|
||||
async def read(self, size):
|
||||
del size
|
||||
if self._chunks:
|
||||
return self._chunks.pop(0)
|
||||
return b""
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
def install_dummy_modules():
|
||||
fastapi = types.ModuleType("fastapi")
|
||||
fastapi.FastAPI = DummyFastAPI
|
||||
fastapi.File = lambda default=None, **kwargs: default
|
||||
fastapi.Form = lambda default=None, **kwargs: default
|
||||
fastapi.HTTPException = DummyHTTPException
|
||||
fastapi.Response = type("Response", (), {})
|
||||
fastapi.UploadFile = type("UploadFile", (), {})
|
||||
sys.modules["fastapi"] = fastapi
|
||||
|
||||
middleware = types.ModuleType("fastapi.middleware")
|
||||
cors = types.ModuleType("fastapi.middleware.cors")
|
||||
cors.CORSMiddleware = type("CORSMiddleware", (), {})
|
||||
sys.modules["fastapi.middleware"] = middleware
|
||||
sys.modules["fastapi.middleware.cors"] = cors
|
||||
|
||||
responses = types.ModuleType("fastapi.responses")
|
||||
responses.JSONResponse = type("JSONResponse", (), {})
|
||||
responses.StreamingResponse = type("StreamingResponse", (), {})
|
||||
sys.modules["fastapi.responses"] = responses
|
||||
|
||||
staticfiles = types.ModuleType("fastapi.staticfiles")
|
||||
staticfiles.StaticFiles = type("StaticFiles", (), {"__init__": lambda self, *args, **kwargs: None})
|
||||
sys.modules["fastapi.staticfiles"] = staticfiles
|
||||
|
||||
pydantic = types.ModuleType("pydantic")
|
||||
pydantic.BaseModel = DummyBaseModel
|
||||
pydantic.Field = lambda default=None, **kwargs: default
|
||||
sys.modules["pydantic"] = pydantic
|
||||
|
||||
numpy = types.ModuleType("numpy")
|
||||
numpy.ndarray = object
|
||||
sys.modules["numpy"] = numpy
|
||||
|
||||
soundfile = types.ModuleType("soundfile")
|
||||
soundfile.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
|
||||
sys.modules["soundfile"] = soundfile
|
||||
|
||||
yaml = types.ModuleType("yaml")
|
||||
yaml.safe_load = lambda stream: {
|
||||
"server": {"host": "127.0.0.1", "port": 9881, "tts_config": "dummy.yaml"},
|
||||
"upload": {"dir": "runtime/test_uploads", "min_ref_seconds": 3, "max_ref_seconds": 10, "max_upload_mb": 1},
|
||||
"defaults": {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "text_split_method": "cut5"},
|
||||
"emotion_presets": {"calm": {"temperature": 0.8, "speed_factor": 0.9}},
|
||||
"voices": {},
|
||||
}
|
||||
sys.modules["yaml"] = yaml
|
||||
|
||||
sys.modules["GPT_SoVITS"] = types.ModuleType("GPT_SoVITS")
|
||||
sys.modules["GPT_SoVITS.TTS_infer_pack"] = types.ModuleType("GPT_SoVITS.TTS_infer_pack")
|
||||
|
||||
tts_module = types.ModuleType("GPT_SoVITS.TTS_infer_pack.TTS")
|
||||
tts_module.TTS = type("DummyTTS", (), {})
|
||||
tts_module.TTS_Config = DummyTTSConfig
|
||||
sys.modules["GPT_SoVITS.TTS_infer_pack.TTS"] = tts_module
|
||||
|
||||
segmentation = types.ModuleType("GPT_SoVITS.TTS_infer_pack.text_segmentation_method")
|
||||
segmentation.get_method_names = lambda: ["cut0", "cut5"]
|
||||
sys.modules["GPT_SoVITS.TTS_infer_pack.text_segmentation_method"] = segmentation
|
||||
|
||||
|
||||
def load_simple_api():
|
||||
install_dummy_modules()
|
||||
sys.argv = ["simple_api.py"]
|
||||
spec = importlib.util.spec_from_file_location("simple_api_under_test", SIMPLE_API_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
module.tts_config = DummyTTSConfig()
|
||||
return module
|
||||
|
||||
|
||||
class SimpleApiContractTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.simple_api = load_simple_api()
|
||||
|
||||
def test_api_tts_route_is_registered(self):
|
||||
routes = {route.path for route in self.simple_api.APP.routes}
|
||||
self.assertIn("/api/tts", routes)
|
||||
|
||||
def test_build_request_accepts_explicit_ref_without_voice_profile(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
|
||||
req, media_type, response_stream = self.simple_api.build_tts_request(
|
||||
{
|
||||
"text": "Hello, test.",
|
||||
"ref_audio_path": ref.name,
|
||||
"prompt_text": "",
|
||||
"prompt_lang": "zh",
|
||||
"text_lang": "zh",
|
||||
"format": "wav",
|
||||
"text_split_method": "cut5",
|
||||
"streaming_mode": 0,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(req["prompt_text"], "")
|
||||
self.assertEqual(req["text_split_method"], "cut5")
|
||||
self.assertEqual(req["ref_audio_path"], ref.name)
|
||||
self.assertEqual(media_type, "wav")
|
||||
self.assertFalse(response_stream)
|
||||
|
||||
def test_explicit_speed_overrides_emotion_speed_preset(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
|
||||
payload = {
|
||||
"text": "Hello.",
|
||||
"ref_audio_path": ref.name,
|
||||
"prompt_lang": "zh",
|
||||
"text_lang": "zh",
|
||||
"format": "wav",
|
||||
"streaming_mode": 0,
|
||||
"speed": 1.05,
|
||||
}
|
||||
self.simple_api.apply_emotion_preset(payload, "calm")
|
||||
payload["speed"] = 1.05
|
||||
req, _, _ = self.simple_api.build_tts_request(payload)
|
||||
|
||||
self.assertEqual(req["temperature"], 0.8)
|
||||
self.assertEqual(req["speed_factor"], 1.05)
|
||||
|
||||
def test_empty_prompt_text_is_rejected_for_vocoder_models(self):
|
||||
original_use_vocoder = self.simple_api.tts_config.use_vocoder
|
||||
original_version = self.simple_api.tts_config.version
|
||||
self.simple_api.tts_config.use_vocoder = True
|
||||
self.simple_api.tts_config.version = "v3"
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
|
||||
with self.assertRaises(DummyHTTPException) as exc:
|
||||
self.simple_api.build_tts_request(
|
||||
{
|
||||
"text": "Hello.",
|
||||
"ref_audio_path": ref.name,
|
||||
"prompt_text": "",
|
||||
"prompt_lang": "zh",
|
||||
"text_lang": "zh",
|
||||
"format": "wav",
|
||||
"text_split_method": "cut5",
|
||||
"streaming_mode": 0,
|
||||
}
|
||||
)
|
||||
self.assertEqual(exc.exception.status_code, 400)
|
||||
self.assertIn("prompt_text is required", str(exc.exception.detail))
|
||||
finally:
|
||||
self.simple_api.tts_config.use_vocoder = original_use_vocoder
|
||||
self.simple_api.tts_config.version = original_version
|
||||
|
||||
def test_ref_audio_duration_limits(self):
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
|
||||
self.simple_api.validate_main_ref_duration("ok.wav")
|
||||
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 2, samplerate=16000)
|
||||
with self.assertRaises(DummyHTTPException) as too_short:
|
||||
self.simple_api.validate_main_ref_duration("short.wav")
|
||||
self.assertEqual(too_short.exception.status_code, 400)
|
||||
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 11, samplerate=16000)
|
||||
with self.assertRaises(DummyHTTPException) as too_long:
|
||||
self.simple_api.validate_main_ref_duration("long.wav")
|
||||
self.assertEqual(too_long.exception.status_code, 400)
|
||||
|
||||
def test_save_upload_file_writes_and_closes(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
upload = FakeUpload("ref.wav", [b"abc", b"def"])
|
||||
saved_path = asyncio.run(self.simple_api.save_upload_file(upload, Path(tmp), "ref"))
|
||||
|
||||
self.assertTrue(upload.closed)
|
||||
self.assertEqual(Path(saved_path).read_bytes(), b"abcdef")
|
||||
|
||||
def test_mvp_tts_builds_payload_from_uploads_and_cleans_tmp_dir(self):
|
||||
captured = {}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
request_dir = Path(tmp) / "request"
|
||||
request_dir.mkdir()
|
||||
|
||||
original_get_request_upload_dir = self.simple_api.get_request_upload_dir
|
||||
original_synthesize_response = self.simple_api.synthesize_response
|
||||
original_sf_info = self.simple_api.sf.info
|
||||
|
||||
def fake_synthesize_response(payload):
|
||||
captured.update(payload)
|
||||
return "audio-response"
|
||||
|
||||
try:
|
||||
self.simple_api.get_request_upload_dir = lambda: request_dir
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
|
||||
self.simple_api.synthesize_response = fake_synthesize_response
|
||||
|
||||
response = asyncio.run(
|
||||
self.simple_api.mvp_tts(
|
||||
text="Hello, test.",
|
||||
ref_audio=FakeUpload("ref.wav", [b"main"]),
|
||||
aux_ref_audio=[FakeUpload("aux.wav", [b"aux"])],
|
||||
prompt_text="",
|
||||
text_lang="zh",
|
||||
prompt_lang="zh",
|
||||
format="wav",
|
||||
emotion="calm",
|
||||
speed=1.05,
|
||||
seed=123,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self.simple_api.get_request_upload_dir = original_get_request_upload_dir
|
||||
self.simple_api.synthesize_response = original_synthesize_response
|
||||
self.simple_api.sf.info = original_sf_info
|
||||
|
||||
self.assertEqual(response, "audio-response")
|
||||
self.assertFalse(request_dir.exists())
|
||||
|
||||
self.assertEqual(captured["text"], "Hello, test.")
|
||||
self.assertEqual(captured["prompt_text"], "")
|
||||
self.assertEqual(captured["text_lang"], "zh")
|
||||
self.assertEqual(captured["prompt_lang"], "zh")
|
||||
self.assertEqual(captured["format"], "wav")
|
||||
self.assertEqual(captured["text_split_method"], "cut5")
|
||||
self.assertEqual(captured["streaming_mode"], 0)
|
||||
self.assertEqual(captured["seed"], 123)
|
||||
self.assertEqual(captured["temperature"], 0.8)
|
||||
self.assertEqual(captured["speed"], 1.05)
|
||||
self.assertEqual(len(captured["aux_ref_audio_paths"]), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user