diff --git a/README.md b/README.md index 2511c73c..24abdc66 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -
+

GPT-SoVITS-WebUI

A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.

@@ -27,6 +27,24 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.

--- +## 简化接口 / 测试前端 + +本工作区新增了一个用于 MVP 调用的中间层接口和测试前端: + +- 快速启动:[docs/simple_api_quickstart.md](./docs/simple_api_quickstart.md) +- 完整教程:[docs/simple_api.md](./docs/simple_api.md) +- 后端入口:simple_api.py +- 测试前端:启动后访问 http://127.0.0.1:9881/test/ + +启动命令: + +`powershell +cd D:\tts\GPT-SoVITS +.\go-simple-api.ps1 +` + +--- + ## Features: 1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion. @@ -478,3 +496,4 @@ Thankful to @Naozumi520 for providing the Cantonese training set and for the gui + diff --git a/docs/simple_api.md b/docs/simple_api.md new file mode 100644 index 00000000..bcdedaf6 --- /dev/null +++ b/docs/simple_api.md @@ -0,0 +1,320 @@ +# GPT-SoVITS 简化接口教程 + +本项目原本已经有 `api_v2.py`,但调用时需要传很多 GPT-SoVITS 参数。新增的 `simple_api.py` 是一个中间层,目标是让前端或业务方用更简单的方式调用。 + +当前 MVP 推荐使用: + +```http +POST /api/tts +Content-Type: multipart/form-data +``` + +适合你的流程: + +1. 前端从视频中提取音频(或直接上传已裁剪的 3-10 秒音频)。 +2. 用户人工裁剪主参考音频到 3-10 秒。 +3. 前端把主参考音频、要生成的文字、可选辅助音频提交给后端。 +4. 后端调用 GPT-SoVITS 生成音频并返回。 + +## 1. 安装依赖 + +在 GPT-SoVITS 使用的 Python 环境里运行: + +```bash +python -m pip install -r requirements.txt +``` + +本中间层额外需要: + +- `python-multipart`:用于接收前端上传的音频文件。 +- `soundfile`:用于读取音频信息和校验参考音频时长。 + +这两个依赖已经写入 `requirements.txt`。 + +## 2. 检查配置 + +配置文件: + +```text +simple_api.yaml +``` + +常用配置: + +```yaml +server: + host: 127.0.0.1 + port: 9881 + tts_config: GPT_SoVITS/configs/tts_infer.yaml + +upload: + dir: runtime/uploads + min_ref_seconds: 3 + max_ref_seconds: 10 + max_upload_mb: 80 +``` + +说明: + +- `server.host`:后端监听地址。 +- `server.port`:后端端口。 +- `server.tts_config`:GPT-SoVITS 推理配置文件。 +- `upload.dir`:临时上传目录。 +- `upload.min_ref_seconds`:主参考音频最短秒数。 +- `upload.max_ref_seconds`:主参考音频最长秒数。 +- `upload.max_upload_mb`:单个上传音频最大体积。 + +当前默认后端地址: + +```text +http://127.0.0.1:9881 +``` + +## 3. 启动后端 + +进入项目目录: + +```powershell +cd D:\tts\GPT-SoVITS +``` + +方式一: + +```powershell +python simple_api.py -c simple_api.yaml +``` + +方式二,Windows PowerShell: + +```powershell +.\go-simple-api.ps1 +``` + +方式三,Windows 批处理: + +```bat +go-simple-api.bat +``` + +启动成功后,后端会监听: + +```text +http://127.0.0.1:9881 +``` + +## 4. 健康检查 + +后端启动后,可以访问: + +```text +http://127.0.0.1:9881/health +``` + +或者命令行测试: + +```bash +curl http://127.0.0.1:9881/health +``` + +返回示例: + +```json +{ + "status": "ok", + "tts_config": "GPT_SoVITS/configs/tts_infer.yaml", + "version": "v2", + "languages": ["auto", "auto_yue", "en", "zh"] +} +``` + +## 5. 打开测试前端 + +后端启动后,推荐直接打开: + +```text +http://127.0.0.1:9881/test/ +``` + +这个页面已经挂载在后端服务里,不需要额外启动前端服务。 + +也可以直接打开本地文件: + +```text +test_frontend/index.html +``` + +Windows 下也可以运行: + +```powershell +.\open-test-frontend.ps1 +``` + +测试前端里有一个“后端接口地址”输入框。 + +默认值: + +```text +http://127.0.0.1:9881/api/tts +``` + +如果你修改了 `server.host` 或 `server.port`,记得同步修改页面里的接口地址。 + +## 6. MVP 接口 + +接口地址: + +```http +POST /api/tts +``` + +请求类型: + +```http +multipart/form-data +``` + +字段说明: + +```text +text 必填。需要生成的文字。 +ref_audio 必填。主参考音频(支持上传视频,前端会自动提取音频),要求 3-10 秒。 +aux_ref_audio 可选。辅助参考音频,可以上传多个。 +prompt_text 可选。主参考音频对应文字,可以留空。 +text_lang 可选。生成文字语言,默认 zh。 +prompt_lang 可选。参考音频语言,默认 zh。即使 prompt_text 为空,也需要传给 GPT-SoVITS 内部。 +format 可选。返回格式,默认 wav。 +emotion 可选。情绪 preset:neutral、happy、calm、sad、angry。 +speed 可选。语速,对应 GPT-SoVITS 的 speed_factor。 +seed 可选。随机种子,默认 -1。 +``` + +## 7. curl 调用示例 + +PowerShell 示例: + +```powershell +curl.exe -X POST http://127.0.0.1:9881/api/tts ` + -F "text=你好,欢迎使用这个声音。" ` + -F "ref_audio=@D:\audio\ref.wav" ` + -F "prompt_text=" ` + -F "text_lang=zh" ` + -F "prompt_lang=zh" ` + -F "emotion=neutral" ` + --output output.wav +``` + +如果有辅助参考音频: + +```powershell +curl.exe -X POST http://127.0.0.1:9881/api/tts ` + -F "text=你好,欢迎使用这个声音。" ` + -F "ref_audio=@D:\audio\ref.wav" ` + -F "aux_ref_audio=@D:\audio\aux1.wav" ` + -F "aux_ref_audio=@D:\audio\aux2.wav" ` + -F "prompt_text=" ` + -F "text_lang=zh" ` + -F "prompt_lang=zh" ` + --output output.wav +``` + +## 8. 重要规则 + +- 主参考音频必须是 3-10 秒。 +- `aux_ref_audio` 是可选项。 +- `prompt_text` 可以为空,但当前主要针对 GPT-SoVITS v2。 +- 如果切到 GPT-SoVITS v3/v4,空 `prompt_text` 会被中间层直接返回 400。 +- 生成文字固定使用 `cut5`,也就是按照标点符号切句。 +- `emotion` 目前是轻量 preset,本质是映射到采样和语速参数;更稳定的情绪控制仍然依赖带情绪的参考音频。 + +## 9. 测试前端使用步骤 + +1. 启动后端。 +2. 打开 `http://127.0.0.1:9881/test/`。 +3. 检查“后端接口地址”是否为: + +```text +http://127.0.0.1:9881/api/tts +``` + +4. 填写“需要生成的文字”。 +5. 上传主参考音频。 +6. 可选上传辅助参考音频。 +7. 可选填写参考音频文字。 +8. 选择语言、情绪、语速。 +9. 点击“生成音频”。 +10. 页面右侧会显示返回结果,可以在线播放和下载。 + +## 10. 其他接口 + +除了 MVP 上传接口,还保留了 profile 调用接口: + +```http +GET /voices +GET /speak?text=hello&voice=default +POST /speak +POST /speak/base64 +POST /v1/tts +POST /admin/reload-config +POST /admin/weights +``` + +这些接口适合后续做固定音色 profile,不是当前 MVP 的主流程。 + +## 11. Base64 返回示例 + +```bash +curl -X POST http://127.0.0.1:9881/speak/base64 ^ + -H "Content-Type: application/json" ^ + -d "{\"text\":\"hello\",\"voice\":\"default\"}" +``` + +返回格式: + +```json +{ + "media_type": "audio/wav", + "audio_base64": "..." +} +``` + +## 12. 添加固定音色 profile + +如果后续要做固定音色,可以编辑 `simple_api.yaml`: + +```yaml +voices: + default: + ref_audio_path: reference.wav + prompt_text: exact transcript + prompt_lang: zh + text_lang: zh + + narrator: + ref_audio_path: voices/narrator.wav + prompt_text: exact transcript of narrator.wav + prompt_lang: zh + text_lang: zh +``` + +编辑后调用: + +```bash +curl -X POST http://127.0.0.1:9881/admin/reload-config +``` + +## 13. 契约测试 + +这个测试使用 mock,不会加载 GPT-SoVITS 模型: + +```bash +python -m unittest tests.test_simple_api_contract +``` + +用于确认: + +- `/api/tts` 路由存在。 +- 上传接口能构造正确参数。 +- 主参考音频 3-10 秒校验正常。 +- 空 `prompt_text` 在 v2 可用。 +- v3/v4 空 `prompt_text` 会返回 400。 +- 临时上传目录会被清理。 diff --git a/docs/simple_api_quickstart.md b/docs/simple_api_quickstart.md new file mode 100644 index 00000000..2f07eb65 --- /dev/null +++ b/docs/simple_api_quickstart.md @@ -0,0 +1,89 @@ +# 简化 TTS 接口快速启动 + +完整教程见:`docs/simple_api.md` + +## 一句话流程 + +启动后端,打开测试页,上传 3-10 秒参考音频或视频(视频会自动提取音频),填写生成文字,点击生成。 + +## 1. 启动后端 + +```powershell +cd D:\tts\GPT-SoVITS +python -m pip install -r requirements.txt +.\go-simple-api.ps1 +``` + +也可以运行: + +```bat +go-simple-api.bat +``` + +默认后端地址: + +```text +http://127.0.0.1:9881 +``` + +## 2. 打开测试前端 + +后端启动后访问: + +```text +http://127.0.0.1:9881/test/ +``` + +测试页面里的默认接口地址是: + +```text +http://127.0.0.1:9881/api/tts +``` + +## 3. 调用接口 + +接口: + +```http +POST /api/tts +Content-Type: multipart/form-data +``` + +最小字段: + +```text +text 要生成的文字 +ref_audio 3-10 秒主参考音频(支持视频,前端自动提取音频) +``` + +常用可选字段: + +```text +aux_ref_audio 辅助参考音频,可多个 +prompt_text 参考音频文本,可留空 +text_lang 默认 zh +prompt_lang 默认 zh +emotion neutral / happy / calm / sad / angry +speed 语速,默认 1 +seed 默认 -1 +``` + +## 4. PowerShell 示例 + +```powershell +curl.exe -X POST http://127.0.0.1:9881/api/tts ` + -F "text=你好,欢迎使用这个声音。" ` + -F "ref_audio=@D:\audio\ref.wav" ` + -F "prompt_text=" ` + -F "text_lang=zh" ` + -F "prompt_lang=zh" ` + --output output.wav +``` + +## 5. 注意事项 + +- 主参考音频必须是 3-10 秒。 +- `prompt_text` 在当前 v2 配置下可以为空。 +- 如果切换到 v3/v4,空 `prompt_text` 会返回 400。 +- 生成文字固定按标点符号切句。 +- 更详细的配置、profile、base64 接口见 `docs/simple_api.md`。 diff --git a/go-simple-api.bat b/go-simple-api.bat new file mode 100644 index 00000000..2c2c7d62 --- /dev/null +++ b/go-simple-api.bat @@ -0,0 +1,12 @@ +@echo off +set "SCRIPT_DIR=%~dp0" +set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%" +cd /d "%SCRIPT_DIR%" + +if exist "%SCRIPT_DIR%\runtime\python.exe" ( + "%SCRIPT_DIR%\runtime\python.exe" "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %* +) else ( + python "%SCRIPT_DIR%\simple_api.py" -c "%SCRIPT_DIR%\simple_api.yaml" %* +) + +pause diff --git a/go-simple-api.ps1 b/go-simple-api.ps1 new file mode 100644 index 00000000..4d1336c5 --- /dev/null +++ b/go-simple-api.ps1 @@ -0,0 +1,13 @@ +$ErrorActionPreference = "Stop" +chcp 65001 +Set-Location $PSScriptRoot + +$runtimePython = Join-Path $PSScriptRoot "runtime\python.exe" +$apiScript = Join-Path $PSScriptRoot "simple_api.py" +$apiConfig = Join-Path $PSScriptRoot "simple_api.yaml" + +if (Test-Path $runtimePython) { + & $runtimePython $apiScript -c $apiConfig @args +} else { + python $apiScript -c $apiConfig @args +} diff --git a/open-test-frontend.ps1 b/open-test-frontend.ps1 new file mode 100644 index 00000000..7e26d5a6 --- /dev/null +++ b/open-test-frontend.ps1 @@ -0,0 +1,3 @@ +$ErrorActionPreference = "Stop" +Set-Location $PSScriptRoot +Start-Process (Join-Path $PSScriptRoot "test_frontend\index.html") diff --git a/requirements.txt b/requirements.txt index 3b7cd898..caa9a448 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ numpy<2.0 scipy tensorboard librosa==0.10.2 +soundfile numba pytorch-lightning>=2.4 gradio<5 @@ -22,6 +23,7 @@ transformers>=4.43,<=4.50 peft<0.18.0 chardet PyYAML +python-multipart psutil jieba_fast jieba diff --git a/simple_api.py b/simple_api.py new file mode 100644 index 00000000..91180548 --- /dev/null +++ b/simple_api.py @@ -0,0 +1,741 @@ +""" +Small profile-based API layer for GPT-SoVITS. + +Run: + python simple_api.py -c simple_api.yaml + +Then call: + POST http://127.0.0.1:9881/speak + {"text": "hello", "voice": "default"} +""" + +from __future__ import annotations + +import argparse +import base64 +import os +import shutil +import subprocess +import sys +import threading +import traceback +import uuid +import wave +from copy import deepcopy +from io import BytesIO +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import numpy as np +import soundfile as sf +import yaml +from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel, Field + +PROJECT_ROOT = Path(__file__).resolve().parent +os.chdir(PROJECT_ROOT) +sys.path.append(str(PROJECT_ROOT)) +sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS")) + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402 +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402 + get_method_names as get_cut_method_names, +) + + +DEFAULT_TTS_PARAMS: Dict[str, Any] = { + "text": "", + "text_lang": "zh", + "ref_audio_path": "", + "aux_ref_audio_paths": [], + "prompt_text": "", + "prompt_lang": "zh", + "top_k": 15, + "top_p": 1.0, + "temperature": 1.0, + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "split_bucket": True, + "speed_factor": 1.0, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "return_fragment": False, + "fixed_length_chunk": False, + "parallel_infer": True, + "repetition_penalty": 1.35, + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, +} + +SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"} +SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"} +VOICE_META_KEYS = {"description"} + + +class SpeakRequest(BaseModel): + text: str + voice: Optional[str] = None + text_lang: Optional[str] = None + ref_audio_path: Optional[str] = None + aux_ref_audio_paths: Optional[List[str]] = None + prompt_text: Optional[str] = None + prompt_lang: Optional[str] = None + format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac") + stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2") + streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false") + speed: Optional[float] = Field(default=None, description="Alias of speed_factor") + speed_factor: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + temperature: Optional[float] = None + text_split_method: Optional[str] = None + batch_size: Optional[int] = None + batch_threshold: Optional[float] = None + split_bucket: Optional[bool] = None + fragment_interval: Optional[float] = None + seed: Optional[int] = None + parallel_infer: Optional[bool] = None + repetition_penalty: Optional[float] = None + sample_steps: Optional[int] = None + super_sampling: Optional[bool] = None + overlap_length: Optional[int] = None + min_chunk_length: Optional[int] = None + + +class WeightsRequest(BaseModel): + gpt_weights_path: Optional[str] = None + sovits_weights_path: Optional[str] = None + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="GPT-SoVITS simple API") + parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path") + parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path") + parser.add_argument("-a", "--bind-addr", default=None, help="bind address") + parser.add_argument("-p", "--port", type=int, default=None, help="bind port") + return parser.parse_args() + + +def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]: + path = Path(config_path) + if not path.is_absolute(): + path = PROJECT_ROOT / path + if not path.exists(): + raise FileNotFoundError(f"simple API config not found: {path}") + with path.open("r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + if not isinstance(data, dict): + raise ValueError("simple API config must be a YAML object") + return data + + +args = parse_args() +simple_config = load_yaml_config(args.config) +server_config = simple_config.get("server", {}) or {} +host = args.bind_addr or server_config.get("host", "127.0.0.1") +port = args.port or int(server_config.get("port", 9881)) +tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml") + +cut_method_names = get_cut_method_names() +tts_config: Optional[TTS_Config] = None +tts_pipeline: Optional[TTS] = None +infer_lock = threading.Lock() + +APP = FastAPI( + title="GPT-SoVITS Simple API", + description="Profile-based API layer that hides GPT-SoVITS request details.", + version="1.0.0", +) +APP.add_middleware( + CORSMiddleware, + allow_origins=simple_config.get("cors_allow_origins", ["*"]), + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], +) +test_frontend_dir = PROJECT_ROOT / "test_frontend" +if test_frontend_dir.exists(): + APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend") + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: + def handle_pack_ogg() -> None: + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + try: + threading.stack_size(4096 * 4096) + pack_thread = threading.Thread(target=handle_pack_ogg) + pack_thread.start() + pack_thread.join() + except (RuntimeError, ValueError): + handle_pack_ogg() + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: + del rate + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: + del io_buffer + wav_buffer = BytesIO() + sf.write(wav_buffer, data, rate, format="wav") + return wav_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO: + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", + "-ar", + str(rate), + "-ac", + "1", + "-i", + "pipe:0", + "-c:a", + "aac", + "-b:a", + "192k", + "-vn", + "-f", + "adts", + "pipe:1", + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO: + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes: + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + wav_buf.seek(0) + return wav_buf.read() + + +def request_to_dict(request: BaseModel) -> Dict[str, Any]: + if hasattr(request, "model_dump"): + return request.model_dump(exclude_none=True) + return request.dict(exclude_none=True) + + +def get_default_voice_name() -> Optional[str]: + default_voice = simple_config.get("default_voice") + if default_voice: + return str(default_voice) + voices = simple_config.get("voices", {}) or {} + if "default" in voices: + return "default" + return next(iter(voices.keys()), None) + + +def resolve_project_path(path_value: Optional[str]) -> Optional[str]: + if path_value in [None, ""]: + return None + path = Path(str(path_value)) + if not path.is_absolute(): + path = PROJECT_ROOT / path + return str(path) + + +def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]: + if streaming_mode is None: + streaming_mode = 2 if stream else 0 + elif isinstance(streaming_mode, bool): + streaming_mode = 2 if streaming_mode else 0 + + if streaming_mode == 0: + return False, False, False, False + if streaming_mode == 1: + return False, True, False, True + if streaming_mode == 2: + return True, False, False, True + if streaming_mode == 3: + return True, False, True, True + raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false") + + +def get_upload_config() -> Dict[str, Any]: + return simple_config.get("upload", {}) or {} + + +def get_upload_root() -> Path: + upload_dir = str(get_upload_config().get("dir", "runtime/uploads")) + root = Path(upload_dir) + if not root.is_absolute(): + root = PROJECT_ROOT / root + root.mkdir(parents=True, exist_ok=True) + return root + + +def get_request_upload_dir() -> Path: + request_dir = get_upload_root() / uuid.uuid4().hex + request_dir.mkdir(parents=True, exist_ok=True) + return request_dir + + +def get_upload_limits() -> Tuple[float, float, int]: + upload_config = get_upload_config() + min_seconds = float(upload_config.get("min_ref_seconds", 3.0)) + max_seconds = float(upload_config.get("max_ref_seconds", 10.0)) + max_upload_mb = int(upload_config.get("max_upload_mb", 80)) + return min_seconds, max_seconds, max_upload_mb + + +def get_upload_suffix(upload: UploadFile) -> str: + suffix = Path(upload.filename or "").suffix.lower() + if suffix: + return suffix + content_type = (upload.content_type or "").lower() + if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}: + return ".wav" + if content_type == "audio/mpeg": + return ".mp3" + if content_type == "audio/ogg": + return ".ogg" + if content_type in {"audio/aac", "audio/aacp"}: + return ".aac" + return ".wav" + + +def validate_audio_upload(upload: UploadFile) -> None: + suffix = get_upload_suffix(upload) + content_type = (upload.content_type or "").lower() + is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""} + if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type: + raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}") + + +async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str: + validate_audio_upload(upload) + _, _, max_upload_mb = get_upload_limits() + max_bytes = max_upload_mb * 1024 * 1024 + suffix = get_upload_suffix(upload) + target_path = target_dir / f"{name_prefix}{suffix}" + total_size = 0 + + try: + with target_path.open("wb") as f: + while True: + chunk = await upload.read(1024 * 1024) + if not chunk: + break + total_size += len(chunk) + if total_size > max_bytes: + raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB") + f.write(chunk) + if total_size == 0: + raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}") + finally: + await upload.close() + + return str(target_path) + + +def get_audio_duration_with_ffprobe(audio_path: str) -> float: + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + audio_path, + ], + capture_output=True, + text=True, + timeout=15, + ) + if result.returncode != 0: + raise RuntimeError(result.stderr.strip() or "ffprobe failed") + return float(result.stdout.strip()) + + +def get_audio_duration_seconds(audio_path: str) -> float: + try: + info = sf.info(audio_path) + if info.samplerate <= 0: + raise ValueError("invalid sample rate") + return float(info.frames) / float(info.samplerate) + except Exception as sf_exc: + try: + return get_audio_duration_with_ffprobe(audio_path) + except Exception as ffprobe_exc: + raise HTTPException( + status_code=400, + detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}", + ) from ffprobe_exc + + +def validate_main_ref_duration(audio_path: str) -> None: + min_seconds, max_seconds, _ = get_upload_limits() + duration = get_audio_duration_seconds(audio_path) + if duration < min_seconds or duration > max_seconds: + raise HTTPException( + status_code=400, + detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s", + ) + + +def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None: + if emotion in [None, ""]: + return + emotion_name = str(emotion).strip().lower() + emotion_presets = simple_config.get("emotion_presets", {}) or {} + if emotion_name not in emotion_presets: + supported = ", ".join(sorted(emotion_presets.keys())) or "none" + raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}") + payload.update(emotion_presets[emotion_name] or {}) + + +def prompt_text_required_for_current_model() -> bool: + if tts_config is None: + return False + use_vocoder = getattr(tts_config, "use_vocoder", None) + if use_vocoder is not None: + return bool(use_vocoder) + return getattr(tts_config, "version", None) in {"v3", "v4"} + + +def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]: + ref_audio_path = resolve_project_path(profile.get("ref_audio_path")) + return { + "name": name, + "description": profile.get("description", ""), + "text_lang": profile.get("text_lang"), + "prompt_lang": profile.get("prompt_lang"), + "ref_audio_path": profile.get("ref_audio_path"), + "ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")), + } + + +def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]: + voices = simple_config.get("voices", {}) or {} + explicit_ref_audio = bool(payload.get("ref_audio_path")) + voice_name = payload.pop("voice", None) + if voice_name is None and not explicit_ref_audio: + voice_name = get_default_voice_name() + voice_profile: Dict[str, Any] = {} + + if voice_name: + if voice_name not in voices: + raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}") + voice_profile = deepcopy(voices[voice_name] or {}) + + tts_req = deepcopy(DEFAULT_TTS_PARAMS) + tts_req.update(simple_config.get("defaults", {}) or {}) + tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS}) + + stream = payload.pop("stream", None) + media_type = payload.pop("format", None) + speed = payload.pop("speed", None) + if media_type is not None: + tts_req["media_type"] = media_type + + for key, value in payload.items(): + if key in DEFAULT_TTS_PARAMS: + tts_req[key] = value + if speed is not None: + tts_req["speed_factor"] = speed + + ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path")) + if ref_audio_path: + tts_req["ref_audio_path"] = ref_audio_path + + aux_paths = tts_req.get("aux_ref_audio_paths") or [] + tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item] + + media_type = str(tts_req.get("media_type", "wav")).lower() + tts_req["media_type"] = media_type + if media_type not in SUPPORTED_MEDIA_TYPES: + raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}") + + text = str(tts_req.get("text") or "").strip() + if not text: + raise HTTPException(status_code=400, detail="text is required") + tts_req["text"] = text + + if not tts_req.get("ref_audio_path"): + raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request") + if not Path(tts_req["ref_audio_path"]).exists(): + raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}") + + for aux_path in tts_req["aux_ref_audio_paths"]: + if aux_path and not Path(aux_path).exists(): + raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}") + + if tts_config is None: + raise HTTPException(status_code=503, detail="TTS pipeline is not ready") + + text_lang = str(tts_req.get("text_lang") or "").lower() + prompt_lang = str(tts_req.get("prompt_lang") or "").lower() + tts_req["text_lang"] = text_lang + tts_req["prompt_lang"] = prompt_lang + + if text_lang not in tts_config.languages: + raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}") + if prompt_lang not in tts_config.languages: + raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}") + if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model(): + version = getattr(tts_config, "version", "current") + raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}") + + split_method = str(tts_req.get("text_split_method") or "cut5") + if split_method not in cut_method_names: + raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}") + + streaming_mode = tts_req.pop("streaming_mode", None) + streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream) + tts_req["streaming_mode"] = streaming_enabled + tts_req["return_fragment"] = return_fragment + tts_req["fixed_length_chunk"] = fixed_length_chunk + + return tts_req, media_type, response_stream + + +def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]: + tts_req, media_type, response_stream = build_tts_request(payload) + if response_stream: + raise HTTPException(status_code=400, detail="base64 output does not support streaming") + + if tts_pipeline is None: + raise HTTPException(status_code=503, detail="TTS pipeline is not ready") + + try: + with infer_lock: + sr, audio_data = next(tts_pipeline.run(tts_req)) + audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return audio_bytes, media_type + except Exception as exc: + raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc + + +def synthesize_response(payload: Dict[str, Any]) -> Response: + tts_req, media_type, response_stream = build_tts_request(payload) + if tts_pipeline is None: + raise HTTPException(status_code=503, detail="TTS pipeline is not ready") + + if response_stream: + + def streaming_generator() -> Generator[bytes, None, None]: + first_chunk = True + chunk_media_type = media_type + with infer_lock: + for sr, chunk in tts_pipeline.run(tts_req): + if first_chunk and chunk_media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + chunk_media_type = "raw" + first_chunk = False + yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue() + + return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}") + + try: + with infer_lock: + sr, audio_data = next(tts_pipeline.run(tts_req)) + audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_bytes, media_type=f"audio/{media_type}") + except Exception as exc: + return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)}) + + +@APP.on_event("startup") +def startup() -> None: + global tts_config, tts_pipeline + tts_config = TTS_Config(tts_config_path) + print(tts_config) + tts_pipeline = TTS(tts_config) + + +@APP.get("/") +def index() -> Dict[str, Any]: + return { + "name": "GPT-SoVITS Simple API", + "endpoints": [ + "/health", + "/voices", + "/api/tts", + "/speak", + "/speak/base64", + "/admin/weights", + "/admin/reload-config", + ], + } + + +@APP.get("/health") +def health() -> Dict[str, Any]: + return { + "status": "ok" if tts_pipeline is not None else "starting", + "tts_config": tts_config_path, + "version": getattr(tts_config, "version", None), + "languages": getattr(tts_config, "languages", []), + } + + +@APP.get("/voices") +def list_voices() -> Dict[str, Any]: + voices = simple_config.get("voices", {}) or {} + return { + "default_voice": get_default_voice_name(), + "voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()], + } + + +@APP.post("/api/tts") +async def mvp_tts( + text: str = Form(...), + ref_audio: UploadFile = File(...), + aux_ref_audio: Optional[List[UploadFile]] = File(default=None), + prompt_text: str = Form(default=""), + text_lang: str = Form(default="zh"), + prompt_lang: str = Form(default="zh"), + format: str = Form(default="wav"), + emotion: Optional[str] = Form(default=None), + speed: Optional[float] = Form(default=None), + seed: int = Form(default=-1), +) -> Response: + request_dir = get_request_upload_dir() + try: + ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref") + validate_main_ref_duration(ref_audio_path) + + aux_ref_audio_paths = [] + for index, upload in enumerate(aux_ref_audio or []): + aux_path = await save_upload_file(upload, request_dir, f"aux_{index}") + get_audio_duration_seconds(aux_path) + aux_ref_audio_paths.append(aux_path) + + payload: Dict[str, Any] = { + "text": text, + "text_lang": text_lang, + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": aux_ref_audio_paths, + "prompt_text": prompt_text or "", + "prompt_lang": prompt_lang, + "format": format, + "text_split_method": "cut5", + "streaming_mode": 0, + "seed": seed, + } + apply_emotion_preset(payload, emotion) + payload["text_split_method"] = "cut5" + payload["streaming_mode"] = 0 + if speed is not None: + payload["speed"] = speed + + return synthesize_response(payload) + finally: + shutil.rmtree(request_dir, ignore_errors=True) + + +@APP.get("/speak") +def speak_get( + text: str, + voice: Optional[str] = None, + text_lang: Optional[str] = None, + format: Optional[str] = None, + stream: Optional[bool] = None, + speed: Optional[float] = None, +) -> Response: + payload = { + "text": text, + "voice": voice, + "text_lang": text_lang, + "format": format, + "stream": stream, + "speed": speed, + } + return synthesize_response({k: v for k, v in payload.items() if v is not None}) + + +@APP.post("/speak") +def speak_post(request: SpeakRequest) -> Response: + return synthesize_response(request_to_dict(request)) + + +@APP.post("/v1/tts") +def openai_style_tts(request: SpeakRequest) -> Response: + return synthesize_response(request_to_dict(request)) + + +@APP.post("/speak/base64") +def speak_base64(request: SpeakRequest) -> Dict[str, Any]: + audio_bytes, media_type = synthesize_once(request_to_dict(request)) + return { + "media_type": f"audio/{media_type}", + "audio_base64": base64.b64encode(audio_bytes).decode("ascii"), + } + + +@APP.post("/admin/reload-config") +def reload_config() -> Dict[str, Any]: + global simple_config + simple_config = load_yaml_config(args.config) + return {"message": "success", "default_voice": get_default_voice_name()} + + +@APP.post("/admin/weights") +def set_weights(request: WeightsRequest) -> Dict[str, Any]: + if tts_pipeline is None: + raise HTTPException(status_code=503, detail="TTS pipeline is not ready") + if not request.gpt_weights_path and not request.sovits_weights_path: + raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required") + + try: + with infer_lock: + if request.gpt_weights_path: + tts_pipeline.init_t2s_weights(request.gpt_weights_path) + if request.sovits_weights_path: + tts_pipeline.init_vits_weights(request.sovits_weights_path) + except Exception as exc: + raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc + + return {"message": "success"} + + +if __name__ == "__main__": + import uvicorn + + try: + uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1) + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/simple_api.yaml b/simple_api.yaml new file mode 100644 index 00000000..4c02394f --- /dev/null +++ b/simple_api.yaml @@ -0,0 +1,59 @@ +server: + host: 127.0.0.1 + port: 9881 + tts_config: GPT_SoVITS/configs/tts_infer.yaml + +cors_allow_origins: + - "*" + +upload: + dir: runtime/uploads + min_ref_seconds: 3 + max_ref_seconds: 10 + max_upload_mb: 80 + +default_voice: default + +defaults: + text_lang: zh + prompt_lang: zh + media_type: wav + text_split_method: cut5 + batch_size: 1 + batch_threshold: 0.75 + split_bucket: true + speed_factor: 1.0 + fragment_interval: 0.3 + seed: -1 + parallel_infer: true + repetition_penalty: 1.35 + sample_steps: 32 + super_sampling: false + overlap_length: 2 + min_chunk_length: 16 + +emotion_presets: + neutral: {} + happy: + temperature: 1.1 + top_p: 0.95 + calm: + temperature: 0.8 + top_p: 0.85 + speed_factor: 0.92 + sad: + temperature: 0.75 + top_p: 0.85 + speed_factor: 0.9 + angry: + temperature: 1.2 + top_k: 20 + repetition_penalty: 1.25 + +voices: + default: + description: Replace this profile with your reference voice. + ref_audio_path: reference.wav + prompt_text: Replace this with the exact text spoken in reference.wav. + prompt_lang: zh + text_lang: zh diff --git a/test_frontend/index.html b/test_frontend/index.html new file mode 100644 index 00000000..c7e1b5ba --- /dev/null +++ b/test_frontend/index.html @@ -0,0 +1,713 @@ + + + + + + GPT-SoVITS API Test + + + +
+
+
+

GPT-SoVITS 接口测试台

+

选择 3-10 秒参考音频或视频(视频会自动提取音频),填写后端接口地址和生成文本,直接调用中间层 /api/tts

+
+
+ 未检测 + 后端连接状态 +
+
+ +
+
+
+
+
+ + +

如果后端端口或主机变了,在这里改完整地址。页面会把表单直接 POST 到这个地址。

+
+ +
+ + +
+ +
+ + +
请选择 3-10 秒音频或视频(视频会自动提取音频)
+ +
+ +
+ + +
可选,可多选
+
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +

显式语速会覆盖情绪 preset 中的语速。

+
+ +
+ + +
+ +
+ + +
+
+ +
+ + + +
+
+
+ +
+

返回结果

+
+
耗时-
+
文件大小-
+
+
等待请求。
+ + +
+
+
+ + + + diff --git a/tests/test_simple_api_contract.py b/tests/test_simple_api_contract.py new file mode 100644 index 00000000..4cfa3a3e --- /dev/null +++ b/tests/test_simple_api_contract.py @@ -0,0 +1,303 @@ +import asyncio +import importlib.util +import sys +import tempfile +import types +import unittest +from pathlib import Path + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +SIMPLE_API_PATH = PROJECT_ROOT / "simple_api.py" + + +class DummyHTTPException(Exception): + def __init__(self, status_code=500, detail=None): + super().__init__(detail) + self.status_code = status_code + self.detail = detail + + +class DummyFastAPI: + def __init__(self, *args, **kwargs): + self.routes = [] + + def add_middleware(self, *args, **kwargs): + pass + + def mount(self, path, app, name=None): + self.routes.append(types.SimpleNamespace(path=path, endpoint=app, name=name)) + + def get(self, path): + return self._route(path) + + def post(self, path): + return self._route(path) + + def on_event(self, event): + return lambda func: func + + def _route(self, path): + def decorator(func): + self.routes.append(types.SimpleNamespace(path=path, endpoint=func)) + return func + + return decorator + + +class DummyBaseModel: + def dict(self, exclude_none=False): + data = dict(self.__dict__) + if exclude_none: + data = {key: value for key, value in data.items() if value is not None} + return data + + +class DummyTTSConfig: + version = "v2" + languages = ["zh", "en", "auto"] + use_vocoder = False + + +class FakeUpload: + def __init__(self, filename, chunks, content_type="audio/wav"): + self.filename = filename + self.content_type = content_type + self._chunks = list(chunks) + self.closed = False + + async def read(self, size): + del size + if self._chunks: + return self._chunks.pop(0) + return b"" + + async def close(self): + self.closed = True + + +def install_dummy_modules(): + fastapi = types.ModuleType("fastapi") + fastapi.FastAPI = DummyFastAPI + fastapi.File = lambda default=None, **kwargs: default + fastapi.Form = lambda default=None, **kwargs: default + fastapi.HTTPException = DummyHTTPException + fastapi.Response = type("Response", (), {}) + fastapi.UploadFile = type("UploadFile", (), {}) + sys.modules["fastapi"] = fastapi + + middleware = types.ModuleType("fastapi.middleware") + cors = types.ModuleType("fastapi.middleware.cors") + cors.CORSMiddleware = type("CORSMiddleware", (), {}) + sys.modules["fastapi.middleware"] = middleware + sys.modules["fastapi.middleware.cors"] = cors + + responses = types.ModuleType("fastapi.responses") + responses.JSONResponse = type("JSONResponse", (), {}) + responses.StreamingResponse = type("StreamingResponse", (), {}) + sys.modules["fastapi.responses"] = responses + + staticfiles = types.ModuleType("fastapi.staticfiles") + staticfiles.StaticFiles = type("StaticFiles", (), {"__init__": lambda self, *args, **kwargs: None}) + sys.modules["fastapi.staticfiles"] = staticfiles + + pydantic = types.ModuleType("pydantic") + pydantic.BaseModel = DummyBaseModel + pydantic.Field = lambda default=None, **kwargs: default + sys.modules["pydantic"] = pydantic + + numpy = types.ModuleType("numpy") + numpy.ndarray = object + sys.modules["numpy"] = numpy + + soundfile = types.ModuleType("soundfile") + soundfile.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000) + sys.modules["soundfile"] = soundfile + + yaml = types.ModuleType("yaml") + yaml.safe_load = lambda stream: { + "server": {"host": "127.0.0.1", "port": 9881, "tts_config": "dummy.yaml"}, + "upload": {"dir": "runtime/test_uploads", "min_ref_seconds": 3, "max_ref_seconds": 10, "max_upload_mb": 1}, + "defaults": {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "text_split_method": "cut5"}, + "emotion_presets": {"calm": {"temperature": 0.8, "speed_factor": 0.9}}, + "voices": {}, + } + sys.modules["yaml"] = yaml + + sys.modules["GPT_SoVITS"] = types.ModuleType("GPT_SoVITS") + sys.modules["GPT_SoVITS.TTS_infer_pack"] = types.ModuleType("GPT_SoVITS.TTS_infer_pack") + + tts_module = types.ModuleType("GPT_SoVITS.TTS_infer_pack.TTS") + tts_module.TTS = type("DummyTTS", (), {}) + tts_module.TTS_Config = DummyTTSConfig + sys.modules["GPT_SoVITS.TTS_infer_pack.TTS"] = tts_module + + segmentation = types.ModuleType("GPT_SoVITS.TTS_infer_pack.text_segmentation_method") + segmentation.get_method_names = lambda: ["cut0", "cut5"] + sys.modules["GPT_SoVITS.TTS_infer_pack.text_segmentation_method"] = segmentation + + +def load_simple_api(): + install_dummy_modules() + sys.argv = ["simple_api.py"] + spec = importlib.util.spec_from_file_location("simple_api_under_test", SIMPLE_API_PATH) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module.tts_config = DummyTTSConfig() + return module + + +class SimpleApiContractTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.simple_api = load_simple_api() + + def test_api_tts_route_is_registered(self): + routes = {route.path for route in self.simple_api.APP.routes} + self.assertIn("/api/tts", routes) + + def test_build_request_accepts_explicit_ref_without_voice_profile(self): + with tempfile.NamedTemporaryFile(suffix=".wav") as ref: + req, media_type, response_stream = self.simple_api.build_tts_request( + { + "text": "Hello, test.", + "ref_audio_path": ref.name, + "prompt_text": "", + "prompt_lang": "zh", + "text_lang": "zh", + "format": "wav", + "text_split_method": "cut5", + "streaming_mode": 0, + } + ) + + self.assertEqual(req["prompt_text"], "") + self.assertEqual(req["text_split_method"], "cut5") + self.assertEqual(req["ref_audio_path"], ref.name) + self.assertEqual(media_type, "wav") + self.assertFalse(response_stream) + + def test_explicit_speed_overrides_emotion_speed_preset(self): + with tempfile.NamedTemporaryFile(suffix=".wav") as ref: + payload = { + "text": "Hello.", + "ref_audio_path": ref.name, + "prompt_lang": "zh", + "text_lang": "zh", + "format": "wav", + "streaming_mode": 0, + "speed": 1.05, + } + self.simple_api.apply_emotion_preset(payload, "calm") + payload["speed"] = 1.05 + req, _, _ = self.simple_api.build_tts_request(payload) + + self.assertEqual(req["temperature"], 0.8) + self.assertEqual(req["speed_factor"], 1.05) + + def test_empty_prompt_text_is_rejected_for_vocoder_models(self): + original_use_vocoder = self.simple_api.tts_config.use_vocoder + original_version = self.simple_api.tts_config.version + self.simple_api.tts_config.use_vocoder = True + self.simple_api.tts_config.version = "v3" + + try: + with tempfile.NamedTemporaryFile(suffix=".wav") as ref: + with self.assertRaises(DummyHTTPException) as exc: + self.simple_api.build_tts_request( + { + "text": "Hello.", + "ref_audio_path": ref.name, + "prompt_text": "", + "prompt_lang": "zh", + "text_lang": "zh", + "format": "wav", + "text_split_method": "cut5", + "streaming_mode": 0, + } + ) + self.assertEqual(exc.exception.status_code, 400) + self.assertIn("prompt_text is required", str(exc.exception.detail)) + finally: + self.simple_api.tts_config.use_vocoder = original_use_vocoder + self.simple_api.tts_config.version = original_version + + def test_ref_audio_duration_limits(self): + self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000) + self.simple_api.validate_main_ref_duration("ok.wav") + + self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 2, samplerate=16000) + with self.assertRaises(DummyHTTPException) as too_short: + self.simple_api.validate_main_ref_duration("short.wav") + self.assertEqual(too_short.exception.status_code, 400) + + self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 11, samplerate=16000) + with self.assertRaises(DummyHTTPException) as too_long: + self.simple_api.validate_main_ref_duration("long.wav") + self.assertEqual(too_long.exception.status_code, 400) + + def test_save_upload_file_writes_and_closes(self): + with tempfile.TemporaryDirectory() as tmp: + upload = FakeUpload("ref.wav", [b"abc", b"def"]) + saved_path = asyncio.run(self.simple_api.save_upload_file(upload, Path(tmp), "ref")) + + self.assertTrue(upload.closed) + self.assertEqual(Path(saved_path).read_bytes(), b"abcdef") + + def test_mvp_tts_builds_payload_from_uploads_and_cleans_tmp_dir(self): + captured = {} + + with tempfile.TemporaryDirectory() as tmp: + request_dir = Path(tmp) / "request" + request_dir.mkdir() + + original_get_request_upload_dir = self.simple_api.get_request_upload_dir + original_synthesize_response = self.simple_api.synthesize_response + original_sf_info = self.simple_api.sf.info + + def fake_synthesize_response(payload): + captured.update(payload) + return "audio-response" + + try: + self.simple_api.get_request_upload_dir = lambda: request_dir + self.simple_api.sf.info = lambda path: types.SimpleNamespace(frames=16000 * 5, samplerate=16000) + self.simple_api.synthesize_response = fake_synthesize_response + + response = asyncio.run( + self.simple_api.mvp_tts( + text="Hello, test.", + ref_audio=FakeUpload("ref.wav", [b"main"]), + aux_ref_audio=[FakeUpload("aux.wav", [b"aux"])], + prompt_text="", + text_lang="zh", + prompt_lang="zh", + format="wav", + emotion="calm", + speed=1.05, + seed=123, + ) + ) + finally: + self.simple_api.get_request_upload_dir = original_get_request_upload_dir + self.simple_api.synthesize_response = original_synthesize_response + self.simple_api.sf.info = original_sf_info + + self.assertEqual(response, "audio-response") + self.assertFalse(request_dir.exists()) + + self.assertEqual(captured["text"], "Hello, test.") + self.assertEqual(captured["prompt_text"], "") + self.assertEqual(captured["text_lang"], "zh") + self.assertEqual(captured["prompt_lang"], "zh") + self.assertEqual(captured["format"], "wav") + self.assertEqual(captured["text_split_method"], "cut5") + self.assertEqual(captured["streaming_mode"], 0) + self.assertEqual(captured["seed"], 123) + self.assertEqual(captured["temperature"], 0.8) + self.assertEqual(captured["speed"], 1.05) + self.assertEqual(len(captured["aux_ref_audio_paths"]), 1) + + +if __name__ == "__main__": + unittest.main()