mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-03 20:48:14 +08:00
feat: add simple API layer with video support and test frontend
- Add simple_api.py: profile-based API that wraps GPT-SoVITS TTS engine - Add /api/tts endpoint for MVP: accepts ref audio/video, text, optional aux audio - Frontend auto-extracts audio from uploaded video files via Web Audio API - Add emotion presets (neutral/happy/calm/sad/angry) with speed customization - Add test_frontend/index.html with health check, audio playback, and download - Add contract tests (7 tests, all passing) using mock TTS pipeline - Add documentation: simple_api.md (full tutorial), simple_api_quickstart.md - Add startup scripts: go-simple-api.ps1, go-simple-api.bat, open-test-frontend.ps1 - Add soundfile and python-multipart to requirements.txt - Text splitting fixed to cut5 (punctuation-based) per MVP spec
This commit is contained in:
parent
08d627c333
commit
735b2e3554
21
README.md
21
README.md
@ -1,4 +1,4 @@
|
||||
<div align="center">
|
||||
<div align="center">
|
||||
|
||||
<h1>GPT-SoVITS-WebUI</h1>
|
||||
A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
|
||||
@ -27,6 +27,24 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
|
||||
|
||||
---
|
||||
|
||||
## 简化接口 / 测试前端
|
||||
|
||||
本工作区新增了一个用于 MVP 调用的中间层接口和测试前端:
|
||||
|
||||
- 快速启动:[docs/simple_api_quickstart.md](./docs/simple_api_quickstart.md)
|
||||
- 完整教程:[docs/simple_api.md](./docs/simple_api.md)
|
||||
- 后端入口:simple_api.py
|
||||
- 测试前端:启动后访问 http://127.0.0.1:9881/test/
|
||||
|
||||
启动命令:
|
||||
|
||||
`powershell
|
||||
cd D:\tts\GPT-SoVITS
|
||||
.\go-simple-api.ps1
|
||||
`
|
||||
|
||||
---
|
||||
|
||||
## Features:
|
||||
|
||||
1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion.
|
||||
@ -478,3 +496,4 @@ Thankful to @Naozumi520 for providing the Cantonese training set and for the gui
|
||||
<a href="https://github.com/RVC-Boss/GPT-SoVITS/graphs/contributors" target="_blank">
|
||||
<img src="https://contrib.rocks/image?repo=RVC-Boss/GPT-SoVITS" />
|
||||
</a>
|
||||
|
||||
|
||||
320
docs/simple_api.md
Normal file
320
docs/simple_api.md
Normal file
@ -0,0 +1,320 @@
|
||||
# GPT-SoVITS 简化接口教程
|
||||
|
||||
本项目原本已经有 `api_v2.py`,但调用时需要传很多 GPT-SoVITS 参数。新增的 `simple_api.py` 是一个中间层,目标是让前端或业务方用更简单的方式调用。
|
||||
|
||||
当前 MVP 推荐使用:
|
||||
|
||||
```http
|
||||
POST /api/tts
|
||||
Content-Type: multipart/form-data
|
||||
```
|
||||
|
||||
适合你的流程:
|
||||
|
||||
1. 前端从视频中提取音频(或直接上传已裁剪的 3-10 秒音频)。
|
||||
2. 用户人工裁剪主参考音频到 3-10 秒。
|
||||
3. 前端把主参考音频、要生成的文字、可选辅助音频提交给后端。
|
||||
4. 后端调用 GPT-SoVITS 生成音频并返回。
|
||||
|
||||
## 1. 安装依赖
|
||||
|
||||
在 GPT-SoVITS 使用的 Python 环境里运行:
|
||||
|
||||
```bash
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
本中间层额外需要:
|
||||
|
||||
- `python-multipart`:用于接收前端上传的音频文件。
|
||||
- `soundfile`:用于读取音频信息和校验参考音频时长。
|
||||
|
||||
这两个依赖已经写入 `requirements.txt`。
|
||||
|
||||
## 2. 检查配置
|
||||
|
||||
配置文件:
|
||||
|
||||
```text
|
||||
simple_api.yaml
|
||||
```
|
||||
|
||||
常用配置:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: 127.0.0.1
|
||||
port: 9881
|
||||
tts_config: GPT_SoVITS/configs/tts_infer.yaml
|
||||
|
||||
upload:
|
||||
dir: runtime/uploads
|
||||
min_ref_seconds: 3
|
||||
max_ref_seconds: 10
|
||||
max_upload_mb: 80
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `server.host`:后端监听地址。
|
||||
- `server.port`:后端端口。
|
||||
- `server.tts_config`:GPT-SoVITS 推理配置文件。
|
||||
- `upload.dir`:临时上传目录。
|
||||
- `upload.min_ref_seconds`:主参考音频最短秒数。
|
||||
- `upload.max_ref_seconds`:主参考音频最长秒数。
|
||||
- `upload.max_upload_mb`:单个上传音频最大体积。
|
||||
|
||||
当前默认后端地址:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881
|
||||
```
|
||||
|
||||
## 3. 启动后端
|
||||
|
||||
进入项目目录:
|
||||
|
||||
```powershell
|
||||
cd D:\tts\GPT-SoVITS
|
||||
```
|
||||
|
||||
方式一:
|
||||
|
||||
```powershell
|
||||
python simple_api.py -c simple_api.yaml
|
||||
```
|
||||
|
||||
方式二,Windows PowerShell:
|
||||
|
||||
```powershell
|
||||
.\go-simple-api.ps1
|
||||
```
|
||||
|
||||
方式三,Windows 批处理:
|
||||
|
||||
```bat
|
||||
go-simple-api.bat
|
||||
```
|
||||
|
||||
启动成功后,后端会监听:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881
|
||||
```
|
||||
|
||||
## 4. 健康检查
|
||||
|
||||
后端启动后,可以访问:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881/health
|
||||
```
|
||||
|
||||
或者命令行测试:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:9881/health
|
||||
```
|
||||
|
||||
返回示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "ok",
|
||||
"tts_config": "GPT_SoVITS/configs/tts_infer.yaml",
|
||||
"version": "v2",
|
||||
"languages": ["auto", "auto_yue", "en", "zh"]
|
||||
}
|
||||
```
|
||||
|
||||
## 5. 打开测试前端
|
||||
|
||||
后端启动后,推荐直接打开:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881/test/
|
||||
```
|
||||
|
||||
这个页面已经挂载在后端服务里,不需要额外启动前端服务。
|
||||
|
||||
也可以直接打开本地文件:
|
||||
|
||||
```text
|
||||
test_frontend/index.html
|
||||
```
|
||||
|
||||
Windows 下也可以运行:
|
||||
|
||||
```powershell
|
||||
.\open-test-frontend.ps1
|
||||
```
|
||||
|
||||
测试前端里有一个“后端接口地址”输入框。
|
||||
|
||||
默认值:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881/api/tts
|
||||
```
|
||||
|
||||
如果你修改了 `server.host` 或 `server.port`,记得同步修改页面里的接口地址。
|
||||
|
||||
## 6. MVP 接口
|
||||
|
||||
接口地址:
|
||||
|
||||
```http
|
||||
POST /api/tts
|
||||
```
|
||||
|
||||
请求类型:
|
||||
|
||||
```http
|
||||
multipart/form-data
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
```text
|
||||
text 必填。需要生成的文字。
|
||||
ref_audio 必填。主参考音频(支持上传视频,前端会自动提取音频),要求 3-10 秒。
|
||||
aux_ref_audio 可选。辅助参考音频,可以上传多个。
|
||||
prompt_text 可选。主参考音频对应文字,可以留空。
|
||||
text_lang 可选。生成文字语言,默认 zh。
|
||||
prompt_lang 可选。参考音频语言,默认 zh。即使 prompt_text 为空,也需要传给 GPT-SoVITS 内部。
|
||||
format 可选。返回格式,默认 wav。
|
||||
emotion 可选。情绪 preset:neutral、happy、calm、sad、angry。
|
||||
speed 可选。语速,对应 GPT-SoVITS 的 speed_factor。
|
||||
seed 可选。随机种子,默认 -1。
|
||||
```
|
||||
|
||||
## 7. curl 调用示例
|
||||
|
||||
PowerShell 示例:
|
||||
|
||||
```powershell
|
||||
curl.exe -X POST http://127.0.0.1:9881/api/tts `
|
||||
-F "text=你好,欢迎使用这个声音。" `
|
||||
-F "ref_audio=@D:\audio\ref.wav" `
|
||||
-F "prompt_text=" `
|
||||
-F "text_lang=zh" `
|
||||
-F "prompt_lang=zh" `
|
||||
-F "emotion=neutral" `
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
如果有辅助参考音频:
|
||||
|
||||
```powershell
|
||||
curl.exe -X POST http://127.0.0.1:9881/api/tts `
|
||||
-F "text=你好,欢迎使用这个声音。" `
|
||||
-F "ref_audio=@D:\audio\ref.wav" `
|
||||
-F "aux_ref_audio=@D:\audio\aux1.wav" `
|
||||
-F "aux_ref_audio=@D:\audio\aux2.wav" `
|
||||
-F "prompt_text=" `
|
||||
-F "text_lang=zh" `
|
||||
-F "prompt_lang=zh" `
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
## 8. 重要规则
|
||||
|
||||
- 主参考音频必须是 3-10 秒。
|
||||
- `aux_ref_audio` 是可选项。
|
||||
- `prompt_text` 可以为空,但当前主要针对 GPT-SoVITS v2。
|
||||
- 如果切到 GPT-SoVITS v3/v4,空 `prompt_text` 会被中间层直接返回 400。
|
||||
- 生成文字固定使用 `cut5`,也就是按照标点符号切句。
|
||||
- `emotion` 目前是轻量 preset,本质是映射到采样和语速参数;更稳定的情绪控制仍然依赖带情绪的参考音频。
|
||||
|
||||
## 9. 测试前端使用步骤
|
||||
|
||||
1. 启动后端。
|
||||
2. 打开 `http://127.0.0.1:9881/test/`。
|
||||
3. 检查“后端接口地址”是否为:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881/api/tts
|
||||
```
|
||||
|
||||
4. 填写“需要生成的文字”。
|
||||
5. 上传主参考音频。
|
||||
6. 可选上传辅助参考音频。
|
||||
7. 可选填写参考音频文字。
|
||||
8. 选择语言、情绪、语速。
|
||||
9. 点击“生成音频”。
|
||||
10. 页面右侧会显示返回结果,可以在线播放和下载。
|
||||
|
||||
## 10. 其他接口
|
||||
|
||||
除了 MVP 上传接口,还保留了 profile 调用接口:
|
||||
|
||||
```http
|
||||
GET /voices
|
||||
GET /speak?text=hello&voice=default
|
||||
POST /speak
|
||||
POST /speak/base64
|
||||
POST /v1/tts
|
||||
POST /admin/reload-config
|
||||
POST /admin/weights
|
||||
```
|
||||
|
||||
这些接口适合后续做固定音色 profile,不是当前 MVP 的主流程。
|
||||
|
||||
## 11. Base64 返回示例
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/speak/base64 ^
|
||||
-H "Content-Type: application/json" ^
|
||||
-d "{\"text\":\"hello\",\"voice\":\"default\"}"
|
||||
```
|
||||
|
||||
返回格式:
|
||||
|
||||
```json
|
||||
{
|
||||
"media_type": "audio/wav",
|
||||
"audio_base64": "..."
|
||||
}
|
||||
```
|
||||
|
||||
## 12. 添加固定音色 profile
|
||||
|
||||
如果后续要做固定音色,可以编辑 `simple_api.yaml`:
|
||||
|
||||
```yaml
|
||||
voices:
|
||||
default:
|
||||
ref_audio_path: reference.wav
|
||||
prompt_text: exact transcript
|
||||
prompt_lang: zh
|
||||
text_lang: zh
|
||||
|
||||
narrator:
|
||||
ref_audio_path: voices/narrator.wav
|
||||
prompt_text: exact transcript of narrator.wav
|
||||
prompt_lang: zh
|
||||
text_lang: zh
|
||||
```
|
||||
|
||||
编辑后调用:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:9881/admin/reload-config
|
||||
```
|
||||
|
||||
## 13. 契约测试
|
||||
|
||||
这个测试使用 mock,不会加载 GPT-SoVITS 模型:
|
||||
|
||||
```bash
|
||||
python -m unittest tests.test_simple_api_contract
|
||||
```
|
||||
|
||||
用于确认:
|
||||
|
||||
- `/api/tts` 路由存在。
|
||||
- 上传接口能构造正确参数。
|
||||
- 主参考音频 3-10 秒校验正常。
|
||||
- 空 `prompt_text` 在 v2 可用。
|
||||
- v3/v4 空 `prompt_text` 会返回 400。
|
||||
- 临时上传目录会被清理。
|
||||
89
docs/simple_api_quickstart.md
Normal file
89
docs/simple_api_quickstart.md
Normal file
@ -0,0 +1,89 @@
|
||||
# 简化 TTS 接口快速启动
|
||||
|
||||
完整教程见:`docs/simple_api.md`
|
||||
|
||||
## 一句话流程
|
||||
|
||||
启动后端,打开测试页,上传 3-10 秒参考音频或视频(视频会自动提取音频),填写生成文字,点击生成。
|
||||
|
||||
## 1. 启动后端
|
||||
|
||||
```powershell
|
||||
cd D:\tts\GPT-SoVITS
|
||||
python -m pip install -r requirements.txt
|
||||
.\go-simple-api.ps1
|
||||
```
|
||||
|
||||
也可以运行:
|
||||
|
||||
```bat
|
||||
go-simple-api.bat
|
||||
```
|
||||
|
||||
默认后端地址:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881
|
||||
```
|
||||
|
||||
## 2. 打开测试前端
|
||||
|
||||
后端启动后访问:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881/test/
|
||||
```
|
||||
|
||||
测试页面里的默认接口地址是:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:9881/api/tts
|
||||
```
|
||||
|
||||
## 3. 调用接口
|
||||
|
||||
接口:
|
||||
|
||||
```http
|
||||
POST /api/tts
|
||||
Content-Type: multipart/form-data
|
||||
```
|
||||
|
||||
最小字段:
|
||||
|
||||
```text
|
||||
text 要生成的文字
|
||||
ref_audio 3-10 秒主参考音频(支持视频,前端自动提取音频)
|
||||
```
|
||||
|
||||
常用可选字段:
|
||||
|
||||
```text
|
||||
aux_ref_audio 辅助参考音频,可多个
|
||||
prompt_text 参考音频文本,可留空
|
||||
text_lang 默认 zh
|
||||
prompt_lang 默认 zh
|
||||
emotion neutral / happy / calm / sad / angry
|
||||
speed 语速,默认 1
|
||||
seed 默认 -1
|
||||
```
|
||||
|
||||
## 4. PowerShell 示例
|
||||
|
||||
```powershell
|
||||
curl.exe -X POST http://127.0.0.1:9881/api/tts `
|
||||
-F "text=你好,欢迎使用这个声音。" `
|
||||
-F "ref_audio=@D:\audio\ref.wav" `
|
||||
-F "prompt_text=" `
|
||||
-F "text_lang=zh" `
|
||||
-F "prompt_lang=zh" `
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
## 5. 注意事项
|
||||
|
||||
- 主参考音频必须是 3-10 秒。
|
||||
- `prompt_text` 在当前 v2 配置下可以为空。
|
||||
- 如果切换到 v3/v4,空 `prompt_text` 会返回 400。
|
||||
- 生成文字固定按标点符号切句。
|
||||
- 更详细的配置、profile、base64 接口见 `docs/simple_api.md`。
|
||||
12
go-simple-api.bat
Normal file
12
go-simple-api.bat
Normal file
@ -0,0 +1,12 @@
|
||||
@echo off
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
|
||||
cd /d "%SCRIPT_DIR%"
|
||||
|
||||
if exist "%SCRIPT_DIR%\runtime\python.exe" (
|
||||
"%SCRIPT_DIR%\runtime\python.exe" "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
|
||||
) else (
|
||||
python "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
|
||||
)
|
||||
|
||||
pause
|
||||
13
go-simple-api.ps1
Normal file
13
go-simple-api.ps1
Normal file
@ -0,0 +1,13 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
chcp 65001
|
||||
Set-Location $PSScriptRoot
|
||||
|
||||
$runtimePython = Join-Path $PSScriptRoot "runtime\python.exe"
|
||||
$apiScript = Join-Path $PSScriptRoot "simple_api.py"
|
||||
$apiConfig = Join-Path $PSScriptRoot "simple_api.yaml"
|
||||
|
||||
if (Test-Path $runtimePython) {
|
||||
& $runtimePython $apiScript -c $apiConfig @args
|
||||
} else {
|
||||
python $apiScript -c $apiConfig @args
|
||||
}
|
||||
3
open-test-frontend.ps1
Normal file
3
open-test-frontend.ps1
Normal file
@ -0,0 +1,3 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
Set-Location $PSScriptRoot
|
||||
Start-Process (Join-Path $PSScriptRoot "test_frontend\index.html")
|
||||
@ -3,6 +3,7 @@ numpy<2.0
|
||||
scipy
|
||||
tensorboard
|
||||
librosa==0.10.2
|
||||
soundfile
|
||||
numba
|
||||
pytorch-lightning>=2.4
|
||||
gradio<5
|
||||
@ -22,6 +23,7 @@ transformers>=4.43,<=4.50
|
||||
peft<0.18.0
|
||||
chardet
|
||||
PyYAML
|
||||
python-multipart
|
||||
psutil
|
||||
jieba_fast
|
||||
jieba
|
||||
|
||||
741
simple_api.py
Normal file
741
simple_api.py
Normal file
@ -0,0 +1,741 @@
|
||||
"""
|
||||
Small profile-based API layer for GPT-SoVITS.
|
||||
|
||||
Run:
|
||||
python simple_api.py -c simple_api.yaml
|
||||
|
||||
Then call:
|
||||
POST http://127.0.0.1:9881/speak
|
||||
{"text": "hello", "voice": "default"}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
import wave
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import yaml
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent
|
||||
os.chdir(PROJECT_ROOT)
|
||||
sys.path.append(str(PROJECT_ROOT))
|
||||
sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS"))
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
|
||||
get_method_names as get_cut_method_names,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TTS_PARAMS: Dict[str, Any] = {
|
||||
"text": "",
|
||||
"text_lang": "zh",
|
||||
"ref_audio_path": "",
|
||||
"aux_ref_audio_paths": [],
|
||||
"prompt_text": "",
|
||||
"prompt_lang": "zh",
|
||||
"top_k": 15,
|
||||
"top_p": 1.0,
|
||||
"temperature": 1.0,
|
||||
"text_split_method": "cut5",
|
||||
"batch_size": 1,
|
||||
"batch_threshold": 0.75,
|
||||
"split_bucket": True,
|
||||
"speed_factor": 1.0,
|
||||
"fragment_interval": 0.3,
|
||||
"seed": -1,
|
||||
"media_type": "wav",
|
||||
"streaming_mode": False,
|
||||
"return_fragment": False,
|
||||
"fixed_length_chunk": False,
|
||||
"parallel_infer": True,
|
||||
"repetition_penalty": 1.35,
|
||||
"sample_steps": 32,
|
||||
"super_sampling": False,
|
||||
"overlap_length": 2,
|
||||
"min_chunk_length": 16,
|
||||
}
|
||||
|
||||
SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"}
|
||||
SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"}
|
||||
VOICE_META_KEYS = {"description"}
|
||||
|
||||
|
||||
class SpeakRequest(BaseModel):
|
||||
text: str
|
||||
voice: Optional[str] = None
|
||||
text_lang: Optional[str] = None
|
||||
ref_audio_path: Optional[str] = None
|
||||
aux_ref_audio_paths: Optional[List[str]] = None
|
||||
prompt_text: Optional[str] = None
|
||||
prompt_lang: Optional[str] = None
|
||||
format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac")
|
||||
stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2")
|
||||
streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false")
|
||||
speed: Optional[float] = Field(default=None, description="Alias of speed_factor")
|
||||
speed_factor: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
text_split_method: Optional[str] = None
|
||||
batch_size: Optional[int] = None
|
||||
batch_threshold: Optional[float] = None
|
||||
split_bucket: Optional[bool] = None
|
||||
fragment_interval: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
parallel_infer: Optional[bool] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
sample_steps: Optional[int] = None
|
||||
super_sampling: Optional[bool] = None
|
||||
overlap_length: Optional[int] = None
|
||||
min_chunk_length: Optional[int] = None
|
||||
|
||||
|
||||
class WeightsRequest(BaseModel):
|
||||
gpt_weights_path: Optional[str] = None
|
||||
sovits_weights_path: Optional[str] = None
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS simple API")
|
||||
parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path")
|
||||
parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path")
|
||||
parser.add_argument("-a", "--bind-addr", default=None, help="bind address")
|
||||
parser.add_argument("-p", "--port", type=int, default=None, help="bind port")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
path = Path(config_path)
|
||||
if not path.is_absolute():
|
||||
path = PROJECT_ROOT / path
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"simple API config not found: {path}")
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("simple API config must be a YAML object")
|
||||
return data
|
||||
|
||||
|
||||
args = parse_args()
|
||||
simple_config = load_yaml_config(args.config)
|
||||
server_config = simple_config.get("server", {}) or {}
|
||||
host = args.bind_addr or server_config.get("host", "127.0.0.1")
|
||||
port = args.port or int(server_config.get("port", 9881))
|
||||
tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml")
|
||||
|
||||
cut_method_names = get_cut_method_names()
|
||||
tts_config: Optional[TTS_Config] = None
|
||||
tts_pipeline: Optional[TTS] = None
|
||||
infer_lock = threading.Lock()
|
||||
|
||||
APP = FastAPI(
|
||||
title="GPT-SoVITS Simple API",
|
||||
description="Profile-based API layer that hides GPT-SoVITS request details.",
|
||||
version="1.0.0",
|
||||
)
|
||||
APP.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=simple_config.get("cors_allow_origins", ["*"]),
|
||||
allow_credentials=False,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
test_frontend_dir = PROJECT_ROOT / "test_frontend"
|
||||
if test_frontend_dir.exists():
|
||||
APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend")
|
||||
|
||||
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
def handle_pack_ogg() -> None:
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
try:
|
||||
threading.stack_size(4096 * 4096)
|
||||
pack_thread = threading.Thread(target=handle_pack_ogg)
|
||||
pack_thread.start()
|
||||
pack_thread.join()
|
||||
except (RuntimeError, ValueError):
|
||||
handle_pack_ogg()
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
del rate
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
del io_buffer
|
||||
wav_buffer = BytesIO()
|
||||
sf.write(wav_buffer, data, rate, format="wav")
|
||||
return wav_buffer
|
||||
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ar",
|
||||
str(rate),
|
||||
"-ac",
|
||||
"1",
|
||||
"-i",
|
||||
"pipe:0",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-b:a",
|
||||
"192k",
|
||||
"-vn",
|
||||
"-f",
|
||||
"adts",
|
||||
"pipe:1",
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes:
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
def request_to_dict(request: BaseModel) -> Dict[str, Any]:
|
||||
if hasattr(request, "model_dump"):
|
||||
return request.model_dump(exclude_none=True)
|
||||
return request.dict(exclude_none=True)
|
||||
|
||||
|
||||
def get_default_voice_name() -> Optional[str]:
|
||||
default_voice = simple_config.get("default_voice")
|
||||
if default_voice:
|
||||
return str(default_voice)
|
||||
voices = simple_config.get("voices", {}) or {}
|
||||
if "default" in voices:
|
||||
return "default"
|
||||
return next(iter(voices.keys()), None)
|
||||
|
||||
|
||||
def resolve_project_path(path_value: Optional[str]) -> Optional[str]:
|
||||
if path_value in [None, ""]:
|
||||
return None
|
||||
path = Path(str(path_value))
|
||||
if not path.is_absolute():
|
||||
path = PROJECT_ROOT / path
|
||||
return str(path)
|
||||
|
||||
|
||||
def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]:
|
||||
if streaming_mode is None:
|
||||
streaming_mode = 2 if stream else 0
|
||||
elif isinstance(streaming_mode, bool):
|
||||
streaming_mode = 2 if streaming_mode else 0
|
||||
|
||||
if streaming_mode == 0:
|
||||
return False, False, False, False
|
||||
if streaming_mode == 1:
|
||||
return False, True, False, True
|
||||
if streaming_mode == 2:
|
||||
return True, False, False, True
|
||||
if streaming_mode == 3:
|
||||
return True, False, True, True
|
||||
raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false")
|
||||
|
||||
|
||||
def get_upload_config() -> Dict[str, Any]:
|
||||
return simple_config.get("upload", {}) or {}
|
||||
|
||||
|
||||
def get_upload_root() -> Path:
|
||||
upload_dir = str(get_upload_config().get("dir", "runtime/uploads"))
|
||||
root = Path(upload_dir)
|
||||
if not root.is_absolute():
|
||||
root = PROJECT_ROOT / root
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def get_request_upload_dir() -> Path:
|
||||
request_dir = get_upload_root() / uuid.uuid4().hex
|
||||
request_dir.mkdir(parents=True, exist_ok=True)
|
||||
return request_dir
|
||||
|
||||
|
||||
def get_upload_limits() -> Tuple[float, float, int]:
|
||||
upload_config = get_upload_config()
|
||||
min_seconds = float(upload_config.get("min_ref_seconds", 3.0))
|
||||
max_seconds = float(upload_config.get("max_ref_seconds", 10.0))
|
||||
max_upload_mb = int(upload_config.get("max_upload_mb", 80))
|
||||
return min_seconds, max_seconds, max_upload_mb
|
||||
|
||||
|
||||
def get_upload_suffix(upload: UploadFile) -> str:
|
||||
suffix = Path(upload.filename or "").suffix.lower()
|
||||
if suffix:
|
||||
return suffix
|
||||
content_type = (upload.content_type or "").lower()
|
||||
if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}:
|
||||
return ".wav"
|
||||
if content_type == "audio/mpeg":
|
||||
return ".mp3"
|
||||
if content_type == "audio/ogg":
|
||||
return ".ogg"
|
||||
if content_type in {"audio/aac", "audio/aacp"}:
|
||||
return ".aac"
|
||||
return ".wav"
|
||||
|
||||
|
||||
def validate_audio_upload(upload: UploadFile) -> None:
|
||||
suffix = get_upload_suffix(upload)
|
||||
content_type = (upload.content_type or "").lower()
|
||||
is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""}
|
||||
if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type:
|
||||
raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}")
|
||||
|
||||
|
||||
async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str:
|
||||
validate_audio_upload(upload)
|
||||
_, _, max_upload_mb = get_upload_limits()
|
||||
max_bytes = max_upload_mb * 1024 * 1024
|
||||
suffix = get_upload_suffix(upload)
|
||||
target_path = target_dir / f"{name_prefix}{suffix}"
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
with target_path.open("wb") as f:
|
||||
while True:
|
||||
chunk = await upload.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total_size += len(chunk)
|
||||
if total_size > max_bytes:
|
||||
raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB")
|
||||
f.write(chunk)
|
||||
if total_size == 0:
|
||||
raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}")
|
||||
finally:
|
||||
await upload.close()
|
||||
|
||||
return str(target_path)
|
||||
|
||||
|
||||
def get_audio_duration_with_ffprobe(audio_path: str) -> float:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
audio_path,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(result.stderr.strip() or "ffprobe failed")
|
||||
return float(result.stdout.strip())
|
||||
|
||||
|
||||
def get_audio_duration_seconds(audio_path: str) -> float:
|
||||
try:
|
||||
info = sf.info(audio_path)
|
||||
if info.samplerate <= 0:
|
||||
raise ValueError("invalid sample rate")
|
||||
return float(info.frames) / float(info.samplerate)
|
||||
except Exception as sf_exc:
|
||||
try:
|
||||
return get_audio_duration_with_ffprobe(audio_path)
|
||||
except Exception as ffprobe_exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}",
|
||||
) from ffprobe_exc
|
||||
|
||||
|
||||
def validate_main_ref_duration(audio_path: str) -> None:
|
||||
min_seconds, max_seconds, _ = get_upload_limits()
|
||||
duration = get_audio_duration_seconds(audio_path)
|
||||
if duration < min_seconds or duration > max_seconds:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s",
|
||||
)
|
||||
|
||||
|
||||
def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None:
|
||||
if emotion in [None, ""]:
|
||||
return
|
||||
emotion_name = str(emotion).strip().lower()
|
||||
emotion_presets = simple_config.get("emotion_presets", {}) or {}
|
||||
if emotion_name not in emotion_presets:
|
||||
supported = ", ".join(sorted(emotion_presets.keys())) or "none"
|
||||
raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}")
|
||||
payload.update(emotion_presets[emotion_name] or {})
|
||||
|
||||
|
||||
def prompt_text_required_for_current_model() -> bool:
|
||||
if tts_config is None:
|
||||
return False
|
||||
use_vocoder = getattr(tts_config, "use_vocoder", None)
|
||||
if use_vocoder is not None:
|
||||
return bool(use_vocoder)
|
||||
return getattr(tts_config, "version", None) in {"v3", "v4"}
|
||||
|
||||
|
||||
def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ref_audio_path = resolve_project_path(profile.get("ref_audio_path"))
|
||||
return {
|
||||
"name": name,
|
||||
"description": profile.get("description", ""),
|
||||
"text_lang": profile.get("text_lang"),
|
||||
"prompt_lang": profile.get("prompt_lang"),
|
||||
"ref_audio_path": profile.get("ref_audio_path"),
|
||||
"ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")),
|
||||
}
|
||||
|
||||
|
||||
def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]:
|
||||
voices = simple_config.get("voices", {}) or {}
|
||||
explicit_ref_audio = bool(payload.get("ref_audio_path"))
|
||||
voice_name = payload.pop("voice", None)
|
||||
if voice_name is None and not explicit_ref_audio:
|
||||
voice_name = get_default_voice_name()
|
||||
voice_profile: Dict[str, Any] = {}
|
||||
|
||||
if voice_name:
|
||||
if voice_name not in voices:
|
||||
raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}")
|
||||
voice_profile = deepcopy(voices[voice_name] or {})
|
||||
|
||||
tts_req = deepcopy(DEFAULT_TTS_PARAMS)
|
||||
tts_req.update(simple_config.get("defaults", {}) or {})
|
||||
tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS})
|
||||
|
||||
stream = payload.pop("stream", None)
|
||||
media_type = payload.pop("format", None)
|
||||
speed = payload.pop("speed", None)
|
||||
if media_type is not None:
|
||||
tts_req["media_type"] = media_type
|
||||
|
||||
for key, value in payload.items():
|
||||
if key in DEFAULT_TTS_PARAMS:
|
||||
tts_req[key] = value
|
||||
if speed is not None:
|
||||
tts_req["speed_factor"] = speed
|
||||
|
||||
ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path"))
|
||||
if ref_audio_path:
|
||||
tts_req["ref_audio_path"] = ref_audio_path
|
||||
|
||||
aux_paths = tts_req.get("aux_ref_audio_paths") or []
|
||||
tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item]
|
||||
|
||||
media_type = str(tts_req.get("media_type", "wav")).lower()
|
||||
tts_req["media_type"] = media_type
|
||||
if media_type not in SUPPORTED_MEDIA_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}")
|
||||
|
||||
text = str(tts_req.get("text") or "").strip()
|
||||
if not text:
|
||||
raise HTTPException(status_code=400, detail="text is required")
|
||||
tts_req["text"] = text
|
||||
|
||||
if not tts_req.get("ref_audio_path"):
|
||||
raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request")
|
||||
if not Path(tts_req["ref_audio_path"]).exists():
|
||||
raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}")
|
||||
|
||||
for aux_path in tts_req["aux_ref_audio_paths"]:
|
||||
if aux_path and not Path(aux_path).exists():
|
||||
raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}")
|
||||
|
||||
if tts_config is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
|
||||
text_lang = str(tts_req.get("text_lang") or "").lower()
|
||||
prompt_lang = str(tts_req.get("prompt_lang") or "").lower()
|
||||
tts_req["text_lang"] = text_lang
|
||||
tts_req["prompt_lang"] = prompt_lang
|
||||
|
||||
if text_lang not in tts_config.languages:
|
||||
raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}")
|
||||
if prompt_lang not in tts_config.languages:
|
||||
raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}")
|
||||
if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model():
|
||||
version = getattr(tts_config, "version", "current")
|
||||
raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}")
|
||||
|
||||
split_method = str(tts_req.get("text_split_method") or "cut5")
|
||||
if split_method not in cut_method_names:
|
||||
raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}")
|
||||
|
||||
streaming_mode = tts_req.pop("streaming_mode", None)
|
||||
streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream)
|
||||
tts_req["streaming_mode"] = streaming_enabled
|
||||
tts_req["return_fragment"] = return_fragment
|
||||
tts_req["fixed_length_chunk"] = fixed_length_chunk
|
||||
|
||||
return tts_req, media_type, response_stream
|
||||
|
||||
|
||||
def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]:
|
||||
tts_req, media_type, response_stream = build_tts_request(payload)
|
||||
if response_stream:
|
||||
raise HTTPException(status_code=400, detail="base64 output does not support streaming")
|
||||
|
||||
if tts_pipeline is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
|
||||
try:
|
||||
with infer_lock:
|
||||
sr, audio_data = next(tts_pipeline.run(tts_req))
|
||||
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return audio_bytes, media_type
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc
|
||||
|
||||
|
||||
def synthesize_response(payload: Dict[str, Any]) -> Response:
|
||||
tts_req, media_type, response_stream = build_tts_request(payload)
|
||||
if tts_pipeline is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
|
||||
if response_stream:
|
||||
|
||||
def streaming_generator() -> Generator[bytes, None, None]:
|
||||
first_chunk = True
|
||||
chunk_media_type = media_type
|
||||
with infer_lock:
|
||||
for sr, chunk in tts_pipeline.run(tts_req):
|
||||
if first_chunk and chunk_media_type == "wav":
|
||||
yield wave_header_chunk(sample_rate=sr)
|
||||
chunk_media_type = "raw"
|
||||
first_chunk = False
|
||||
yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue()
|
||||
|
||||
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
||||
|
||||
try:
|
||||
with infer_lock:
|
||||
sr, audio_data = next(tts_pipeline.run(tts_req))
|
||||
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return Response(audio_bytes, media_type=f"audio/{media_type}")
|
||||
except Exception as exc:
|
||||
return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)})
|
||||
|
||||
|
||||
@APP.on_event("startup")
|
||||
def startup() -> None:
|
||||
global tts_config, tts_pipeline
|
||||
tts_config = TTS_Config(tts_config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
|
||||
@APP.get("/")
|
||||
def index() -> Dict[str, Any]:
|
||||
return {
|
||||
"name": "GPT-SoVITS Simple API",
|
||||
"endpoints": [
|
||||
"/health",
|
||||
"/voices",
|
||||
"/api/tts",
|
||||
"/speak",
|
||||
"/speak/base64",
|
||||
"/admin/weights",
|
||||
"/admin/reload-config",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@APP.get("/health")
|
||||
def health() -> Dict[str, Any]:
|
||||
return {
|
||||
"status": "ok" if tts_pipeline is not None else "starting",
|
||||
"tts_config": tts_config_path,
|
||||
"version": getattr(tts_config, "version", None),
|
||||
"languages": getattr(tts_config, "languages", []),
|
||||
}
|
||||
|
||||
|
||||
@APP.get("/voices")
|
||||
def list_voices() -> Dict[str, Any]:
|
||||
voices = simple_config.get("voices", {}) or {}
|
||||
return {
|
||||
"default_voice": get_default_voice_name(),
|
||||
"voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()],
|
||||
}
|
||||
|
||||
|
||||
@APP.post("/api/tts")
|
||||
async def mvp_tts(
|
||||
text: str = Form(...),
|
||||
ref_audio: UploadFile = File(...),
|
||||
aux_ref_audio: Optional[List[UploadFile]] = File(default=None),
|
||||
prompt_text: str = Form(default=""),
|
||||
text_lang: str = Form(default="zh"),
|
||||
prompt_lang: str = Form(default="zh"),
|
||||
format: str = Form(default="wav"),
|
||||
emotion: Optional[str] = Form(default=None),
|
||||
speed: Optional[float] = Form(default=None),
|
||||
seed: int = Form(default=-1),
|
||||
) -> Response:
|
||||
request_dir = get_request_upload_dir()
|
||||
try:
|
||||
ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref")
|
||||
validate_main_ref_duration(ref_audio_path)
|
||||
|
||||
aux_ref_audio_paths = []
|
||||
for index, upload in enumerate(aux_ref_audio or []):
|
||||
aux_path = await save_upload_file(upload, request_dir, f"aux_{index}")
|
||||
get_audio_duration_seconds(aux_path)
|
||||
aux_ref_audio_paths.append(aux_path)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"text": text,
|
||||
"text_lang": text_lang,
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_text": prompt_text or "",
|
||||
"prompt_lang": prompt_lang,
|
||||
"format": format,
|
||||
"text_split_method": "cut5",
|
||||
"streaming_mode": 0,
|
||||
"seed": seed,
|
||||
}
|
||||
apply_emotion_preset(payload, emotion)
|
||||
payload["text_split_method"] = "cut5"
|
||||
payload["streaming_mode"] = 0
|
||||
if speed is not None:
|
||||
payload["speed"] = speed
|
||||
|
||||
return synthesize_response(payload)
|
||||
finally:
|
||||
shutil.rmtree(request_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@APP.get("/speak")
|
||||
def speak_get(
|
||||
text: str,
|
||||
voice: Optional[str] = None,
|
||||
text_lang: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> Response:
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"text_lang": text_lang,
|
||||
"format": format,
|
||||
"stream": stream,
|
||||
"speed": speed,
|
||||
}
|
||||
return synthesize_response({k: v for k, v in payload.items() if v is not None})
|
||||
|
||||
|
||||
@APP.post("/speak")
|
||||
def speak_post(request: SpeakRequest) -> Response:
|
||||
return synthesize_response(request_to_dict(request))
|
||||
|
||||
|
||||
@APP.post("/v1/tts")
|
||||
def openai_style_tts(request: SpeakRequest) -> Response:
|
||||
return synthesize_response(request_to_dict(request))
|
||||
|
||||
|
||||
@APP.post("/speak/base64")
|
||||
def speak_base64(request: SpeakRequest) -> Dict[str, Any]:
|
||||
audio_bytes, media_type = synthesize_once(request_to_dict(request))
|
||||
return {
|
||||
"media_type": f"audio/{media_type}",
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("ascii"),
|
||||
}
|
||||
|
||||
|
||||
@APP.post("/admin/reload-config")
|
||||
def reload_config() -> Dict[str, Any]:
|
||||
global simple_config
|
||||
simple_config = load_yaml_config(args.config)
|
||||
return {"message": "success", "default_voice": get_default_voice_name()}
|
||||
|
||||
|
||||
@APP.post("/admin/weights")
|
||||
def set_weights(request: WeightsRequest) -> Dict[str, Any]:
|
||||
if tts_pipeline is None:
|
||||
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
||||
if not request.gpt_weights_path and not request.sovits_weights_path:
|
||||
raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required")
|
||||
|
||||
try:
|
||||
with infer_lock:
|
||||
if request.gpt_weights_path:
|
||||
tts_pipeline.init_t2s_weights(request.gpt_weights_path)
|
||||
if request.sovits_weights_path:
|
||||
tts_pipeline.init_vits_weights(request.sovits_weights_path)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc
|
||||
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
try:
|
||||
uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
59
simple_api.yaml
Normal file
59
simple_api.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
server:
|
||||
host: 127.0.0.1
|
||||
port: 9881
|
||||
tts_config: GPT_SoVITS/configs/tts_infer.yaml
|
||||
|
||||
cors_allow_origins:
|
||||
- "*"
|
||||
|
||||
upload:
|
||||
dir: runtime/uploads
|
||||
min_ref_seconds: 3
|
||||
max_ref_seconds: 10
|
||||
max_upload_mb: 80
|
||||
|
||||
default_voice: default
|
||||
|
||||
defaults:
|
||||
text_lang: zh
|
||||
prompt_lang: zh
|
||||
media_type: wav
|
||||
text_split_method: cut5
|
||||
batch_size: 1
|
||||
batch_threshold: 0.75
|
||||
split_bucket: true
|
||||
speed_factor: 1.0
|
||||
fragment_interval: 0.3
|
||||
seed: -1
|
||||
parallel_infer: true
|
||||
repetition_penalty: 1.35
|
||||
sample_steps: 32
|
||||
super_sampling: false
|
||||
overlap_length: 2
|
||||
min_chunk_length: 16
|
||||
|
||||
emotion_presets:
|
||||
neutral: {}
|
||||
happy:
|
||||
temperature: 1.1
|
||||
top_p: 0.95
|
||||
calm:
|
||||
temperature: 0.8
|
||||
top_p: 0.85
|
||||
speed_factor: 0.92
|
||||
sad:
|
||||
temperature: 0.75
|
||||
top_p: 0.85
|
||||
speed_factor: 0.9
|
||||
angry:
|
||||
temperature: 1.2
|
||||
top_k: 20
|
||||
repetition_penalty: 1.25
|
||||
|
||||
voices:
|
||||
default:
|
||||
description: Replace this profile with your reference voice.
|
||||
ref_audio_path: reference.wav
|
||||
prompt_text: Replace this with the exact text spoken in reference.wav.
|
||||
prompt_lang: zh
|
||||
text_lang: zh
|
||||
713
test_frontend/index.html
Normal file
713
test_frontend/index.html
Normal file
@ -0,0 +1,713 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>GPT-SoVITS API Test</title>
|
||||
<style>
|
||||
:root {
|
||||
color-scheme: light;
|
||||
--bg: #f5f7f4;
|
||||
--panel: #ffffff;
|
||||
--ink: #18201d;
|
||||
--muted: #64706b;
|
||||
--line: #d8ded9;
|
||||
--accent: #19745f;
|
||||
--accent-strong: #0f5f4c;
|
||||
--warn: #a15d12;
|
||||
--bad: #b42318;
|
||||
--good: #18794e;
|
||||
--shadow: 0 18px 50px rgba(31, 44, 38, 0.12);
|
||||
font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
min-height: 100vh;
|
||||
background:
|
||||
linear-gradient(135deg, rgba(25, 116, 95, 0.08), rgba(248, 251, 247, 0) 42%),
|
||||
var(--bg);
|
||||
color: var(--ink);
|
||||
}
|
||||
|
||||
main {
|
||||
width: min(1180px, calc(100vw - 32px));
|
||||
margin: 0 auto;
|
||||
padding: 28px 0 44px;
|
||||
}
|
||||
|
||||
header {
|
||||
display: grid;
|
||||
grid-template-columns: minmax(0, 1fr) auto;
|
||||
align-items: end;
|
||||
gap: 20px;
|
||||
padding: 8px 0 22px;
|
||||
border-bottom: 1px solid var(--line);
|
||||
}
|
||||
|
||||
h1 {
|
||||
margin: 0;
|
||||
font-size: clamp(28px, 4vw, 54px);
|
||||
line-height: 1.02;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
|
||||
.sub {
|
||||
margin: 12px 0 0;
|
||||
color: var(--muted);
|
||||
max-width: 760px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.status {
|
||||
min-width: 172px;
|
||||
padding: 12px 14px;
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 8px;
|
||||
background: rgba(255, 255, 255, 0.68);
|
||||
color: var(--muted);
|
||||
text-align: right;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.status strong {
|
||||
display: block;
|
||||
color: var(--ink);
|
||||
font-size: 16px;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
|
||||
.workspace {
|
||||
display: grid;
|
||||
grid-template-columns: minmax(0, 1.05fr) minmax(320px, 0.65fr);
|
||||
gap: 22px;
|
||||
padding-top: 24px;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
section {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 8px;
|
||||
box-shadow: var(--shadow);
|
||||
}
|
||||
|
||||
.form {
|
||||
padding: 22px;
|
||||
}
|
||||
|
||||
.result {
|
||||
padding: 20px;
|
||||
position: sticky;
|
||||
top: 18px;
|
||||
}
|
||||
|
||||
.grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, minmax(0, 1fr));
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.full {
|
||||
grid-column: 1 / -1;
|
||||
}
|
||||
|
||||
label {
|
||||
display: block;
|
||||
color: var(--ink);
|
||||
font-weight: 650;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
input,
|
||||
textarea,
|
||||
select {
|
||||
width: 100%;
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 8px;
|
||||
background: #fbfcfb;
|
||||
color: var(--ink);
|
||||
font: inherit;
|
||||
padding: 12px;
|
||||
outline: none;
|
||||
transition: border-color 150ms ease, box-shadow 150ms ease, background 150ms ease;
|
||||
}
|
||||
|
||||
input:focus,
|
||||
textarea:focus,
|
||||
select:focus {
|
||||
border-color: rgba(25, 116, 95, 0.72);
|
||||
box-shadow: 0 0 0 4px rgba(25, 116, 95, 0.12);
|
||||
background: #fff;
|
||||
}
|
||||
|
||||
textarea {
|
||||
min-height: 142px;
|
||||
resize: vertical;
|
||||
line-height: 1.55;
|
||||
}
|
||||
|
||||
.hint {
|
||||
margin: 7px 0 0;
|
||||
color: var(--muted);
|
||||
font-size: 13px;
|
||||
line-height: 1.45;
|
||||
}
|
||||
|
||||
.file-line {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
min-height: 24px;
|
||||
color: var(--muted);
|
||||
font-size: 13px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.duration-ok {
|
||||
color: var(--good);
|
||||
}
|
||||
|
||||
.duration-warn {
|
||||
color: var(--warn);
|
||||
}
|
||||
|
||||
.actions {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
margin-top: 18px;
|
||||
padding-top: 18px;
|
||||
border-top: 1px solid var(--line);
|
||||
}
|
||||
|
||||
button,
|
||||
.download {
|
||||
border: 0;
|
||||
border-radius: 8px;
|
||||
min-height: 44px;
|
||||
padding: 0 16px;
|
||||
background: var(--accent);
|
||||
color: #fff;
|
||||
font-weight: 700;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
text-decoration: none;
|
||||
transition: transform 140ms ease, background 140ms ease, opacity 140ms ease;
|
||||
}
|
||||
|
||||
button:hover,
|
||||
.download:hover {
|
||||
background: var(--accent-strong);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
button:disabled,
|
||||
.download[aria-disabled="true"] {
|
||||
opacity: 0.55;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.ghost {
|
||||
background: #e9efeb;
|
||||
color: var(--ink);
|
||||
}
|
||||
|
||||
.ghost:hover {
|
||||
background: #dde7e1;
|
||||
}
|
||||
|
||||
.result h2 {
|
||||
margin: 0 0 12px;
|
||||
font-size: 20px;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
|
||||
.log {
|
||||
min-height: 126px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--line);
|
||||
background: #101815;
|
||||
color: #d8f3e9;
|
||||
padding: 13px;
|
||||
font-family: ui-monospace, SFMono-Regular, Consolas, "Liberation Mono", monospace;
|
||||
font-size: 12px;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
overflow-wrap: anywhere;
|
||||
}
|
||||
|
||||
audio {
|
||||
width: 100%;
|
||||
margin-top: 16px;
|
||||
}
|
||||
|
||||
.meta {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, minmax(0, 1fr));
|
||||
gap: 10px;
|
||||
margin: 16px 0;
|
||||
}
|
||||
|
||||
.metric {
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 8px;
|
||||
padding: 12px;
|
||||
background: #fbfcfb;
|
||||
}
|
||||
|
||||
.metric span {
|
||||
display: block;
|
||||
color: var(--muted);
|
||||
font-size: 12px;
|
||||
margin-bottom: 3px;
|
||||
}
|
||||
|
||||
.metric strong {
|
||||
font-size: 15px;
|
||||
}
|
||||
|
||||
.danger {
|
||||
color: var(--bad);
|
||||
}
|
||||
|
||||
@media (max-width: 860px) {
|
||||
header,
|
||||
.workspace,
|
||||
.grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.status {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.result {
|
||||
position: static;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<header>
|
||||
<div>
|
||||
<h1>GPT-SoVITS 接口测试台</h1>
|
||||
<p class="sub">选择 3-10 秒参考音频或视频(视频会自动提取音频),填写后端接口地址和生成文本,直接调用中间层 <code>/api/tts</code>。</p>
|
||||
</div>
|
||||
<div class="status" id="statusBox">
|
||||
<strong>未检测</strong>
|
||||
后端连接状态
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div class="workspace">
|
||||
<section class="form">
|
||||
<form id="ttsForm">
|
||||
<div class="grid">
|
||||
<div class="full">
|
||||
<label for="endpoint">后端接口地址</label>
|
||||
<input id="endpoint" name="endpoint" type="url" value="http://127.0.0.1:9881/api/tts" required>
|
||||
<p class="hint">如果后端端口或主机变了,在这里改完整地址。页面会把表单直接 POST 到这个地址。</p>
|
||||
</div>
|
||||
|
||||
<div class="full">
|
||||
<label for="text">需要生成的文字</label>
|
||||
<textarea id="text" name="text" placeholder="输入要生成的文字,后端固定按标点符号切句。" required></textarea>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="refAudio">主参考音频/视频</label>
|
||||
<input id="refAudio" name="ref_audio" type="file" accept="audio/*,video/*" required>
|
||||
<div class="file-line" id="refInfo">请选择 3-10 秒音频或视频(视频会自动提取音频)</div>
|
||||
<div class="file-line" id="extractInfo" style="display:none;"></div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="auxAudio">辅助参考音频</label>
|
||||
<input id="auxAudio" name="aux_ref_audio" type="file" accept="audio/*" multiple>
|
||||
<div class="file-line" id="auxInfo">可选,可多选</div>
|
||||
</div>
|
||||
|
||||
<div class="full">
|
||||
<label for="promptText">参考音频文字</label>
|
||||
<textarea id="promptText" name="prompt_text" placeholder="可留空。v2 支持空参考文字;v3/v4 后端会要求填写。"></textarea>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="textLang">生成文字语言</label>
|
||||
<select id="textLang" name="text_lang">
|
||||
<option value="zh">zh</option>
|
||||
<option value="en">en</option>
|
||||
<option value="ja">ja</option>
|
||||
<option value="ko">ko</option>
|
||||
<option value="yue">yue</option>
|
||||
<option value="auto">auto</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="promptLang">参考音频语言</label>
|
||||
<select id="promptLang" name="prompt_lang">
|
||||
<option value="zh">zh</option>
|
||||
<option value="en">en</option>
|
||||
<option value="ja">ja</option>
|
||||
<option value="ko">ko</option>
|
||||
<option value="yue">yue</option>
|
||||
<option value="auto">auto</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="emotion">情绪 preset</label>
|
||||
<select id="emotion" name="emotion">
|
||||
<option value="neutral">neutral</option>
|
||||
<option value="happy">happy</option>
|
||||
<option value="calm">calm</option>
|
||||
<option value="sad">sad</option>
|
||||
<option value="angry">angry</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="speed">语速</label>
|
||||
<input id="speed" name="speed" type="number" min="0.5" max="2" step="0.05" value="1">
|
||||
<p class="hint">显式语速会覆盖情绪 preset 中的语速。</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="seed">Seed</label>
|
||||
<input id="seed" name="seed" type="number" value="-1">
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="format">返回格式</label>
|
||||
<select id="format" name="format">
|
||||
<option value="wav">wav</option>
|
||||
<option value="ogg">ogg</option>
|
||||
<option value="aac">aac</option>
|
||||
<option value="raw">raw</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="actions">
|
||||
<button type="button" class="ghost" id="healthBtn">检测后端</button>
|
||||
<button type="submit" id="submitBtn">生成音频</button>
|
||||
<button type="button" class="ghost" id="resetBtn">清空结果</button>
|
||||
</div>
|
||||
</form>
|
||||
</section>
|
||||
|
||||
<section class="result">
|
||||
<h2>返回结果</h2>
|
||||
<div class="meta">
|
||||
<div class="metric"><span>耗时</span><strong id="elapsed">-</strong></div>
|
||||
<div class="metric"><span>文件大小</span><strong id="fileSize">-</strong></div>
|
||||
</div>
|
||||
<div class="log" id="log">等待请求。</div>
|
||||
<audio id="player" controls hidden></audio>
|
||||
<div class="actions">
|
||||
<a class="download" id="downloadLink" aria-disabled="true">下载音频</a>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
<script>
|
||||
const form = document.querySelector("#ttsForm");
|
||||
const endpoint = document.querySelector("#endpoint");
|
||||
const refAudio = document.querySelector("#refAudio");
|
||||
const auxAudio = document.querySelector("#auxAudio");
|
||||
const refInfo = document.querySelector("#refInfo");
|
||||
const auxInfo = document.querySelector("#auxInfo");
|
||||
const logBox = document.querySelector("#log");
|
||||
const player = document.querySelector("#player");
|
||||
const downloadLink = document.querySelector("#downloadLink");
|
||||
const submitBtn = document.querySelector("#submitBtn");
|
||||
const resetBtn = document.querySelector("#resetBtn");
|
||||
const healthBtn = document.querySelector("#healthBtn");
|
||||
const statusBox = document.querySelector("#statusBox");
|
||||
const elapsed = document.querySelector("#elapsed");
|
||||
const fileSize = document.querySelector("#fileSize");
|
||||
let resultUrl = null;
|
||||
|
||||
if (location.protocol === "http:" || location.protocol === "https:") {
|
||||
endpoint.value = new URL("/api/tts", location.origin).toString();
|
||||
}
|
||||
|
||||
function log(message, isError = false) {
|
||||
logBox.textContent = message;
|
||||
logBox.classList.toggle("danger", isError);
|
||||
}
|
||||
|
||||
function bytesLabel(bytes) {
|
||||
if (!bytes) return "-";
|
||||
const units = ["B", "KB", "MB", "GB"];
|
||||
let value = bytes;
|
||||
let index = 0;
|
||||
while (value >= 1024 && index < units.length - 1) {
|
||||
value /= 1024;
|
||||
index += 1;
|
||||
}
|
||||
return `${value.toFixed(index === 0 ? 0 : 2)} ${units[index]}`;
|
||||
}
|
||||
|
||||
function apiBaseUrl() {
|
||||
try {
|
||||
const url = new URL(endpoint.value.trim());
|
||||
url.pathname = url.pathname.replace(/\/api\/tts\/?$/, "/health");
|
||||
url.search = "";
|
||||
url.hash = "";
|
||||
return url.toString();
|
||||
} catch {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
function clearResult() {
|
||||
if (resultUrl) URL.revokeObjectURL(resultUrl);
|
||||
resultUrl = null;
|
||||
player.hidden = true;
|
||||
player.removeAttribute("src");
|
||||
downloadLink.removeAttribute("href");
|
||||
downloadLink.setAttribute("aria-disabled", "true");
|
||||
elapsed.textContent = "-";
|
||||
fileSize.textContent = "-";
|
||||
log("等待请求。");
|
||||
}
|
||||
|
||||
let extractedAudioBlob = null;
|
||||
|
||||
function isVideoFile(file) {
|
||||
return file && file.type && file.type.startsWith("video/");
|
||||
}
|
||||
|
||||
async function extractAudioFromVideo(file) {
|
||||
const extractInfo = document.querySelector("#extractInfo");
|
||||
extractInfo.style.display = "block";
|
||||
extractInfo.textContent = "正在从视频中提取音频...";
|
||||
extractInfo.className = "file-line";
|
||||
|
||||
try {
|
||||
const video = document.createElement("video");
|
||||
video.preload = "auto";
|
||||
const videoUrl = URL.createObjectURL(file);
|
||||
video.src = videoUrl;
|
||||
|
||||
await new Promise((resolve, reject) => {
|
||||
video.onloadeddata = resolve;
|
||||
video.onerror = () => reject(new Error("无法加载视频文件"));
|
||||
});
|
||||
|
||||
const audioCtx = new (window.AudioContext || window.webkitAudioContext)();
|
||||
const response = await fetch(videoUrl);
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
const audioBuffer = await audioCtx.decodeAudioData(arrayBuffer);
|
||||
|
||||
const wavBlob = audioBufferToWav(audioBuffer);
|
||||
URL.revokeObjectURL(videoUrl);
|
||||
audioCtx.close();
|
||||
|
||||
extractedAudioBlob = wavBlob;
|
||||
const duration = audioBuffer.duration;
|
||||
const ok = Number.isFinite(duration) && duration >= 3 && duration <= 10;
|
||||
extractInfo.textContent = `已提取音频 · ${duration.toFixed(2)}s · ${bytesLabel(wavBlob.size)}${ok ? " ✓" : " ⚠ 建议裁剪到 3-10 秒"}`;
|
||||
extractInfo.className = `file-line ${ok ? "duration-ok" : "duration-warn"}`;
|
||||
return true;
|
||||
} catch (err) {
|
||||
extractInfo.textContent = `提取失败:${err.message}`;
|
||||
extractInfo.className = "file-line duration-warn";
|
||||
extractedAudioBlob = null;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function audioBufferToWav(buffer) {
|
||||
const numChannels = buffer.numberOfChannels;
|
||||
const sampleRate = buffer.sampleRate;
|
||||
const format = 1;
|
||||
const bitDepth = 16;
|
||||
const bytesPerSample = bitDepth / 8;
|
||||
const blockAlign = numChannels * bytesPerSample;
|
||||
const dataLength = buffer.length * blockAlign;
|
||||
const headerLength = 44;
|
||||
const totalLength = headerLength + dataLength;
|
||||
const arrayBuffer = new ArrayBuffer(totalLength);
|
||||
const view = new DataView(arrayBuffer);
|
||||
|
||||
function writeString(offset, str) {
|
||||
for (let i = 0; i < str.length; i++) view.setUint8(offset + i, str.charCodeAt(i));
|
||||
}
|
||||
|
||||
writeString(0, "RIFF");
|
||||
view.setUint32(4, totalLength - 8, true);
|
||||
writeString(8, "WAVE");
|
||||
writeString(12, "fmt ");
|
||||
view.setUint32(16, 16, true);
|
||||
view.setUint16(20, format, true);
|
||||
view.setUint16(22, numChannels, true);
|
||||
view.setUint32(24, sampleRate, true);
|
||||
view.setUint32(28, sampleRate * blockAlign, true);
|
||||
view.setUint16(32, blockAlign, true);
|
||||
view.setUint16(34, bitDepth, true);
|
||||
writeString(36, "data");
|
||||
view.setUint32(40, dataLength, true);
|
||||
|
||||
const channels = [];
|
||||
for (let ch = 0; ch < numChannels; ch++) channels.push(buffer.getChannelData(ch));
|
||||
|
||||
let offset = 44;
|
||||
for (let i = 0; i < buffer.length; i++) {
|
||||
for (let ch = 0; ch < numChannels; ch++) {
|
||||
const sample = Math.max(-1, Math.min(1, channels[ch][i]));
|
||||
view.setInt16(offset, sample < 0 ? sample * 0x8000 : sample * 0x7FFF, true);
|
||||
offset += 2;
|
||||
}
|
||||
}
|
||||
|
||||
return new Blob([arrayBuffer], { type: "audio/wav" });
|
||||
}
|
||||
|
||||
function inspectDuration(file, target) {
|
||||
extractedAudioBlob = null;
|
||||
const extractInfo = document.querySelector("#extractInfo");
|
||||
extractInfo.style.display = "none";
|
||||
|
||||
if (!file) {
|
||||
target.textContent = "请选择 3-10 秒音频或视频";
|
||||
target.className = "file-line";
|
||||
return;
|
||||
}
|
||||
|
||||
if (isVideoFile(file)) {
|
||||
target.textContent = `${file.name} · ${bytesLabel(file.size)} · 视频文件`;
|
||||
target.className = "file-line";
|
||||
extractAudioFromVideo(file);
|
||||
return;
|
||||
}
|
||||
|
||||
const url = URL.createObjectURL(file);
|
||||
const audio = new Audio();
|
||||
audio.preload = "metadata";
|
||||
audio.onloadedmetadata = () => {
|
||||
URL.revokeObjectURL(url);
|
||||
const duration = audio.duration;
|
||||
const ok = Number.isFinite(duration) && duration >= 3 && duration <= 10;
|
||||
target.textContent = `${file.name} · ${duration.toFixed(2)}s · ${bytesLabel(file.size)}`;
|
||||
target.className = `file-line ${ok ? "duration-ok" : "duration-warn"}`;
|
||||
};
|
||||
audio.onerror = () => {
|
||||
URL.revokeObjectURL(url);
|
||||
target.textContent = `${file.name} · 无法读取时长 · ${bytesLabel(file.size)}`;
|
||||
target.className = "file-line duration-warn";
|
||||
};
|
||||
audio.src = url;
|
||||
}
|
||||
|
||||
refAudio.addEventListener("change", () => {
|
||||
inspectDuration(refAudio.files[0], refInfo);
|
||||
});
|
||||
|
||||
auxAudio.addEventListener("change", () => {
|
||||
const count = auxAudio.files.length;
|
||||
auxInfo.textContent = count ? `已选择 ${count} 个辅助音频` : "可选,可多选";
|
||||
});
|
||||
|
||||
healthBtn.addEventListener("click", async () => {
|
||||
const healthUrl = apiBaseUrl();
|
||||
if (!healthUrl) {
|
||||
log("后端地址格式不正确。", true);
|
||||
return;
|
||||
}
|
||||
statusBox.innerHTML = "<strong>检测中</strong>正在请求 /health";
|
||||
try {
|
||||
const response = await fetch(healthUrl);
|
||||
const data = await response.json();
|
||||
if (!response.ok) throw new Error(JSON.stringify(data));
|
||||
statusBox.innerHTML = `<strong>可连接</strong>${data.version || "unknown"} · ${data.status || "ok"}`;
|
||||
log(JSON.stringify(data, null, 2));
|
||||
} catch (error) {
|
||||
statusBox.innerHTML = "<strong>连接失败</strong>检查后端是否启动";
|
||||
log(`检测失败:${error.message}`, true);
|
||||
}
|
||||
});
|
||||
|
||||
resetBtn.addEventListener("click", clearResult);
|
||||
|
||||
form.addEventListener("submit", async (event) => {
|
||||
event.preventDefault();
|
||||
clearResult();
|
||||
|
||||
const file = refAudio.files[0];
|
||||
if (!file) {
|
||||
log("请先选择主参考音频或视频。", true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isVideoFile(file) && !extractedAudioBlob) {
|
||||
log("视频音频提取尚未完成,请稍候再试。", true);
|
||||
return;
|
||||
}
|
||||
|
||||
const started = performance.now();
|
||||
const data = new FormData();
|
||||
data.append("text", document.querySelector("#text").value.trim());
|
||||
|
||||
if (extractedAudioBlob) {
|
||||
data.append("ref_audio", extractedAudioBlob, "extracted_audio.wav");
|
||||
} else {
|
||||
data.append("ref_audio", file);
|
||||
}
|
||||
|
||||
for (const aux of auxAudio.files) data.append("aux_ref_audio", aux);
|
||||
data.append("prompt_text", document.querySelector("#promptText").value);
|
||||
data.append("text_lang", document.querySelector("#textLang").value);
|
||||
data.append("prompt_lang", document.querySelector("#promptLang").value);
|
||||
data.append("format", document.querySelector("#format").value);
|
||||
data.append("emotion", document.querySelector("#emotion").value);
|
||||
data.append("speed", document.querySelector("#speed").value);
|
||||
data.append("seed", document.querySelector("#seed").value);
|
||||
|
||||
submitBtn.disabled = true;
|
||||
log("正在请求后端,请等待模型生成。");
|
||||
|
||||
try {
|
||||
const response = await fetch(endpoint.value.trim(), { method: "POST", body: data });
|
||||
const contentType = response.headers.get("content-type") || "";
|
||||
if (!response.ok) {
|
||||
const detail = contentType.includes("application/json") ? await response.json() : await response.text();
|
||||
throw new Error(typeof detail === "string" ? detail : JSON.stringify(detail, null, 2));
|
||||
}
|
||||
|
||||
const blob = await response.blob();
|
||||
resultUrl = URL.createObjectURL(blob);
|
||||
player.src = resultUrl;
|
||||
player.hidden = false;
|
||||
downloadLink.href = resultUrl;
|
||||
downloadLink.download = `gpt-sovits-${Date.now()}.${document.querySelector("#format").value}`;
|
||||
downloadLink.setAttribute("aria-disabled", "false");
|
||||
|
||||
elapsed.textContent = `${((performance.now() - started) / 1000).toFixed(2)}s`;
|
||||
fileSize.textContent = bytesLabel(blob.size);
|
||||
log(`生成成功。\nContent-Type: ${contentType || "unknown"}\nSize: ${bytesLabel(blob.size)}`);
|
||||
} catch (error) {
|
||||
log(`生成失败:\n${error.message}`, true);
|
||||
} finally {
|
||||
submitBtn.disabled = false;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
303
tests/test_simple_api_contract.py
Normal file
303
tests/test_simple_api_contract.py
Normal file
@ -0,0 +1,303 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SIMPLE_API_PATH = PROJECT_ROOT / "simple_api.py"
|
||||
|
||||
|
||||
class DummyHTTPException(Exception):
|
||||
def __init__(self, status_code=500, detail=None):
|
||||
super().__init__(detail)
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyFastAPI:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.routes = []
|
||||
|
||||
def add_middleware(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def mount(self, path, app, name=None):
|
||||
self.routes.append(types.SimpleNamespace(path=path, endpoint=app, name=name))
|
||||
|
||||
def get(self, path):
|
||||
return self._route(path)
|
||||
|
||||
def post(self, path):
|
||||
return self._route(path)
|
||||
|
||||
def on_event(self, event):
|
||||
return lambda func: func
|
||||
|
||||
def _route(self, path):
|
||||
def decorator(func):
|
||||
self.routes.append(types.SimpleNamespace(path=path, endpoint=func))
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class DummyBaseModel:
|
||||
def dict(self, exclude_none=False):
|
||||
data = dict(self.__dict__)
|
||||
if exclude_none:
|
||||
data = {key: value for key, value in data.items() if value is not None}
|
||||
return data
|
||||
|
||||
|
||||
class DummyTTSConfig:
|
||||
version = "v2"
|
||||
languages = ["zh", "en", "auto"]
|
||||
use_vocoder = False
|
||||
|
||||
|
||||
class FakeUpload:
|
||||
def __init__(self, filename, chunks, content_type="audio/wav"):
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
self._chunks = list(chunks)
|
||||
self.closed = False
|
||||
|
||||
async def read(self, size):
|
||||
del size
|
||||
if self._chunks:
|
||||
return self._chunks.pop(0)
|
||||
return b""
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
def install_dummy_modules():
|
||||
fastapi = types.ModuleType("fastapi")
|
||||
fastapi.FastAPI = DummyFastAPI
|
||||
fastapi.File = lambda default=None, **kwargs: default
|
||||
fastapi.Form = lambda default=None, **kwargs: default
|
||||
fastapi.HTTPException = DummyHTTPException
|
||||
fastapi.Response = type("Response", (), {})
|
||||
fastapi.UploadFile = type("UploadFile", (), {})
|
||||
sys.modules["fastapi"] = fastapi
|
||||
|
||||
middleware = types.ModuleType("fastapi.middleware")
|
||||
cors = types.ModuleType("fastapi.middleware.cors")
|
||||
cors.CORSMiddleware = type("CORSMiddleware", (), {})
|
||||
sys.modules["fastapi.middleware"] = middleware
|
||||
sys.modules["fastapi.middleware.cors"] = cors
|
||||
|
||||
responses = types.ModuleType("fastapi.responses")
|
||||
responses.JSONResponse = type("JSONResponse", (), {})
|
||||
responses.StreamingResponse = type("StreamingResponse", (), {})
|
||||
sys.modules["fastapi.responses"] = responses
|
||||
|
||||
staticfiles = types.ModuleType("fastapi.staticfiles")
|
||||
staticfiles.StaticFiles = type("StaticFiles", (), {"__init__": lambda self, *args, **kwargs: None})
|
||||
sys.modules["fastapi.staticfiles"] = staticfiles
|
||||
|
||||
pydantic = types.ModuleType("pydantic")
|
||||
pydantic.BaseModel = DummyBaseModel
|
||||
pydantic.Field = lambda default=None, **kwargs: default
|
||||
sys.modules["pydantic"] = pydantic
|
||||
|
||||
numpy = types.ModuleType("numpy")
|
||||
numpy.ndarray = object
|
||||
sys.modules["numpy"] = numpy
|
||||
|
||||
soundfile = types.ModuleType("soundfile")
|
||||
soundfile.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
|
||||
sys.modules["soundfile"] = soundfile
|
||||
|
||||
yaml = types.ModuleType("yaml")
|
||||
yaml.safe_load = lambda stream: {
|
||||
"server": {"host": "127.0.0.1", "port": 9881, "tts_config": "dummy.yaml"},
|
||||
"upload": {"dir": "runtime/test_uploads", "min_ref_seconds": 3, "max_ref_seconds": 10, "max_upload_mb": 1},
|
||||
"defaults": {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "text_split_method": "cut5"},
|
||||
"emotion_presets": {"calm": {"temperature": 0.8, "speed_factor": 0.9}},
|
||||
"voices": {},
|
||||
}
|
||||
sys.modules["yaml"] = yaml
|
||||
|
||||
sys.modules["GPT_SoVITS"] = types.ModuleType("GPT_SoVITS")
|
||||
sys.modules["GPT_SoVITS.TTS_infer_pack"] = types.ModuleType("GPT_SoVITS.TTS_infer_pack")
|
||||
|
||||
tts_module = types.ModuleType("GPT_SoVITS.TTS_infer_pack.TTS")
|
||||
tts_module.TTS = type("DummyTTS", (), {})
|
||||
tts_module.TTS_Config = DummyTTSConfig
|
||||
sys.modules["GPT_SoVITS.TTS_infer_pack.TTS"] = tts_module
|
||||
|
||||
segmentation = types.ModuleType("GPT_SoVITS.TTS_infer_pack.text_segmentation_method")
|
||||
segmentation.get_method_names = lambda: ["cut0", "cut5"]
|
||||
sys.modules["GPT_SoVITS.TTS_infer_pack.text_segmentation_method"] = segmentation
|
||||
|
||||
|
||||
def load_simple_api():
|
||||
install_dummy_modules()
|
||||
sys.argv = ["simple_api.py"]
|
||||
spec = importlib.util.spec_from_file_location("simple_api_under_test", SIMPLE_API_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
module.tts_config = DummyTTSConfig()
|
||||
return module
|
||||
|
||||
|
||||
class SimpleApiContractTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.simple_api = load_simple_api()
|
||||
|
||||
def test_api_tts_route_is_registered(self):
|
||||
routes = {route.path for route in self.simple_api.APP.routes}
|
||||
self.assertIn("/api/tts", routes)
|
||||
|
||||
def test_build_request_accepts_explicit_ref_without_voice_profile(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
|
||||
req, media_type, response_stream = self.simple_api.build_tts_request(
|
||||
{
|
||||
"text": "Hello, test.",
|
||||
"ref_audio_path": ref.name,
|
||||
"prompt_text": "",
|
||||
"prompt_lang": "zh",
|
||||
"text_lang": "zh",
|
||||
"format": "wav",
|
||||
"text_split_method": "cut5",
|
||||
"streaming_mode": 0,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(req["prompt_text"], "")
|
||||
self.assertEqual(req["text_split_method"], "cut5")
|
||||
self.assertEqual(req["ref_audio_path"], ref.name)
|
||||
self.assertEqual(media_type, "wav")
|
||||
self.assertFalse(response_stream)
|
||||
|
||||
def test_explicit_speed_overrides_emotion_speed_preset(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
|
||||
payload = {
|
||||
"text": "Hello.",
|
||||
"ref_audio_path": ref.name,
|
||||
"prompt_lang": "zh",
|
||||
"text_lang": "zh",
|
||||
"format": "wav",
|
||||
"streaming_mode": 0,
|
||||
"speed": 1.05,
|
||||
}
|
||||
self.simple_api.apply_emotion_preset(payload, "calm")
|
||||
payload["speed"] = 1.05
|
||||
req, _, _ = self.simple_api.build_tts_request(payload)
|
||||
|
||||
self.assertEqual(req["temperature"], 0.8)
|
||||
self.assertEqual(req["speed_factor"], 1.05)
|
||||
|
||||
def test_empty_prompt_text_is_rejected_for_vocoder_models(self):
|
||||
original_use_vocoder = self.simple_api.tts_config.use_vocoder
|
||||
original_version = self.simple_api.tts_config.version
|
||||
self.simple_api.tts_config.use_vocoder = True
|
||||
self.simple_api.tts_config.version = "v3"
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
|
||||
with self.assertRaises(DummyHTTPException) as exc:
|
||||
self.simple_api.build_tts_request(
|
||||
{
|
||||
"text": "Hello.",
|
||||
"ref_audio_path": ref.name,
|
||||
"prompt_text": "",
|
||||
"prompt_lang": "zh",
|
||||
"text_lang": "zh",
|
||||
"format": "wav",
|
||||
"text_split_method": "cut5",
|
||||
"streaming_mode": 0,
|
||||
}
|
||||
)
|
||||
self.assertEqual(exc.exception.status_code, 400)
|
||||
self.assertIn("prompt_text is required", str(exc.exception.detail))
|
||||
finally:
|
||||
self.simple_api.tts_config.use_vocoder = original_use_vocoder
|
||||
self.simple_api.tts_config.version = original_version
|
||||
|
||||
def test_ref_audio_duration_limits(self):
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
|
||||
self.simple_api.validate_main_ref_duration("ok.wav")
|
||||
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 2, samplerate=16000)
|
||||
with self.assertRaises(DummyHTTPException) as too_short:
|
||||
self.simple_api.validate_main_ref_duration("short.wav")
|
||||
self.assertEqual(too_short.exception.status_code, 400)
|
||||
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 11, samplerate=16000)
|
||||
with self.assertRaises(DummyHTTPException) as too_long:
|
||||
self.simple_api.validate_main_ref_duration("long.wav")
|
||||
self.assertEqual(too_long.exception.status_code, 400)
|
||||
|
||||
def test_save_upload_file_writes_and_closes(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
upload = FakeUpload("ref.wav", [b"abc", b"def"])
|
||||
saved_path = asyncio.run(self.simple_api.save_upload_file(upload, Path(tmp), "ref"))
|
||||
|
||||
self.assertTrue(upload.closed)
|
||||
self.assertEqual(Path(saved_path).read_bytes(), b"abcdef")
|
||||
|
||||
def test_mvp_tts_builds_payload_from_uploads_and_cleans_tmp_dir(self):
|
||||
captured = {}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
request_dir = Path(tmp) / "request"
|
||||
request_dir.mkdir()
|
||||
|
||||
original_get_request_upload_dir = self.simple_api.get_request_upload_dir
|
||||
original_synthesize_response = self.simple_api.synthesize_response
|
||||
original_sf_info = self.simple_api.sf.info
|
||||
|
||||
def fake_synthesize_response(payload):
|
||||
captured.update(payload)
|
||||
return "audio-response"
|
||||
|
||||
try:
|
||||
self.simple_api.get_request_upload_dir = lambda: request_dir
|
||||
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
|
||||
self.simple_api.synthesize_response = fake_synthesize_response
|
||||
|
||||
response = asyncio.run(
|
||||
self.simple_api.mvp_tts(
|
||||
text="Hello, test.",
|
||||
ref_audio=FakeUpload("ref.wav", [b"main"]),
|
||||
aux_ref_audio=[FakeUpload("aux.wav", [b"aux"])],
|
||||
prompt_text="",
|
||||
text_lang="zh",
|
||||
prompt_lang="zh",
|
||||
format="wav",
|
||||
emotion="calm",
|
||||
speed=1.05,
|
||||
seed=123,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self.simple_api.get_request_upload_dir = original_get_request_upload_dir
|
||||
self.simple_api.synthesize_response = original_synthesize_response
|
||||
self.simple_api.sf.info = original_sf_info
|
||||
|
||||
self.assertEqual(response, "audio-response")
|
||||
self.assertFalse(request_dir.exists())
|
||||
|
||||
self.assertEqual(captured["text"], "Hello, test.")
|
||||
self.assertEqual(captured["prompt_text"], "")
|
||||
self.assertEqual(captured["text_lang"], "zh")
|
||||
self.assertEqual(captured["prompt_lang"], "zh")
|
||||
self.assertEqual(captured["format"], "wav")
|
||||
self.assertEqual(captured["text_split_method"], "cut5")
|
||||
self.assertEqual(captured["streaming_mode"], 0)
|
||||
self.assertEqual(captured["seed"], 123)
|
||||
self.assertEqual(captured["temperature"], 0.8)
|
||||
self.assertEqual(captured["speed"], 1.05)
|
||||
self.assertEqual(len(captured["aux_ref_audio_paths"]), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user