feat: add simple API layer with video support and test frontend

- Add simple_api.py: profile-based API that wraps GPT-SoVITS TTS engine
- Add /api/tts endpoint for MVP: accepts ref audio/video, text, optional aux audio
- Frontend auto-extracts audio from uploaded video files via Web Audio API
- Add emotion presets (neutral/happy/calm/sad/angry) with speed customization
- Add test_frontend/index.html with health check, audio playback, and download
- Add contract tests (7 tests, all passing) using mock TTS pipeline
- Add documentation: simple_api.md (full tutorial), simple_api_quickstart.md
- Add startup scripts: go-simple-api.ps1, go-simple-api.bat, open-test-frontend.ps1
- Add soundfile and python-multipart to requirements.txt
- Text splitting fixed to cut5 (punctuation-based) per MVP spec
This commit is contained in:
mangzhnag 2026-06-11 21:06:43 +08:00
parent 08d627c333
commit 735b2e3554
11 changed files with 2275 additions and 1 deletions

View File

@ -1,4 +1,4 @@
<div align="center">
<div align="center">
<h1>GPT-SoVITS-WebUI</h1>
A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
@ -27,6 +27,24 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
---
## 简化接口 / 测试前端
本工作区新增了一个用于 MVP 调用的中间层接口和测试前端:
- 快速启动:[docs/simple_api_quickstart.md](./docs/simple_api_quickstart.md)
- 完整教程:[docs/simple_api.md](./docs/simple_api.md)
- 后端入口simple_api.py
- 测试前端:启动后访问 http://127.0.0.1:9881/test/
启动命令:
`powershell
cd D:\tts\GPT-SoVITS
.\go-simple-api.ps1
`
---
## Features:
1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion.
@ -478,3 +496,4 @@ Thankful to @Naozumi520 for providing the Cantonese training set and for the gui
<a href="https://github.com/RVC-Boss/GPT-SoVITS/graphs/contributors" target="_blank">
<img src="https://contrib.rocks/image?repo=RVC-Boss/GPT-SoVITS" />
</a>

320
docs/simple_api.md Normal file
View File

@ -0,0 +1,320 @@
# GPT-SoVITS 简化接口教程
本项目原本已经有 `api_v2.py`,但调用时需要传很多 GPT-SoVITS 参数。新增的 `simple_api.py` 是一个中间层,目标是让前端或业务方用更简单的方式调用。
当前 MVP 推荐使用:
```http
POST /api/tts
Content-Type: multipart/form-data
```
适合你的流程:
1. 前端从视频中提取音频(或直接上传已裁剪的 3-10 秒音频)。
2. 用户人工裁剪主参考音频到 3-10 秒。
3. 前端把主参考音频、要生成的文字、可选辅助音频提交给后端。
4. 后端调用 GPT-SoVITS 生成音频并返回。
## 1. 安装依赖
在 GPT-SoVITS 使用的 Python 环境里运行:
```bash
python -m pip install -r requirements.txt
```
本中间层额外需要:
- `python-multipart`:用于接收前端上传的音频文件。
- `soundfile`:用于读取音频信息和校验参考音频时长。
这两个依赖已经写入 `requirements.txt`
## 2. 检查配置
配置文件:
```text
simple_api.yaml
```
常用配置:
```yaml
server:
host: 127.0.0.1
port: 9881
tts_config: GPT_SoVITS/configs/tts_infer.yaml
upload:
dir: runtime/uploads
min_ref_seconds: 3
max_ref_seconds: 10
max_upload_mb: 80
```
说明:
- `server.host`:后端监听地址。
- `server.port`:后端端口。
- `server.tts_config`GPT-SoVITS 推理配置文件。
- `upload.dir`:临时上传目录。
- `upload.min_ref_seconds`:主参考音频最短秒数。
- `upload.max_ref_seconds`:主参考音频最长秒数。
- `upload.max_upload_mb`:单个上传音频最大体积。
当前默认后端地址:
```text
http://127.0.0.1:9881
```
## 3. 启动后端
进入项目目录:
```powershell
cd D:\tts\GPT-SoVITS
```
方式一:
```powershell
python simple_api.py -c simple_api.yaml
```
方式二Windows PowerShell
```powershell
.\go-simple-api.ps1
```
方式三Windows 批处理:
```bat
go-simple-api.bat
```
启动成功后,后端会监听:
```text
http://127.0.0.1:9881
```
## 4. 健康检查
后端启动后,可以访问:
```text
http://127.0.0.1:9881/health
```
或者命令行测试:
```bash
curl http://127.0.0.1:9881/health
```
返回示例:
```json
{
"status": "ok",
"tts_config": "GPT_SoVITS/configs/tts_infer.yaml",
"version": "v2",
"languages": ["auto", "auto_yue", "en", "zh"]
}
```
## 5. 打开测试前端
后端启动后,推荐直接打开:
```text
http://127.0.0.1:9881/test/
```
这个页面已经挂载在后端服务里,不需要额外启动前端服务。
也可以直接打开本地文件:
```text
test_frontend/index.html
```
Windows 下也可以运行:
```powershell
.\open-test-frontend.ps1
```
测试前端里有一个“后端接口地址”输入框。
默认值:
```text
http://127.0.0.1:9881/api/tts
```
如果你修改了 `server.host``server.port`,记得同步修改页面里的接口地址。
## 6. MVP 接口
接口地址:
```http
POST /api/tts
```
请求类型:
```http
multipart/form-data
```
字段说明:
```text
text 必填。需要生成的文字。
ref_audio 必填。主参考音频(支持上传视频,前端会自动提取音频),要求 3-10 秒。
aux_ref_audio 可选。辅助参考音频,可以上传多个。
prompt_text 可选。主参考音频对应文字,可以留空。
text_lang 可选。生成文字语言,默认 zh。
prompt_lang 可选。参考音频语言,默认 zh。即使 prompt_text 为空,也需要传给 GPT-SoVITS 内部。
format 可选。返回格式,默认 wav。
emotion 可选。情绪 presetneutral、happy、calm、sad、angry。
speed 可选。语速,对应 GPT-SoVITS 的 speed_factor。
seed 可选。随机种子,默认 -1。
```
## 7. curl 调用示例
PowerShell 示例:
```powershell
curl.exe -X POST http://127.0.0.1:9881/api/tts `
-F "text=你好,欢迎使用这个声音。" `
-F "ref_audio=@D:\audio\ref.wav" `
-F "prompt_text=" `
-F "text_lang=zh" `
-F "prompt_lang=zh" `
-F "emotion=neutral" `
--output output.wav
```
如果有辅助参考音频:
```powershell
curl.exe -X POST http://127.0.0.1:9881/api/tts `
-F "text=你好,欢迎使用这个声音。" `
-F "ref_audio=@D:\audio\ref.wav" `
-F "aux_ref_audio=@D:\audio\aux1.wav" `
-F "aux_ref_audio=@D:\audio\aux2.wav" `
-F "prompt_text=" `
-F "text_lang=zh" `
-F "prompt_lang=zh" `
--output output.wav
```
## 8. 重要规则
- 主参考音频必须是 3-10 秒。
- `aux_ref_audio` 是可选项。
- `prompt_text` 可以为空,但当前主要针对 GPT-SoVITS v2。
- 如果切到 GPT-SoVITS v3/v4`prompt_text` 会被中间层直接返回 400。
- 生成文字固定使用 `cut5`,也就是按照标点符号切句。
- `emotion` 目前是轻量 preset本质是映射到采样和语速参数更稳定的情绪控制仍然依赖带情绪的参考音频。
## 9. 测试前端使用步骤
1. 启动后端。
2. 打开 `http://127.0.0.1:9881/test/`
3. 检查“后端接口地址”是否为:
```text
http://127.0.0.1:9881/api/tts
```
4. 填写“需要生成的文字”。
5. 上传主参考音频。
6. 可选上传辅助参考音频。
7. 可选填写参考音频文字。
8. 选择语言、情绪、语速。
9. 点击“生成音频”。
10. 页面右侧会显示返回结果,可以在线播放和下载。
## 10. 其他接口
除了 MVP 上传接口,还保留了 profile 调用接口:
```http
GET /voices
GET /speak?text=hello&voice=default
POST /speak
POST /speak/base64
POST /v1/tts
POST /admin/reload-config
POST /admin/weights
```
这些接口适合后续做固定音色 profile不是当前 MVP 的主流程。
## 11. Base64 返回示例
```bash
curl -X POST http://127.0.0.1:9881/speak/base64 ^
-H "Content-Type: application/json" ^
-d "{\"text\":\"hello\",\"voice\":\"default\"}"
```
返回格式:
```json
{
"media_type": "audio/wav",
"audio_base64": "..."
}
```
## 12. 添加固定音色 profile
如果后续要做固定音色,可以编辑 `simple_api.yaml`
```yaml
voices:
default:
ref_audio_path: reference.wav
prompt_text: exact transcript
prompt_lang: zh
text_lang: zh
narrator:
ref_audio_path: voices/narrator.wav
prompt_text: exact transcript of narrator.wav
prompt_lang: zh
text_lang: zh
```
编辑后调用:
```bash
curl -X POST http://127.0.0.1:9881/admin/reload-config
```
## 13. 契约测试
这个测试使用 mock不会加载 GPT-SoVITS 模型:
```bash
python -m unittest tests.test_simple_api_contract
```
用于确认:
- `/api/tts` 路由存在。
- 上传接口能构造正确参数。
- 主参考音频 3-10 秒校验正常。
- 空 `prompt_text` 在 v2 可用。
- v3/v4 空 `prompt_text` 会返回 400。
- 临时上传目录会被清理。

View File

@ -0,0 +1,89 @@
# 简化 TTS 接口快速启动
完整教程见:`docs/simple_api.md`
## 一句话流程
启动后端,打开测试页,上传 3-10 秒参考音频或视频(视频会自动提取音频),填写生成文字,点击生成。
## 1. 启动后端
```powershell
cd D:\tts\GPT-SoVITS
python -m pip install -r requirements.txt
.\go-simple-api.ps1
```
也可以运行:
```bat
go-simple-api.bat
```
默认后端地址:
```text
http://127.0.0.1:9881
```
## 2. 打开测试前端
后端启动后访问:
```text
http://127.0.0.1:9881/test/
```
测试页面里的默认接口地址是:
```text
http://127.0.0.1:9881/api/tts
```
## 3. 调用接口
接口:
```http
POST /api/tts
Content-Type: multipart/form-data
```
最小字段:
```text
text 要生成的文字
ref_audio 3-10 秒主参考音频(支持视频,前端自动提取音频)
```
常用可选字段:
```text
aux_ref_audio 辅助参考音频,可多个
prompt_text 参考音频文本,可留空
text_lang 默认 zh
prompt_lang 默认 zh
emotion neutral / happy / calm / sad / angry
speed 语速,默认 1
seed 默认 -1
```
## 4. PowerShell 示例
```powershell
curl.exe -X POST http://127.0.0.1:9881/api/tts `
-F "text=你好,欢迎使用这个声音。" `
-F "ref_audio=@D:\audio\ref.wav" `
-F "prompt_text=" `
-F "text_lang=zh" `
-F "prompt_lang=zh" `
--output output.wav
```
## 5. 注意事项
- 主参考音频必须是 3-10 秒。
- `prompt_text` 在当前 v2 配置下可以为空。
- 如果切换到 v3/v4`prompt_text` 会返回 400。
- 生成文字固定按标点符号切句。
- 更详细的配置、profile、base64 接口见 `docs/simple_api.md`

12
go-simple-api.bat Normal file
View File

@ -0,0 +1,12 @@
@echo off
set "SCRIPT_DIR=%~dp0"
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
cd /d "%SCRIPT_DIR%"
if exist "%SCRIPT_DIR%\runtime\python.exe" (
"%SCRIPT_DIR%\runtime\python.exe" "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
) else (
python "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %*
)
pause

13
go-simple-api.ps1 Normal file
View File

@ -0,0 +1,13 @@
$ErrorActionPreference = "Stop"
chcp 65001
Set-Location $PSScriptRoot
$runtimePython = Join-Path $PSScriptRoot "runtime\python.exe"
$apiScript = Join-Path $PSScriptRoot "simple_api.py"
$apiConfig = Join-Path $PSScriptRoot "simple_api.yaml"
if (Test-Path $runtimePython) {
& $runtimePython $apiScript -c $apiConfig @args
} else {
python $apiScript -c $apiConfig @args
}

3
open-test-frontend.ps1 Normal file
View File

@ -0,0 +1,3 @@
$ErrorActionPreference = "Stop"
Set-Location $PSScriptRoot
Start-Process (Join-Path $PSScriptRoot "test_frontend\index.html")

View File

@ -3,6 +3,7 @@ numpy<2.0
scipy
tensorboard
librosa==0.10.2
soundfile
numba
pytorch-lightning>=2.4
gradio<5
@ -22,6 +23,7 @@ transformers>=4.43,<=4.50
peft<0.18.0
chardet
PyYAML
python-multipart
psutil
jieba_fast
jieba

741
simple_api.py Normal file
View File

@ -0,0 +1,741 @@
"""
Small profile-based API layer for GPT-SoVITS.
Run:
python simple_api.py -c simple_api.yaml
Then call:
POST http://127.0.0.1:9881/speak
{"text": "hello", "voice": "default"}
"""
from __future__ import annotations
import argparse
import base64
import os
import shutil
import subprocess
import sys
import threading
import traceback
import uuid
import wave
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import soundfile as sf
import yaml
from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
PROJECT_ROOT = Path(__file__).resolve().parent
os.chdir(PROJECT_ROOT)
sys.path.append(str(PROJECT_ROOT))
sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS"))
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
get_method_names as get_cut_method_names,
)
DEFAULT_TTS_PARAMS: Dict[str, Any] = {
"text": "",
"text_lang": "zh",
"ref_audio_path": "",
"aux_ref_audio_paths": [],
"prompt_text": "",
"prompt_lang": "zh",
"top_k": 15,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut5",
"batch_size": 1,
"batch_threshold": 0.75,
"split_bucket": True,
"speed_factor": 1.0,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"return_fragment": False,
"fixed_length_chunk": False,
"parallel_infer": True,
"repetition_penalty": 1.35,
"sample_steps": 32,
"super_sampling": False,
"overlap_length": 2,
"min_chunk_length": 16,
}
SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"}
SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"}
VOICE_META_KEYS = {"description"}
class SpeakRequest(BaseModel):
text: str
voice: Optional[str] = None
text_lang: Optional[str] = None
ref_audio_path: Optional[str] = None
aux_ref_audio_paths: Optional[List[str]] = None
prompt_text: Optional[str] = None
prompt_lang: Optional[str] = None
format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac")
stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2")
streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false")
speed: Optional[float] = Field(default=None, description="Alias of speed_factor")
speed_factor: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
text_split_method: Optional[str] = None
batch_size: Optional[int] = None
batch_threshold: Optional[float] = None
split_bucket: Optional[bool] = None
fragment_interval: Optional[float] = None
seed: Optional[int] = None
parallel_infer: Optional[bool] = None
repetition_penalty: Optional[float] = None
sample_steps: Optional[int] = None
super_sampling: Optional[bool] = None
overlap_length: Optional[int] = None
min_chunk_length: Optional[int] = None
class WeightsRequest(BaseModel):
gpt_weights_path: Optional[str] = None
sovits_weights_path: Optional[str] = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="GPT-SoVITS simple API")
parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path")
parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path")
parser.add_argument("-a", "--bind-addr", default=None, help="bind address")
parser.add_argument("-p", "--port", type=int, default=None, help="bind port")
return parser.parse_args()
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
path = Path(config_path)
if not path.is_absolute():
path = PROJECT_ROOT / path
if not path.exists():
raise FileNotFoundError(f"simple API config not found: {path}")
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
if not isinstance(data, dict):
raise ValueError("simple API config must be a YAML object")
return data
args = parse_args()
simple_config = load_yaml_config(args.config)
server_config = simple_config.get("server", {}) or {}
host = args.bind_addr or server_config.get("host", "127.0.0.1")
port = args.port or int(server_config.get("port", 9881))
tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml")
cut_method_names = get_cut_method_names()
tts_config: Optional[TTS_Config] = None
tts_pipeline: Optional[TTS] = None
infer_lock = threading.Lock()
APP = FastAPI(
title="GPT-SoVITS Simple API",
description="Profile-based API layer that hides GPT-SoVITS request details.",
version="1.0.0",
)
APP.add_middleware(
CORSMiddleware,
allow_origins=simple_config.get("cors_allow_origins", ["*"]),
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
test_frontend_dir = PROJECT_ROOT / "test_frontend"
if test_frontend_dir.exists():
APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend")
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
def handle_pack_ogg() -> None:
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
try:
threading.stack_size(4096 * 4096)
pack_thread = threading.Thread(target=handle_pack_ogg)
pack_thread.start()
pack_thread.join()
except (RuntimeError, ValueError):
handle_pack_ogg()
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
del rate
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
del io_buffer
wav_buffer = BytesIO()
sf.write(wav_buffer, data, rate, format="wav")
return wav_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le",
"-ar",
str(rate),
"-ac",
"1",
"-i",
"pipe:0",
"-c:a",
"aac",
"-b:a",
"192k",
"-vn",
"-f",
"adts",
"pipe:1",
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO:
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes:
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def request_to_dict(request: BaseModel) -> Dict[str, Any]:
if hasattr(request, "model_dump"):
return request.model_dump(exclude_none=True)
return request.dict(exclude_none=True)
def get_default_voice_name() -> Optional[str]:
default_voice = simple_config.get("default_voice")
if default_voice:
return str(default_voice)
voices = simple_config.get("voices", {}) or {}
if "default" in voices:
return "default"
return next(iter(voices.keys()), None)
def resolve_project_path(path_value: Optional[str]) -> Optional[str]:
if path_value in [None, ""]:
return None
path = Path(str(path_value))
if not path.is_absolute():
path = PROJECT_ROOT / path
return str(path)
def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]:
if streaming_mode is None:
streaming_mode = 2 if stream else 0
elif isinstance(streaming_mode, bool):
streaming_mode = 2 if streaming_mode else 0
if streaming_mode == 0:
return False, False, False, False
if streaming_mode == 1:
return False, True, False, True
if streaming_mode == 2:
return True, False, False, True
if streaming_mode == 3:
return True, False, True, True
raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false")
def get_upload_config() -> Dict[str, Any]:
return simple_config.get("upload", {}) or {}
def get_upload_root() -> Path:
upload_dir = str(get_upload_config().get("dir", "runtime/uploads"))
root = Path(upload_dir)
if not root.is_absolute():
root = PROJECT_ROOT / root
root.mkdir(parents=True, exist_ok=True)
return root
def get_request_upload_dir() -> Path:
request_dir = get_upload_root() / uuid.uuid4().hex
request_dir.mkdir(parents=True, exist_ok=True)
return request_dir
def get_upload_limits() -> Tuple[float, float, int]:
upload_config = get_upload_config()
min_seconds = float(upload_config.get("min_ref_seconds", 3.0))
max_seconds = float(upload_config.get("max_ref_seconds", 10.0))
max_upload_mb = int(upload_config.get("max_upload_mb", 80))
return min_seconds, max_seconds, max_upload_mb
def get_upload_suffix(upload: UploadFile) -> str:
suffix = Path(upload.filename or "").suffix.lower()
if suffix:
return suffix
content_type = (upload.content_type or "").lower()
if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}:
return ".wav"
if content_type == "audio/mpeg":
return ".mp3"
if content_type == "audio/ogg":
return ".ogg"
if content_type in {"audio/aac", "audio/aacp"}:
return ".aac"
return ".wav"
def validate_audio_upload(upload: UploadFile) -> None:
suffix = get_upload_suffix(upload)
content_type = (upload.content_type or "").lower()
is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""}
if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type:
raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}")
async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str:
validate_audio_upload(upload)
_, _, max_upload_mb = get_upload_limits()
max_bytes = max_upload_mb * 1024 * 1024
suffix = get_upload_suffix(upload)
target_path = target_dir / f"{name_prefix}{suffix}"
total_size = 0
try:
with target_path.open("wb") as f:
while True:
chunk = await upload.read(1024 * 1024)
if not chunk:
break
total_size += len(chunk)
if total_size > max_bytes:
raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB")
f.write(chunk)
if total_size == 0:
raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}")
finally:
await upload.close()
return str(target_path)
def get_audio_duration_with_ffprobe(audio_path: str) -> float:
result = subprocess.run(
[
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
audio_path,
],
capture_output=True,
text=True,
timeout=15,
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip() or "ffprobe failed")
return float(result.stdout.strip())
def get_audio_duration_seconds(audio_path: str) -> float:
try:
info = sf.info(audio_path)
if info.samplerate <= 0:
raise ValueError("invalid sample rate")
return float(info.frames) / float(info.samplerate)
except Exception as sf_exc:
try:
return get_audio_duration_with_ffprobe(audio_path)
except Exception as ffprobe_exc:
raise HTTPException(
status_code=400,
detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}",
) from ffprobe_exc
def validate_main_ref_duration(audio_path: str) -> None:
min_seconds, max_seconds, _ = get_upload_limits()
duration = get_audio_duration_seconds(audio_path)
if duration < min_seconds or duration > max_seconds:
raise HTTPException(
status_code=400,
detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s",
)
def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None:
if emotion in [None, ""]:
return
emotion_name = str(emotion).strip().lower()
emotion_presets = simple_config.get("emotion_presets", {}) or {}
if emotion_name not in emotion_presets:
supported = ", ".join(sorted(emotion_presets.keys())) or "none"
raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}")
payload.update(emotion_presets[emotion_name] or {})
def prompt_text_required_for_current_model() -> bool:
if tts_config is None:
return False
use_vocoder = getattr(tts_config, "use_vocoder", None)
if use_vocoder is not None:
return bool(use_vocoder)
return getattr(tts_config, "version", None) in {"v3", "v4"}
def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]:
ref_audio_path = resolve_project_path(profile.get("ref_audio_path"))
return {
"name": name,
"description": profile.get("description", ""),
"text_lang": profile.get("text_lang"),
"prompt_lang": profile.get("prompt_lang"),
"ref_audio_path": profile.get("ref_audio_path"),
"ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")),
}
def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]:
voices = simple_config.get("voices", {}) or {}
explicit_ref_audio = bool(payload.get("ref_audio_path"))
voice_name = payload.pop("voice", None)
if voice_name is None and not explicit_ref_audio:
voice_name = get_default_voice_name()
voice_profile: Dict[str, Any] = {}
if voice_name:
if voice_name not in voices:
raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}")
voice_profile = deepcopy(voices[voice_name] or {})
tts_req = deepcopy(DEFAULT_TTS_PARAMS)
tts_req.update(simple_config.get("defaults", {}) or {})
tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS})
stream = payload.pop("stream", None)
media_type = payload.pop("format", None)
speed = payload.pop("speed", None)
if media_type is not None:
tts_req["media_type"] = media_type
for key, value in payload.items():
if key in DEFAULT_TTS_PARAMS:
tts_req[key] = value
if speed is not None:
tts_req["speed_factor"] = speed
ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path"))
if ref_audio_path:
tts_req["ref_audio_path"] = ref_audio_path
aux_paths = tts_req.get("aux_ref_audio_paths") or []
tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item]
media_type = str(tts_req.get("media_type", "wav")).lower()
tts_req["media_type"] = media_type
if media_type not in SUPPORTED_MEDIA_TYPES:
raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}")
text = str(tts_req.get("text") or "").strip()
if not text:
raise HTTPException(status_code=400, detail="text is required")
tts_req["text"] = text
if not tts_req.get("ref_audio_path"):
raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request")
if not Path(tts_req["ref_audio_path"]).exists():
raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}")
for aux_path in tts_req["aux_ref_audio_paths"]:
if aux_path and not Path(aux_path).exists():
raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}")
if tts_config is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
text_lang = str(tts_req.get("text_lang") or "").lower()
prompt_lang = str(tts_req.get("prompt_lang") or "").lower()
tts_req["text_lang"] = text_lang
tts_req["prompt_lang"] = prompt_lang
if text_lang not in tts_config.languages:
raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}")
if prompt_lang not in tts_config.languages:
raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}")
if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model():
version = getattr(tts_config, "version", "current")
raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}")
split_method = str(tts_req.get("text_split_method") or "cut5")
if split_method not in cut_method_names:
raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}")
streaming_mode = tts_req.pop("streaming_mode", None)
streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream)
tts_req["streaming_mode"] = streaming_enabled
tts_req["return_fragment"] = return_fragment
tts_req["fixed_length_chunk"] = fixed_length_chunk
return tts_req, media_type, response_stream
def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]:
tts_req, media_type, response_stream = build_tts_request(payload)
if response_stream:
raise HTTPException(status_code=400, detail="base64 output does not support streaming")
if tts_pipeline is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
try:
with infer_lock:
sr, audio_data = next(tts_pipeline.run(tts_req))
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return audio_bytes, media_type
except Exception as exc:
raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc
def synthesize_response(payload: Dict[str, Any]) -> Response:
tts_req, media_type, response_stream = build_tts_request(payload)
if tts_pipeline is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
if response_stream:
def streaming_generator() -> Generator[bytes, None, None]:
first_chunk = True
chunk_media_type = media_type
with infer_lock:
for sr, chunk in tts_pipeline.run(tts_req):
if first_chunk and chunk_media_type == "wav":
yield wave_header_chunk(sample_rate=sr)
chunk_media_type = "raw"
first_chunk = False
yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue()
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
try:
with infer_lock:
sr, audio_data = next(tts_pipeline.run(tts_req))
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return Response(audio_bytes, media_type=f"audio/{media_type}")
except Exception as exc:
return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)})
@APP.on_event("startup")
def startup() -> None:
global tts_config, tts_pipeline
tts_config = TTS_Config(tts_config_path)
print(tts_config)
tts_pipeline = TTS(tts_config)
@APP.get("/")
def index() -> Dict[str, Any]:
return {
"name": "GPT-SoVITS Simple API",
"endpoints": [
"/health",
"/voices",
"/api/tts",
"/speak",
"/speak/base64",
"/admin/weights",
"/admin/reload-config",
],
}
@APP.get("/health")
def health() -> Dict[str, Any]:
return {
"status": "ok" if tts_pipeline is not None else "starting",
"tts_config": tts_config_path,
"version": getattr(tts_config, "version", None),
"languages": getattr(tts_config, "languages", []),
}
@APP.get("/voices")
def list_voices() -> Dict[str, Any]:
voices = simple_config.get("voices", {}) or {}
return {
"default_voice": get_default_voice_name(),
"voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()],
}
@APP.post("/api/tts")
async def mvp_tts(
text: str = Form(...),
ref_audio: UploadFile = File(...),
aux_ref_audio: Optional[List[UploadFile]] = File(default=None),
prompt_text: str = Form(default=""),
text_lang: str = Form(default="zh"),
prompt_lang: str = Form(default="zh"),
format: str = Form(default="wav"),
emotion: Optional[str] = Form(default=None),
speed: Optional[float] = Form(default=None),
seed: int = Form(default=-1),
) -> Response:
request_dir = get_request_upload_dir()
try:
ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref")
validate_main_ref_duration(ref_audio_path)
aux_ref_audio_paths = []
for index, upload in enumerate(aux_ref_audio or []):
aux_path = await save_upload_file(upload, request_dir, f"aux_{index}")
get_audio_duration_seconds(aux_path)
aux_ref_audio_paths.append(aux_path)
payload: Dict[str, Any] = {
"text": text,
"text_lang": text_lang,
"ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": aux_ref_audio_paths,
"prompt_text": prompt_text or "",
"prompt_lang": prompt_lang,
"format": format,
"text_split_method": "cut5",
"streaming_mode": 0,
"seed": seed,
}
apply_emotion_preset(payload, emotion)
payload["text_split_method"] = "cut5"
payload["streaming_mode"] = 0
if speed is not None:
payload["speed"] = speed
return synthesize_response(payload)
finally:
shutil.rmtree(request_dir, ignore_errors=True)
@APP.get("/speak")
def speak_get(
text: str,
voice: Optional[str] = None,
text_lang: Optional[str] = None,
format: Optional[str] = None,
stream: Optional[bool] = None,
speed: Optional[float] = None,
) -> Response:
payload = {
"text": text,
"voice": voice,
"text_lang": text_lang,
"format": format,
"stream": stream,
"speed": speed,
}
return synthesize_response({k: v for k, v in payload.items() if v is not None})
@APP.post("/speak")
def speak_post(request: SpeakRequest) -> Response:
return synthesize_response(request_to_dict(request))
@APP.post("/v1/tts")
def openai_style_tts(request: SpeakRequest) -> Response:
return synthesize_response(request_to_dict(request))
@APP.post("/speak/base64")
def speak_base64(request: SpeakRequest) -> Dict[str, Any]:
audio_bytes, media_type = synthesize_once(request_to_dict(request))
return {
"media_type": f"audio/{media_type}",
"audio_base64": base64.b64encode(audio_bytes).decode("ascii"),
}
@APP.post("/admin/reload-config")
def reload_config() -> Dict[str, Any]:
global simple_config
simple_config = load_yaml_config(args.config)
return {"message": "success", "default_voice": get_default_voice_name()}
@APP.post("/admin/weights")
def set_weights(request: WeightsRequest) -> Dict[str, Any]:
if tts_pipeline is None:
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
if not request.gpt_weights_path and not request.sovits_weights_path:
raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required")
try:
with infer_lock:
if request.gpt_weights_path:
tts_pipeline.init_t2s_weights(request.gpt_weights_path)
if request.sovits_weights_path:
tts_pipeline.init_vits_weights(request.sovits_weights_path)
except Exception as exc:
raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc
return {"message": "success"}
if __name__ == "__main__":
import uvicorn
try:
uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1)
except Exception:
traceback.print_exc()
sys.exit(1)

59
simple_api.yaml Normal file
View File

@ -0,0 +1,59 @@
server:
host: 127.0.0.1
port: 9881
tts_config: GPT_SoVITS/configs/tts_infer.yaml
cors_allow_origins:
- "*"
upload:
dir: runtime/uploads
min_ref_seconds: 3
max_ref_seconds: 10
max_upload_mb: 80
default_voice: default
defaults:
text_lang: zh
prompt_lang: zh
media_type: wav
text_split_method: cut5
batch_size: 1
batch_threshold: 0.75
split_bucket: true
speed_factor: 1.0
fragment_interval: 0.3
seed: -1
parallel_infer: true
repetition_penalty: 1.35
sample_steps: 32
super_sampling: false
overlap_length: 2
min_chunk_length: 16
emotion_presets:
neutral: {}
happy:
temperature: 1.1
top_p: 0.95
calm:
temperature: 0.8
top_p: 0.85
speed_factor: 0.92
sad:
temperature: 0.75
top_p: 0.85
speed_factor: 0.9
angry:
temperature: 1.2
top_k: 20
repetition_penalty: 1.25
voices:
default:
description: Replace this profile with your reference voice.
ref_audio_path: reference.wav
prompt_text: Replace this with the exact text spoken in reference.wav.
prompt_lang: zh
text_lang: zh

713
test_frontend/index.html Normal file
View File

@ -0,0 +1,713 @@
<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>GPT-SoVITS API Test</title>
<style>
:root {
color-scheme: light;
--bg: #f5f7f4;
--panel: #ffffff;
--ink: #18201d;
--muted: #64706b;
--line: #d8ded9;
--accent: #19745f;
--accent-strong: #0f5f4c;
--warn: #a15d12;
--bad: #b42318;
--good: #18794e;
--shadow: 0 18px 50px rgba(31, 44, 38, 0.12);
font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
}
* {
box-sizing: border-box;
}
body {
margin: 0;
min-height: 100vh;
background:
linear-gradient(135deg, rgba(25, 116, 95, 0.08), rgba(248, 251, 247, 0) 42%),
var(--bg);
color: var(--ink);
}
main {
width: min(1180px, calc(100vw - 32px));
margin: 0 auto;
padding: 28px 0 44px;
}
header {
display: grid;
grid-template-columns: minmax(0, 1fr) auto;
align-items: end;
gap: 20px;
padding: 8px 0 22px;
border-bottom: 1px solid var(--line);
}
h1 {
margin: 0;
font-size: clamp(28px, 4vw, 54px);
line-height: 1.02;
letter-spacing: 0;
}
.sub {
margin: 12px 0 0;
color: var(--muted);
max-width: 760px;
line-height: 1.6;
}
.status {
min-width: 172px;
padding: 12px 14px;
border: 1px solid var(--line);
border-radius: 8px;
background: rgba(255, 255, 255, 0.68);
color: var(--muted);
text-align: right;
font-size: 14px;
}
.status strong {
display: block;
color: var(--ink);
font-size: 16px;
margin-bottom: 2px;
}
.workspace {
display: grid;
grid-template-columns: minmax(0, 1.05fr) minmax(320px, 0.65fr);
gap: 22px;
padding-top: 24px;
align-items: start;
}
section {
background: var(--panel);
border: 1px solid var(--line);
border-radius: 8px;
box-shadow: var(--shadow);
}
.form {
padding: 22px;
}
.result {
padding: 20px;
position: sticky;
top: 18px;
}
.grid {
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 16px;
}
.full {
grid-column: 1 / -1;
}
label {
display: block;
color: var(--ink);
font-weight: 650;
font-size: 14px;
margin-bottom: 8px;
}
input,
textarea,
select {
width: 100%;
border: 1px solid var(--line);
border-radius: 8px;
background: #fbfcfb;
color: var(--ink);
font: inherit;
padding: 12px;
outline: none;
transition: border-color 150ms ease, box-shadow 150ms ease, background 150ms ease;
}
input:focus,
textarea:focus,
select:focus {
border-color: rgba(25, 116, 95, 0.72);
box-shadow: 0 0 0 4px rgba(25, 116, 95, 0.12);
background: #fff;
}
textarea {
min-height: 142px;
resize: vertical;
line-height: 1.55;
}
.hint {
margin: 7px 0 0;
color: var(--muted);
font-size: 13px;
line-height: 1.45;
}
.file-line {
display: flex;
align-items: center;
gap: 10px;
min-height: 24px;
color: var(--muted);
font-size: 13px;
margin-top: 8px;
}
.duration-ok {
color: var(--good);
}
.duration-warn {
color: var(--warn);
}
.actions {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 12px;
margin-top: 18px;
padding-top: 18px;
border-top: 1px solid var(--line);
}
button,
.download {
border: 0;
border-radius: 8px;
min-height: 44px;
padding: 0 16px;
background: var(--accent);
color: #fff;
font-weight: 700;
font-size: 14px;
cursor: pointer;
display: inline-flex;
align-items: center;
justify-content: center;
gap: 8px;
text-decoration: none;
transition: transform 140ms ease, background 140ms ease, opacity 140ms ease;
}
button:hover,
.download:hover {
background: var(--accent-strong);
transform: translateY(-1px);
}
button:disabled,
.download[aria-disabled="true"] {
opacity: 0.55;
cursor: not-allowed;
transform: none;
}
.ghost {
background: #e9efeb;
color: var(--ink);
}
.ghost:hover {
background: #dde7e1;
}
.result h2 {
margin: 0 0 12px;
font-size: 20px;
letter-spacing: 0;
}
.log {
min-height: 126px;
border-radius: 8px;
border: 1px solid var(--line);
background: #101815;
color: #d8f3e9;
padding: 13px;
font-family: ui-monospace, SFMono-Regular, Consolas, "Liberation Mono", monospace;
font-size: 12px;
line-height: 1.6;
white-space: pre-wrap;
overflow-wrap: anywhere;
}
audio {
width: 100%;
margin-top: 16px;
}
.meta {
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 10px;
margin: 16px 0;
}
.metric {
border: 1px solid var(--line);
border-radius: 8px;
padding: 12px;
background: #fbfcfb;
}
.metric span {
display: block;
color: var(--muted);
font-size: 12px;
margin-bottom: 3px;
}
.metric strong {
font-size: 15px;
}
.danger {
color: var(--bad);
}
@media (max-width: 860px) {
header,
.workspace,
.grid {
grid-template-columns: 1fr;
}
.status {
text-align: left;
}
.result {
position: static;
}
}
</style>
</head>
<body>
<main>
<header>
<div>
<h1>GPT-SoVITS 接口测试台</h1>
<p class="sub">选择 3-10 秒参考音频或视频(视频会自动提取音频),填写后端接口地址和生成文本,直接调用中间层 <code>/api/tts</code></p>
</div>
<div class="status" id="statusBox">
<strong>未检测</strong>
后端连接状态
</div>
</header>
<div class="workspace">
<section class="form">
<form id="ttsForm">
<div class="grid">
<div class="full">
<label for="endpoint">后端接口地址</label>
<input id="endpoint" name="endpoint" type="url" value="http://127.0.0.1:9881/api/tts" required>
<p class="hint">如果后端端口或主机变了,在这里改完整地址。页面会把表单直接 POST 到这个地址。</p>
</div>
<div class="full">
<label for="text">需要生成的文字</label>
<textarea id="text" name="text" placeholder="输入要生成的文字,后端固定按标点符号切句。" required></textarea>
</div>
<div>
<label for="refAudio">主参考音频/视频</label>
<input id="refAudio" name="ref_audio" type="file" accept="audio/*,video/*" required>
<div class="file-line" id="refInfo">请选择 3-10 秒音频或视频(视频会自动提取音频)</div>
<div class="file-line" id="extractInfo" style="display:none;"></div>
</div>
<div>
<label for="auxAudio">辅助参考音频</label>
<input id="auxAudio" name="aux_ref_audio" type="file" accept="audio/*" multiple>
<div class="file-line" id="auxInfo">可选,可多选</div>
</div>
<div class="full">
<label for="promptText">参考音频文字</label>
<textarea id="promptText" name="prompt_text" placeholder="可留空。v2 支持空参考文字v3/v4 后端会要求填写。"></textarea>
</div>
<div>
<label for="textLang">生成文字语言</label>
<select id="textLang" name="text_lang">
<option value="zh">zh</option>
<option value="en">en</option>
<option value="ja">ja</option>
<option value="ko">ko</option>
<option value="yue">yue</option>
<option value="auto">auto</option>
</select>
</div>
<div>
<label for="promptLang">参考音频语言</label>
<select id="promptLang" name="prompt_lang">
<option value="zh">zh</option>
<option value="en">en</option>
<option value="ja">ja</option>
<option value="ko">ko</option>
<option value="yue">yue</option>
<option value="auto">auto</option>
</select>
</div>
<div>
<label for="emotion">情绪 preset</label>
<select id="emotion" name="emotion">
<option value="neutral">neutral</option>
<option value="happy">happy</option>
<option value="calm">calm</option>
<option value="sad">sad</option>
<option value="angry">angry</option>
</select>
</div>
<div>
<label for="speed">语速</label>
<input id="speed" name="speed" type="number" min="0.5" max="2" step="0.05" value="1">
<p class="hint">显式语速会覆盖情绪 preset 中的语速。</p>
</div>
<div>
<label for="seed">Seed</label>
<input id="seed" name="seed" type="number" value="-1">
</div>
<div>
<label for="format">返回格式</label>
<select id="format" name="format">
<option value="wav">wav</option>
<option value="ogg">ogg</option>
<option value="aac">aac</option>
<option value="raw">raw</option>
</select>
</div>
</div>
<div class="actions">
<button type="button" class="ghost" id="healthBtn">检测后端</button>
<button type="submit" id="submitBtn">生成音频</button>
<button type="button" class="ghost" id="resetBtn">清空结果</button>
</div>
</form>
</section>
<section class="result">
<h2>返回结果</h2>
<div class="meta">
<div class="metric"><span>耗时</span><strong id="elapsed">-</strong></div>
<div class="metric"><span>文件大小</span><strong id="fileSize">-</strong></div>
</div>
<div class="log" id="log">等待请求。</div>
<audio id="player" controls hidden></audio>
<div class="actions">
<a class="download" id="downloadLink" aria-disabled="true">下载音频</a>
</div>
</section>
</div>
</main>
<script>
const form = document.querySelector("#ttsForm");
const endpoint = document.querySelector("#endpoint");
const refAudio = document.querySelector("#refAudio");
const auxAudio = document.querySelector("#auxAudio");
const refInfo = document.querySelector("#refInfo");
const auxInfo = document.querySelector("#auxInfo");
const logBox = document.querySelector("#log");
const player = document.querySelector("#player");
const downloadLink = document.querySelector("#downloadLink");
const submitBtn = document.querySelector("#submitBtn");
const resetBtn = document.querySelector("#resetBtn");
const healthBtn = document.querySelector("#healthBtn");
const statusBox = document.querySelector("#statusBox");
const elapsed = document.querySelector("#elapsed");
const fileSize = document.querySelector("#fileSize");
let resultUrl = null;
if (location.protocol === "http:" || location.protocol === "https:") {
endpoint.value = new URL("/api/tts", location.origin).toString();
}
function log(message, isError = false) {
logBox.textContent = message;
logBox.classList.toggle("danger", isError);
}
function bytesLabel(bytes) {
if (!bytes) return "-";
const units = ["B", "KB", "MB", "GB"];
let value = bytes;
let index = 0;
while (value >= 1024 && index < units.length - 1) {
value /= 1024;
index += 1;
}
return `${value.toFixed(index === 0 ? 0 : 2)} ${units[index]}`;
}
function apiBaseUrl() {
try {
const url = new URL(endpoint.value.trim());
url.pathname = url.pathname.replace(/\/api\/tts\/?$/, "/health");
url.search = "";
url.hash = "";
return url.toString();
} catch {
return "";
}
}
function clearResult() {
if (resultUrl) URL.revokeObjectURL(resultUrl);
resultUrl = null;
player.hidden = true;
player.removeAttribute("src");
downloadLink.removeAttribute("href");
downloadLink.setAttribute("aria-disabled", "true");
elapsed.textContent = "-";
fileSize.textContent = "-";
log("等待请求。");
}
let extractedAudioBlob = null;
function isVideoFile(file) {
return file && file.type && file.type.startsWith("video/");
}
async function extractAudioFromVideo(file) {
const extractInfo = document.querySelector("#extractInfo");
extractInfo.style.display = "block";
extractInfo.textContent = "正在从视频中提取音频...";
extractInfo.className = "file-line";
try {
const video = document.createElement("video");
video.preload = "auto";
const videoUrl = URL.createObjectURL(file);
video.src = videoUrl;
await new Promise((resolve, reject) => {
video.onloadeddata = resolve;
video.onerror = () => reject(new Error("无法加载视频文件"));
});
const audioCtx = new (window.AudioContext || window.webkitAudioContext)();
const response = await fetch(videoUrl);
const arrayBuffer = await response.arrayBuffer();
const audioBuffer = await audioCtx.decodeAudioData(arrayBuffer);
const wavBlob = audioBufferToWav(audioBuffer);
URL.revokeObjectURL(videoUrl);
audioCtx.close();
extractedAudioBlob = wavBlob;
const duration = audioBuffer.duration;
const ok = Number.isFinite(duration) && duration >= 3 && duration <= 10;
extractInfo.textContent = `已提取音频 · ${duration.toFixed(2)}s · ${bytesLabel(wavBlob.size)}${ok ? " ✓" : " ⚠ 建议裁剪到 3-10 秒"}`;
extractInfo.className = `file-line ${ok ? "duration-ok" : "duration-warn"}`;
return true;
} catch (err) {
extractInfo.textContent = `提取失败:${err.message}`;
extractInfo.className = "file-line duration-warn";
extractedAudioBlob = null;
return false;
}
}
function audioBufferToWav(buffer) {
const numChannels = buffer.numberOfChannels;
const sampleRate = buffer.sampleRate;
const format = 1;
const bitDepth = 16;
const bytesPerSample = bitDepth / 8;
const blockAlign = numChannels * bytesPerSample;
const dataLength = buffer.length * blockAlign;
const headerLength = 44;
const totalLength = headerLength + dataLength;
const arrayBuffer = new ArrayBuffer(totalLength);
const view = new DataView(arrayBuffer);
function writeString(offset, str) {
for (let i = 0; i < str.length; i++) view.setUint8(offset + i, str.charCodeAt(i));
}
writeString(0, "RIFF");
view.setUint32(4, totalLength - 8, true);
writeString(8, "WAVE");
writeString(12, "fmt ");
view.setUint32(16, 16, true);
view.setUint16(20, format, true);
view.setUint16(22, numChannels, true);
view.setUint32(24, sampleRate, true);
view.setUint32(28, sampleRate * blockAlign, true);
view.setUint16(32, blockAlign, true);
view.setUint16(34, bitDepth, true);
writeString(36, "data");
view.setUint32(40, dataLength, true);
const channels = [];
for (let ch = 0; ch < numChannels; ch++) channels.push(buffer.getChannelData(ch));
let offset = 44;
for (let i = 0; i < buffer.length; i++) {
for (let ch = 0; ch < numChannels; ch++) {
const sample = Math.max(-1, Math.min(1, channels[ch][i]));
view.setInt16(offset, sample < 0 ? sample * 0x8000 : sample * 0x7FFF, true);
offset += 2;
}
}
return new Blob([arrayBuffer], { type: "audio/wav" });
}
function inspectDuration(file, target) {
extractedAudioBlob = null;
const extractInfo = document.querySelector("#extractInfo");
extractInfo.style.display = "none";
if (!file) {
target.textContent = "请选择 3-10 秒音频或视频";
target.className = "file-line";
return;
}
if (isVideoFile(file)) {
target.textContent = `${file.name} · ${bytesLabel(file.size)} · 视频文件`;
target.className = "file-line";
extractAudioFromVideo(file);
return;
}
const url = URL.createObjectURL(file);
const audio = new Audio();
audio.preload = "metadata";
audio.onloadedmetadata = () => {
URL.revokeObjectURL(url);
const duration = audio.duration;
const ok = Number.isFinite(duration) && duration >= 3 && duration <= 10;
target.textContent = `${file.name} · ${duration.toFixed(2)}s · ${bytesLabel(file.size)}`;
target.className = `file-line ${ok ? "duration-ok" : "duration-warn"}`;
};
audio.onerror = () => {
URL.revokeObjectURL(url);
target.textContent = `${file.name} · 无法读取时长 · ${bytesLabel(file.size)}`;
target.className = "file-line duration-warn";
};
audio.src = url;
}
refAudio.addEventListener("change", () => {
inspectDuration(refAudio.files[0], refInfo);
});
auxAudio.addEventListener("change", () => {
const count = auxAudio.files.length;
auxInfo.textContent = count ? `已选择 ${count} 个辅助音频` : "可选,可多选";
});
healthBtn.addEventListener("click", async () => {
const healthUrl = apiBaseUrl();
if (!healthUrl) {
log("后端地址格式不正确。", true);
return;
}
statusBox.innerHTML = "<strong>检测中</strong>正在请求 /health";
try {
const response = await fetch(healthUrl);
const data = await response.json();
if (!response.ok) throw new Error(JSON.stringify(data));
statusBox.innerHTML = `<strong>可连接</strong>${data.version || "unknown"} · ${data.status || "ok"}`;
log(JSON.stringify(data, null, 2));
} catch (error) {
statusBox.innerHTML = "<strong>连接失败</strong>检查后端是否启动";
log(`检测失败:${error.message}`, true);
}
});
resetBtn.addEventListener("click", clearResult);
form.addEventListener("submit", async (event) => {
event.preventDefault();
clearResult();
const file = refAudio.files[0];
if (!file) {
log("请先选择主参考音频或视频。", true);
return;
}
if (isVideoFile(file) && !extractedAudioBlob) {
log("视频音频提取尚未完成,请稍候再试。", true);
return;
}
const started = performance.now();
const data = new FormData();
data.append("text", document.querySelector("#text").value.trim());
if (extractedAudioBlob) {
data.append("ref_audio", extractedAudioBlob, "extracted_audio.wav");
} else {
data.append("ref_audio", file);
}
for (const aux of auxAudio.files) data.append("aux_ref_audio", aux);
data.append("prompt_text", document.querySelector("#promptText").value);
data.append("text_lang", document.querySelector("#textLang").value);
data.append("prompt_lang", document.querySelector("#promptLang").value);
data.append("format", document.querySelector("#format").value);
data.append("emotion", document.querySelector("#emotion").value);
data.append("speed", document.querySelector("#speed").value);
data.append("seed", document.querySelector("#seed").value);
submitBtn.disabled = true;
log("正在请求后端,请等待模型生成。");
try {
const response = await fetch(endpoint.value.trim(), { method: "POST", body: data });
const contentType = response.headers.get("content-type") || "";
if (!response.ok) {
const detail = contentType.includes("application/json") ? await response.json() : await response.text();
throw new Error(typeof detail === "string" ? detail : JSON.stringify(detail, null, 2));
}
const blob = await response.blob();
resultUrl = URL.createObjectURL(blob);
player.src = resultUrl;
player.hidden = false;
downloadLink.href = resultUrl;
downloadLink.download = `gpt-sovits-${Date.now()}.${document.querySelector("#format").value}`;
downloadLink.setAttribute("aria-disabled", "false");
elapsed.textContent = `${((performance.now() - started) / 1000).toFixed(2)}s`;
fileSize.textContent = bytesLabel(blob.size);
log(`生成成功。\nContent-Type: ${contentType || "unknown"}\nSize: ${bytesLabel(blob.size)}`);
} catch (error) {
log(`生成失败:\n${error.message}`, true);
} finally {
submitBtn.disabled = false;
}
});
</script>
</body>
</html>

View File

@ -0,0 +1,303 @@
import asyncio
import importlib.util
import sys
import tempfile
import types
import unittest
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
SIMPLE_API_PATH = PROJECT_ROOT / "simple_api.py"
class DummyHTTPException(Exception):
def __init__(self, status_code=500, detail=None):
super().__init__(detail)
self.status_code = status_code
self.detail = detail
class DummyFastAPI:
def __init__(self, *args, **kwargs):
self.routes = []
def add_middleware(self, *args, **kwargs):
pass
def mount(self, path, app, name=None):
self.routes.append(types.SimpleNamespace(path=path, endpoint=app, name=name))
def get(self, path):
return self._route(path)
def post(self, path):
return self._route(path)
def on_event(self, event):
return lambda func: func
def _route(self, path):
def decorator(func):
self.routes.append(types.SimpleNamespace(path=path, endpoint=func))
return func
return decorator
class DummyBaseModel:
def dict(self, exclude_none=False):
data = dict(self.__dict__)
if exclude_none:
data = {key: value for key, value in data.items() if value is not None}
return data
class DummyTTSConfig:
version = "v2"
languages = ["zh", "en", "auto"]
use_vocoder = False
class FakeUpload:
def __init__(self, filename, chunks, content_type="audio/wav"):
self.filename = filename
self.content_type = content_type
self._chunks = list(chunks)
self.closed = False
async def read(self, size):
del size
if self._chunks:
return self._chunks.pop(0)
return b""
async def close(self):
self.closed = True
def install_dummy_modules():
fastapi = types.ModuleType("fastapi")
fastapi.FastAPI = DummyFastAPI
fastapi.File = lambda default=None, **kwargs: default
fastapi.Form = lambda default=None, **kwargs: default
fastapi.HTTPException = DummyHTTPException
fastapi.Response = type("Response", (), {})
fastapi.UploadFile = type("UploadFile", (), {})
sys.modules["fastapi"] = fastapi
middleware = types.ModuleType("fastapi.middleware")
cors = types.ModuleType("fastapi.middleware.cors")
cors.CORSMiddleware = type("CORSMiddleware", (), {})
sys.modules["fastapi.middleware"] = middleware
sys.modules["fastapi.middleware.cors"] = cors
responses = types.ModuleType("fastapi.responses")
responses.JSONResponse = type("JSONResponse", (), {})
responses.StreamingResponse = type("StreamingResponse", (), {})
sys.modules["fastapi.responses"] = responses
staticfiles = types.ModuleType("fastapi.staticfiles")
staticfiles.StaticFiles = type("StaticFiles", (), {"__init__": lambda self, *args, **kwargs: None})
sys.modules["fastapi.staticfiles"] = staticfiles
pydantic = types.ModuleType("pydantic")
pydantic.BaseModel = DummyBaseModel
pydantic.Field = lambda default=None, **kwargs: default
sys.modules["pydantic"] = pydantic
numpy = types.ModuleType("numpy")
numpy.ndarray = object
sys.modules["numpy"] = numpy
soundfile = types.ModuleType("soundfile")
soundfile.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
sys.modules["soundfile"] = soundfile
yaml = types.ModuleType("yaml")
yaml.safe_load = lambda stream: {
"server": {"host": "127.0.0.1", "port": 9881, "tts_config": "dummy.yaml"},
"upload": {"dir": "runtime/test_uploads", "min_ref_seconds": 3, "max_ref_seconds": 10, "max_upload_mb": 1},
"defaults": {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "text_split_method": "cut5"},
"emotion_presets": {"calm": {"temperature": 0.8, "speed_factor": 0.9}},
"voices": {},
}
sys.modules["yaml"] = yaml
sys.modules["GPT_SoVITS"] = types.ModuleType("GPT_SoVITS")
sys.modules["GPT_SoVITS.TTS_infer_pack"] = types.ModuleType("GPT_SoVITS.TTS_infer_pack")
tts_module = types.ModuleType("GPT_SoVITS.TTS_infer_pack.TTS")
tts_module.TTS = type("DummyTTS", (), {})
tts_module.TTS_Config = DummyTTSConfig
sys.modules["GPT_SoVITS.TTS_infer_pack.TTS"] = tts_module
segmentation = types.ModuleType("GPT_SoVITS.TTS_infer_pack.text_segmentation_method")
segmentation.get_method_names = lambda: ["cut0", "cut5"]
sys.modules["GPT_SoVITS.TTS_infer_pack.text_segmentation_method"] = segmentation
def load_simple_api():
install_dummy_modules()
sys.argv = ["simple_api.py"]
spec = importlib.util.spec_from_file_location("simple_api_under_test", SIMPLE_API_PATH)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module.tts_config = DummyTTSConfig()
return module
class SimpleApiContractTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.simple_api = load_simple_api()
def test_api_tts_route_is_registered(self):
routes = {route.path for route in self.simple_api.APP.routes}
self.assertIn("/api/tts", routes)
def test_build_request_accepts_explicit_ref_without_voice_profile(self):
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
req, media_type, response_stream = self.simple_api.build_tts_request(
{
"text": "Hello, test.",
"ref_audio_path": ref.name,
"prompt_text": "",
"prompt_lang": "zh",
"text_lang": "zh",
"format": "wav",
"text_split_method": "cut5",
"streaming_mode": 0,
}
)
self.assertEqual(req["prompt_text"], "")
self.assertEqual(req["text_split_method"], "cut5")
self.assertEqual(req["ref_audio_path"], ref.name)
self.assertEqual(media_type, "wav")
self.assertFalse(response_stream)
def test_explicit_speed_overrides_emotion_speed_preset(self):
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
payload = {
"text": "Hello.",
"ref_audio_path": ref.name,
"prompt_lang": "zh",
"text_lang": "zh",
"format": "wav",
"streaming_mode": 0,
"speed": 1.05,
}
self.simple_api.apply_emotion_preset(payload, "calm")
payload["speed"] = 1.05
req, _, _ = self.simple_api.build_tts_request(payload)
self.assertEqual(req["temperature"], 0.8)
self.assertEqual(req["speed_factor"], 1.05)
def test_empty_prompt_text_is_rejected_for_vocoder_models(self):
original_use_vocoder = self.simple_api.tts_config.use_vocoder
original_version = self.simple_api.tts_config.version
self.simple_api.tts_config.use_vocoder = True
self.simple_api.tts_config.version = "v3"
try:
with tempfile.NamedTemporaryFile(suffix=".wav") as ref:
with self.assertRaises(DummyHTTPException) as exc:
self.simple_api.build_tts_request(
{
"text": "Hello.",
"ref_audio_path": ref.name,
"prompt_text": "",
"prompt_lang": "zh",
"text_lang": "zh",
"format": "wav",
"text_split_method": "cut5",
"streaming_mode": 0,
}
)
self.assertEqual(exc.exception.status_code, 400)
self.assertIn("prompt_text is required", str(exc.exception.detail))
finally:
self.simple_api.tts_config.use_vocoder = original_use_vocoder
self.simple_api.tts_config.version = original_version
def test_ref_audio_duration_limits(self):
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
self.simple_api.validate_main_ref_duration("ok.wav")
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 2, samplerate=16000)
with self.assertRaises(DummyHTTPException) as too_short:
self.simple_api.validate_main_ref_duration("short.wav")
self.assertEqual(too_short.exception.status_code, 400)
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 11, samplerate=16000)
with self.assertRaises(DummyHTTPException) as too_long:
self.simple_api.validate_main_ref_duration("long.wav")
self.assertEqual(too_long.exception.status_code, 400)
def test_save_upload_file_writes_and_closes(self):
with tempfile.TemporaryDirectory() as tmp:
upload = FakeUpload("ref.wav", [b"abc", b"def"])
saved_path = asyncio.run(self.simple_api.save_upload_file(upload, Path(tmp), "ref"))
self.assertTrue(upload.closed)
self.assertEqual(Path(saved_path).read_bytes(), b"abcdef")
def test_mvp_tts_builds_payload_from_uploads_and_cleans_tmp_dir(self):
captured = {}
with tempfile.TemporaryDirectory() as tmp:
request_dir = Path(tmp) / "request"
request_dir.mkdir()
original_get_request_upload_dir = self.simple_api.get_request_upload_dir
original_synthesize_response = self.simple_api.synthesize_response
original_sf_info = self.simple_api.sf.info
def fake_synthesize_response(payload):
captured.update(payload)
return "audio-response"
try:
self.simple_api.get_request_upload_dir = lambda: request_dir
self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000)
self.simple_api.synthesize_response = fake_synthesize_response
response = asyncio.run(
self.simple_api.mvp_tts(
text="Hello, test.",
ref_audio=FakeUpload("ref.wav", [b"main"]),
aux_ref_audio=[FakeUpload("aux.wav", [b"aux"])],
prompt_text="",
text_lang="zh",
prompt_lang="zh",
format="wav",
emotion="calm",
speed=1.05,
seed=123,
)
)
finally:
self.simple_api.get_request_upload_dir = original_get_request_upload_dir
self.simple_api.synthesize_response = original_synthesize_response
self.simple_api.sf.info = original_sf_info
self.assertEqual(response, "audio-response")
self.assertFalse(request_dir.exists())
self.assertEqual(captured["text"], "Hello, test.")
self.assertEqual(captured["prompt_text"], "")
self.assertEqual(captured["text_lang"], "zh")
self.assertEqual(captured["prompt_lang"], "zh")
self.assertEqual(captured["format"], "wav")
self.assertEqual(captured["text_split_method"], "cut5")
self.assertEqual(captured["streaming_mode"], 0)
self.assertEqual(captured["seed"], 123)
self.assertEqual(captured["temperature"], 0.8)
self.assertEqual(captured["speed"], 1.05)
self.assertEqual(len(captured["aux_ref_audio_paths"]), 1)
if __name__ == "__main__":
unittest.main()