mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-03 20:48:14 +08:00
- Add simple_api.py: profile-based API that wraps GPT-SoVITS TTS engine - Add /api/tts endpoint for MVP: accepts ref audio/video, text, optional aux audio - Frontend auto-extracts audio from uploaded video files via Web Audio API - Add emotion presets (neutral/happy/calm/sad/angry) with speed customization - Add test_frontend/index.html with health check, audio playback, and download - Add contract tests (7 tests, all passing) using mock TTS pipeline - Add documentation: simple_api.md (full tutorial), simple_api_quickstart.md - Add startup scripts: go-simple-api.ps1, go-simple-api.bat, open-test-frontend.ps1 - Add soundfile and python-multipart to requirements.txt - Text splitting fixed to cut5 (punctuation-based) per MVP spec
742 lines
25 KiB
Python
742 lines
25 KiB
Python
"""
|
|
Small profile-based API layer for GPT-SoVITS.
|
|
|
|
Run:
|
|
python simple_api.py -c simple_api.yaml
|
|
|
|
Then call:
|
|
POST http://127.0.0.1:9881/speak
|
|
{"text": "hello", "voice": "default"}
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import base64
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
import uuid
|
|
import wave
|
|
from copy import deepcopy
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import yaml
|
|
from fastapi import FastAPI, File, Form, HTTPException, Response, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel, Field
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent
|
|
os.chdir(PROJECT_ROOT)
|
|
sys.path.append(str(PROJECT_ROOT))
|
|
sys.path.append(str(PROJECT_ROOT / "GPT_SoVITS"))
|
|
|
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
|
|
get_method_names as get_cut_method_names,
|
|
)
|
|
|
|
|
|
DEFAULT_TTS_PARAMS: Dict[str, Any] = {
|
|
"text": "",
|
|
"text_lang": "zh",
|
|
"ref_audio_path": "",
|
|
"aux_ref_audio_paths": [],
|
|
"prompt_text": "",
|
|
"prompt_lang": "zh",
|
|
"top_k": 15,
|
|
"top_p": 1.0,
|
|
"temperature": 1.0,
|
|
"text_split_method": "cut5",
|
|
"batch_size": 1,
|
|
"batch_threshold": 0.75,
|
|
"split_bucket": True,
|
|
"speed_factor": 1.0,
|
|
"fragment_interval": 0.3,
|
|
"seed": -1,
|
|
"media_type": "wav",
|
|
"streaming_mode": False,
|
|
"return_fragment": False,
|
|
"fixed_length_chunk": False,
|
|
"parallel_infer": True,
|
|
"repetition_penalty": 1.35,
|
|
"sample_steps": 32,
|
|
"super_sampling": False,
|
|
"overlap_length": 2,
|
|
"min_chunk_length": 16,
|
|
}
|
|
|
|
SUPPORTED_MEDIA_TYPES = {"wav", "raw", "ogg", "aac"}
|
|
SUPPORTED_UPLOAD_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a", ".aac"}
|
|
VOICE_META_KEYS = {"description"}
|
|
|
|
|
|
class SpeakRequest(BaseModel):
|
|
text: str
|
|
voice: Optional[str] = None
|
|
text_lang: Optional[str] = None
|
|
ref_audio_path: Optional[str] = None
|
|
aux_ref_audio_paths: Optional[List[str]] = None
|
|
prompt_text: Optional[str] = None
|
|
prompt_lang: Optional[str] = None
|
|
format: Optional[str] = Field(default=None, description="wav, raw, ogg, or aac")
|
|
stream: Optional[bool] = Field(default=None, description="true maps to streaming mode 2")
|
|
streaming_mode: Optional[Union[bool, int]] = Field(default=None, description="0, 1, 2, 3, true, or false")
|
|
speed: Optional[float] = Field(default=None, description="Alias of speed_factor")
|
|
speed_factor: Optional[float] = None
|
|
top_k: Optional[int] = None
|
|
top_p: Optional[float] = None
|
|
temperature: Optional[float] = None
|
|
text_split_method: Optional[str] = None
|
|
batch_size: Optional[int] = None
|
|
batch_threshold: Optional[float] = None
|
|
split_bucket: Optional[bool] = None
|
|
fragment_interval: Optional[float] = None
|
|
seed: Optional[int] = None
|
|
parallel_infer: Optional[bool] = None
|
|
repetition_penalty: Optional[float] = None
|
|
sample_steps: Optional[int] = None
|
|
super_sampling: Optional[bool] = None
|
|
overlap_length: Optional[int] = None
|
|
min_chunk_length: Optional[int] = None
|
|
|
|
|
|
class WeightsRequest(BaseModel):
|
|
gpt_weights_path: Optional[str] = None
|
|
sovits_weights_path: Optional[str] = None
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="GPT-SoVITS simple API")
|
|
parser.add_argument("-c", "--config", default="simple_api.yaml", help="simple API config path")
|
|
parser.add_argument("--tts-config", default=None, help="GPT-SoVITS tts_infer.yaml path")
|
|
parser.add_argument("-a", "--bind-addr", default=None, help="bind address")
|
|
parser.add_argument("-p", "--port", type=int, default=None, help="bind port")
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
|
|
path = Path(config_path)
|
|
if not path.is_absolute():
|
|
path = PROJECT_ROOT / path
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"simple API config not found: {path}")
|
|
with path.open("r", encoding="utf-8") as f:
|
|
data = yaml.safe_load(f) or {}
|
|
if not isinstance(data, dict):
|
|
raise ValueError("simple API config must be a YAML object")
|
|
return data
|
|
|
|
|
|
args = parse_args()
|
|
simple_config = load_yaml_config(args.config)
|
|
server_config = simple_config.get("server", {}) or {}
|
|
host = args.bind_addr or server_config.get("host", "127.0.0.1")
|
|
port = args.port or int(server_config.get("port", 9881))
|
|
tts_config_path = args.tts_config or server_config.get("tts_config", "GPT_SoVITS/configs/tts_infer.yaml")
|
|
|
|
cut_method_names = get_cut_method_names()
|
|
tts_config: Optional[TTS_Config] = None
|
|
tts_pipeline: Optional[TTS] = None
|
|
infer_lock = threading.Lock()
|
|
|
|
APP = FastAPI(
|
|
title="GPT-SoVITS Simple API",
|
|
description="Profile-based API layer that hides GPT-SoVITS request details.",
|
|
version="1.0.0",
|
|
)
|
|
APP.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=simple_config.get("cors_allow_origins", ["*"]),
|
|
allow_credentials=False,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
test_frontend_dir = PROJECT_ROOT / "test_frontend"
|
|
if test_frontend_dir.exists():
|
|
APP.mount("/test", StaticFiles(directory=str(test_frontend_dir), html=True), name="test_frontend")
|
|
|
|
|
|
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
|
def handle_pack_ogg() -> None:
|
|
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
|
audio_file.write(data)
|
|
|
|
try:
|
|
threading.stack_size(4096 * 4096)
|
|
pack_thread = threading.Thread(target=handle_pack_ogg)
|
|
pack_thread.start()
|
|
pack_thread.join()
|
|
except (RuntimeError, ValueError):
|
|
handle_pack_ogg()
|
|
return io_buffer
|
|
|
|
|
|
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
|
del rate
|
|
io_buffer.write(data.tobytes())
|
|
return io_buffer
|
|
|
|
|
|
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
|
del io_buffer
|
|
wav_buffer = BytesIO()
|
|
sf.write(wav_buffer, data, rate, format="wav")
|
|
return wav_buffer
|
|
|
|
|
|
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int) -> BytesIO:
|
|
process = subprocess.Popen(
|
|
[
|
|
"ffmpeg",
|
|
"-f",
|
|
"s16le",
|
|
"-ar",
|
|
str(rate),
|
|
"-ac",
|
|
"1",
|
|
"-i",
|
|
"pipe:0",
|
|
"-c:a",
|
|
"aac",
|
|
"-b:a",
|
|
"192k",
|
|
"-vn",
|
|
"-f",
|
|
"adts",
|
|
"pipe:1",
|
|
],
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
)
|
|
out, _ = process.communicate(input=data.tobytes())
|
|
io_buffer.write(out)
|
|
return io_buffer
|
|
|
|
|
|
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
|
if media_type == "ogg":
|
|
io_buffer = pack_ogg(io_buffer, data, rate)
|
|
elif media_type == "aac":
|
|
io_buffer = pack_aac(io_buffer, data, rate)
|
|
elif media_type == "wav":
|
|
io_buffer = pack_wav(io_buffer, data, rate)
|
|
else:
|
|
io_buffer = pack_raw(io_buffer, data, rate)
|
|
io_buffer.seek(0)
|
|
return io_buffer
|
|
|
|
|
|
def wave_header_chunk(frame_input: bytes = b"", channels: int = 1, sample_width: int = 2, sample_rate: int = 32000) -> bytes:
|
|
wav_buf = BytesIO()
|
|
with wave.open(wav_buf, "wb") as vfout:
|
|
vfout.setnchannels(channels)
|
|
vfout.setsampwidth(sample_width)
|
|
vfout.setframerate(sample_rate)
|
|
vfout.writeframes(frame_input)
|
|
wav_buf.seek(0)
|
|
return wav_buf.read()
|
|
|
|
|
|
def request_to_dict(request: BaseModel) -> Dict[str, Any]:
|
|
if hasattr(request, "model_dump"):
|
|
return request.model_dump(exclude_none=True)
|
|
return request.dict(exclude_none=True)
|
|
|
|
|
|
def get_default_voice_name() -> Optional[str]:
|
|
default_voice = simple_config.get("default_voice")
|
|
if default_voice:
|
|
return str(default_voice)
|
|
voices = simple_config.get("voices", {}) or {}
|
|
if "default" in voices:
|
|
return "default"
|
|
return next(iter(voices.keys()), None)
|
|
|
|
|
|
def resolve_project_path(path_value: Optional[str]) -> Optional[str]:
|
|
if path_value in [None, ""]:
|
|
return None
|
|
path = Path(str(path_value))
|
|
if not path.is_absolute():
|
|
path = PROJECT_ROOT / path
|
|
return str(path)
|
|
|
|
|
|
def normalize_streaming(streaming_mode: Optional[Union[bool, int]], stream: Optional[bool]) -> Tuple[bool, bool, bool, bool]:
|
|
if streaming_mode is None:
|
|
streaming_mode = 2 if stream else 0
|
|
elif isinstance(streaming_mode, bool):
|
|
streaming_mode = 2 if streaming_mode else 0
|
|
|
|
if streaming_mode == 0:
|
|
return False, False, False, False
|
|
if streaming_mode == 1:
|
|
return False, True, False, True
|
|
if streaming_mode == 2:
|
|
return True, False, False, True
|
|
if streaming_mode == 3:
|
|
return True, False, True, True
|
|
raise HTTPException(status_code=400, detail="streaming_mode must be 0, 1, 2, 3, true, or false")
|
|
|
|
|
|
def get_upload_config() -> Dict[str, Any]:
|
|
return simple_config.get("upload", {}) or {}
|
|
|
|
|
|
def get_upload_root() -> Path:
|
|
upload_dir = str(get_upload_config().get("dir", "runtime/uploads"))
|
|
root = Path(upload_dir)
|
|
if not root.is_absolute():
|
|
root = PROJECT_ROOT / root
|
|
root.mkdir(parents=True, exist_ok=True)
|
|
return root
|
|
|
|
|
|
def get_request_upload_dir() -> Path:
|
|
request_dir = get_upload_root() / uuid.uuid4().hex
|
|
request_dir.mkdir(parents=True, exist_ok=True)
|
|
return request_dir
|
|
|
|
|
|
def get_upload_limits() -> Tuple[float, float, int]:
|
|
upload_config = get_upload_config()
|
|
min_seconds = float(upload_config.get("min_ref_seconds", 3.0))
|
|
max_seconds = float(upload_config.get("max_ref_seconds", 10.0))
|
|
max_upload_mb = int(upload_config.get("max_upload_mb", 80))
|
|
return min_seconds, max_seconds, max_upload_mb
|
|
|
|
|
|
def get_upload_suffix(upload: UploadFile) -> str:
|
|
suffix = Path(upload.filename or "").suffix.lower()
|
|
if suffix:
|
|
return suffix
|
|
content_type = (upload.content_type or "").lower()
|
|
if content_type in {"audio/wav", "audio/x-wav", "audio/wave"}:
|
|
return ".wav"
|
|
if content_type == "audio/mpeg":
|
|
return ".mp3"
|
|
if content_type == "audio/ogg":
|
|
return ".ogg"
|
|
if content_type in {"audio/aac", "audio/aacp"}:
|
|
return ".aac"
|
|
return ".wav"
|
|
|
|
|
|
def validate_audio_upload(upload: UploadFile) -> None:
|
|
suffix = get_upload_suffix(upload)
|
|
content_type = (upload.content_type or "").lower()
|
|
is_audio_type = content_type.startswith("audio/") or content_type in {"application/octet-stream", ""}
|
|
if suffix not in SUPPORTED_UPLOAD_EXTENSIONS or not is_audio_type:
|
|
raise HTTPException(status_code=400, detail=f"unsupported audio upload: {upload.filename or content_type}")
|
|
|
|
|
|
async def save_upload_file(upload: UploadFile, target_dir: Path, name_prefix: str) -> str:
|
|
validate_audio_upload(upload)
|
|
_, _, max_upload_mb = get_upload_limits()
|
|
max_bytes = max_upload_mb * 1024 * 1024
|
|
suffix = get_upload_suffix(upload)
|
|
target_path = target_dir / f"{name_prefix}{suffix}"
|
|
total_size = 0
|
|
|
|
try:
|
|
with target_path.open("wb") as f:
|
|
while True:
|
|
chunk = await upload.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
total_size += len(chunk)
|
|
if total_size > max_bytes:
|
|
raise HTTPException(status_code=400, detail=f"audio upload exceeds {max_upload_mb} MB")
|
|
f.write(chunk)
|
|
if total_size == 0:
|
|
raise HTTPException(status_code=400, detail=f"audio upload is empty: {upload.filename or name_prefix}")
|
|
finally:
|
|
await upload.close()
|
|
|
|
return str(target_path)
|
|
|
|
|
|
def get_audio_duration_with_ffprobe(audio_path: str) -> float:
|
|
result = subprocess.run(
|
|
[
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-show_entries",
|
|
"format=duration",
|
|
"-of",
|
|
"default=noprint_wrappers=1:nokey=1",
|
|
audio_path,
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=15,
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(result.stderr.strip() or "ffprobe failed")
|
|
return float(result.stdout.strip())
|
|
|
|
|
|
def get_audio_duration_seconds(audio_path: str) -> float:
|
|
try:
|
|
info = sf.info(audio_path)
|
|
if info.samplerate <= 0:
|
|
raise ValueError("invalid sample rate")
|
|
return float(info.frames) / float(info.samplerate)
|
|
except Exception as sf_exc:
|
|
try:
|
|
return get_audio_duration_with_ffprobe(audio_path)
|
|
except Exception as ffprobe_exc:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"unable to read audio duration: {audio_path}. soundfile={sf_exc}; ffprobe={ffprobe_exc}",
|
|
) from ffprobe_exc
|
|
|
|
|
|
def validate_main_ref_duration(audio_path: str) -> None:
|
|
min_seconds, max_seconds, _ = get_upload_limits()
|
|
duration = get_audio_duration_seconds(audio_path)
|
|
if duration < min_seconds or duration > max_seconds:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"ref_audio duration must be between {min_seconds:g}s and {max_seconds:g}s, got {duration:.2f}s",
|
|
)
|
|
|
|
|
|
def apply_emotion_preset(payload: Dict[str, Any], emotion: Optional[str]) -> None:
|
|
if emotion in [None, ""]:
|
|
return
|
|
emotion_name = str(emotion).strip().lower()
|
|
emotion_presets = simple_config.get("emotion_presets", {}) or {}
|
|
if emotion_name not in emotion_presets:
|
|
supported = ", ".join(sorted(emotion_presets.keys())) or "none"
|
|
raise HTTPException(status_code=400, detail=f"emotion is not configured: {emotion_name}. supported: {supported}")
|
|
payload.update(emotion_presets[emotion_name] or {})
|
|
|
|
|
|
def prompt_text_required_for_current_model() -> bool:
|
|
if tts_config is None:
|
|
return False
|
|
use_vocoder = getattr(tts_config, "use_vocoder", None)
|
|
if use_vocoder is not None:
|
|
return bool(use_vocoder)
|
|
return getattr(tts_config, "version", None) in {"v3", "v4"}
|
|
|
|
|
|
def voice_public_info(name: str, profile: Dict[str, Any]) -> Dict[str, Any]:
|
|
ref_audio_path = resolve_project_path(profile.get("ref_audio_path"))
|
|
return {
|
|
"name": name,
|
|
"description": profile.get("description", ""),
|
|
"text_lang": profile.get("text_lang"),
|
|
"prompt_lang": profile.get("prompt_lang"),
|
|
"ref_audio_path": profile.get("ref_audio_path"),
|
|
"ready": bool(ref_audio_path and Path(ref_audio_path).exists() and profile.get("prompt_lang")),
|
|
}
|
|
|
|
|
|
def build_tts_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any], str, bool]:
|
|
voices = simple_config.get("voices", {}) or {}
|
|
explicit_ref_audio = bool(payload.get("ref_audio_path"))
|
|
voice_name = payload.pop("voice", None)
|
|
if voice_name is None and not explicit_ref_audio:
|
|
voice_name = get_default_voice_name()
|
|
voice_profile: Dict[str, Any] = {}
|
|
|
|
if voice_name:
|
|
if voice_name not in voices:
|
|
raise HTTPException(status_code=404, detail=f"voice not found: {voice_name}")
|
|
voice_profile = deepcopy(voices[voice_name] or {})
|
|
|
|
tts_req = deepcopy(DEFAULT_TTS_PARAMS)
|
|
tts_req.update(simple_config.get("defaults", {}) or {})
|
|
tts_req.update({k: v for k, v in voice_profile.items() if k not in VOICE_META_KEYS})
|
|
|
|
stream = payload.pop("stream", None)
|
|
media_type = payload.pop("format", None)
|
|
speed = payload.pop("speed", None)
|
|
if media_type is not None:
|
|
tts_req["media_type"] = media_type
|
|
|
|
for key, value in payload.items():
|
|
if key in DEFAULT_TTS_PARAMS:
|
|
tts_req[key] = value
|
|
if speed is not None:
|
|
tts_req["speed_factor"] = speed
|
|
|
|
ref_audio_path = resolve_project_path(tts_req.get("ref_audio_path"))
|
|
if ref_audio_path:
|
|
tts_req["ref_audio_path"] = ref_audio_path
|
|
|
|
aux_paths = tts_req.get("aux_ref_audio_paths") or []
|
|
tts_req["aux_ref_audio_paths"] = [resolve_project_path(item) for item in aux_paths if item]
|
|
|
|
media_type = str(tts_req.get("media_type", "wav")).lower()
|
|
tts_req["media_type"] = media_type
|
|
if media_type not in SUPPORTED_MEDIA_TYPES:
|
|
raise HTTPException(status_code=400, detail=f"format is not supported: {media_type}")
|
|
|
|
text = str(tts_req.get("text") or "").strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="text is required")
|
|
tts_req["text"] = text
|
|
|
|
if not tts_req.get("ref_audio_path"):
|
|
raise HTTPException(status_code=400, detail="ref_audio_path is required in voice profile or request")
|
|
if not Path(tts_req["ref_audio_path"]).exists():
|
|
raise HTTPException(status_code=400, detail=f"ref_audio_path does not exist: {tts_req['ref_audio_path']}")
|
|
|
|
for aux_path in tts_req["aux_ref_audio_paths"]:
|
|
if aux_path and not Path(aux_path).exists():
|
|
raise HTTPException(status_code=400, detail=f"aux_ref_audio_path does not exist: {aux_path}")
|
|
|
|
if tts_config is None:
|
|
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
|
|
|
text_lang = str(tts_req.get("text_lang") or "").lower()
|
|
prompt_lang = str(tts_req.get("prompt_lang") or "").lower()
|
|
tts_req["text_lang"] = text_lang
|
|
tts_req["prompt_lang"] = prompt_lang
|
|
|
|
if text_lang not in tts_config.languages:
|
|
raise HTTPException(status_code=400, detail=f"text_lang is not supported: {text_lang}")
|
|
if prompt_lang not in tts_config.languages:
|
|
raise HTTPException(status_code=400, detail=f"prompt_lang is not supported: {prompt_lang}")
|
|
if not str(tts_req.get("prompt_text") or "").strip() and prompt_text_required_for_current_model():
|
|
version = getattr(tts_config, "version", "current")
|
|
raise HTTPException(status_code=400, detail=f"prompt_text is required when using GPT-SoVITS {version}")
|
|
|
|
split_method = str(tts_req.get("text_split_method") or "cut5")
|
|
if split_method not in cut_method_names:
|
|
raise HTTPException(status_code=400, detail=f"text_split_method is not supported: {split_method}")
|
|
|
|
streaming_mode = tts_req.pop("streaming_mode", None)
|
|
streaming_enabled, return_fragment, fixed_length_chunk, response_stream = normalize_streaming(streaming_mode, stream)
|
|
tts_req["streaming_mode"] = streaming_enabled
|
|
tts_req["return_fragment"] = return_fragment
|
|
tts_req["fixed_length_chunk"] = fixed_length_chunk
|
|
|
|
return tts_req, media_type, response_stream
|
|
|
|
|
|
def synthesize_once(payload: Dict[str, Any]) -> Tuple[bytes, str]:
|
|
tts_req, media_type, response_stream = build_tts_request(payload)
|
|
if response_stream:
|
|
raise HTTPException(status_code=400, detail="base64 output does not support streaming")
|
|
|
|
if tts_pipeline is None:
|
|
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
|
|
|
try:
|
|
with infer_lock:
|
|
sr, audio_data = next(tts_pipeline.run(tts_req))
|
|
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
|
return audio_bytes, media_type
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=400, detail={"message": "tts failed", "exception": str(exc)}) from exc
|
|
|
|
|
|
def synthesize_response(payload: Dict[str, Any]) -> Response:
|
|
tts_req, media_type, response_stream = build_tts_request(payload)
|
|
if tts_pipeline is None:
|
|
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
|
|
|
if response_stream:
|
|
|
|
def streaming_generator() -> Generator[bytes, None, None]:
|
|
first_chunk = True
|
|
chunk_media_type = media_type
|
|
with infer_lock:
|
|
for sr, chunk in tts_pipeline.run(tts_req):
|
|
if first_chunk and chunk_media_type == "wav":
|
|
yield wave_header_chunk(sample_rate=sr)
|
|
chunk_media_type = "raw"
|
|
first_chunk = False
|
|
yield pack_audio(BytesIO(), chunk, sr, chunk_media_type).getvalue()
|
|
|
|
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
|
|
|
try:
|
|
with infer_lock:
|
|
sr, audio_data = next(tts_pipeline.run(tts_req))
|
|
audio_bytes = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
|
return Response(audio_bytes, media_type=f"audio/{media_type}")
|
|
except Exception as exc:
|
|
return JSONResponse(status_code=400, content={"message": "tts failed", "exception": str(exc)})
|
|
|
|
|
|
@APP.on_event("startup")
|
|
def startup() -> None:
|
|
global tts_config, tts_pipeline
|
|
tts_config = TTS_Config(tts_config_path)
|
|
print(tts_config)
|
|
tts_pipeline = TTS(tts_config)
|
|
|
|
|
|
@APP.get("/")
|
|
def index() -> Dict[str, Any]:
|
|
return {
|
|
"name": "GPT-SoVITS Simple API",
|
|
"endpoints": [
|
|
"/health",
|
|
"/voices",
|
|
"/api/tts",
|
|
"/speak",
|
|
"/speak/base64",
|
|
"/admin/weights",
|
|
"/admin/reload-config",
|
|
],
|
|
}
|
|
|
|
|
|
@APP.get("/health")
|
|
def health() -> Dict[str, Any]:
|
|
return {
|
|
"status": "ok" if tts_pipeline is not None else "starting",
|
|
"tts_config": tts_config_path,
|
|
"version": getattr(tts_config, "version", None),
|
|
"languages": getattr(tts_config, "languages", []),
|
|
}
|
|
|
|
|
|
@APP.get("/voices")
|
|
def list_voices() -> Dict[str, Any]:
|
|
voices = simple_config.get("voices", {}) or {}
|
|
return {
|
|
"default_voice": get_default_voice_name(),
|
|
"voices": [voice_public_info(name, profile or {}) for name, profile in voices.items()],
|
|
}
|
|
|
|
|
|
@APP.post("/api/tts")
|
|
async def mvp_tts(
|
|
text: str = Form(...),
|
|
ref_audio: UploadFile = File(...),
|
|
aux_ref_audio: Optional[List[UploadFile]] = File(default=None),
|
|
prompt_text: str = Form(default=""),
|
|
text_lang: str = Form(default="zh"),
|
|
prompt_lang: str = Form(default="zh"),
|
|
format: str = Form(default="wav"),
|
|
emotion: Optional[str] = Form(default=None),
|
|
speed: Optional[float] = Form(default=None),
|
|
seed: int = Form(default=-1),
|
|
) -> Response:
|
|
request_dir = get_request_upload_dir()
|
|
try:
|
|
ref_audio_path = await save_upload_file(ref_audio, request_dir, "ref")
|
|
validate_main_ref_duration(ref_audio_path)
|
|
|
|
aux_ref_audio_paths = []
|
|
for index, upload in enumerate(aux_ref_audio or []):
|
|
aux_path = await save_upload_file(upload, request_dir, f"aux_{index}")
|
|
get_audio_duration_seconds(aux_path)
|
|
aux_ref_audio_paths.append(aux_path)
|
|
|
|
payload: Dict[str, Any] = {
|
|
"text": text,
|
|
"text_lang": text_lang,
|
|
"ref_audio_path": ref_audio_path,
|
|
"aux_ref_audio_paths": aux_ref_audio_paths,
|
|
"prompt_text": prompt_text or "",
|
|
"prompt_lang": prompt_lang,
|
|
"format": format,
|
|
"text_split_method": "cut5",
|
|
"streaming_mode": 0,
|
|
"seed": seed,
|
|
}
|
|
apply_emotion_preset(payload, emotion)
|
|
payload["text_split_method"] = "cut5"
|
|
payload["streaming_mode"] = 0
|
|
if speed is not None:
|
|
payload["speed"] = speed
|
|
|
|
return synthesize_response(payload)
|
|
finally:
|
|
shutil.rmtree(request_dir, ignore_errors=True)
|
|
|
|
|
|
@APP.get("/speak")
|
|
def speak_get(
|
|
text: str,
|
|
voice: Optional[str] = None,
|
|
text_lang: Optional[str] = None,
|
|
format: Optional[str] = None,
|
|
stream: Optional[bool] = None,
|
|
speed: Optional[float] = None,
|
|
) -> Response:
|
|
payload = {
|
|
"text": text,
|
|
"voice": voice,
|
|
"text_lang": text_lang,
|
|
"format": format,
|
|
"stream": stream,
|
|
"speed": speed,
|
|
}
|
|
return synthesize_response({k: v for k, v in payload.items() if v is not None})
|
|
|
|
|
|
@APP.post("/speak")
|
|
def speak_post(request: SpeakRequest) -> Response:
|
|
return synthesize_response(request_to_dict(request))
|
|
|
|
|
|
@APP.post("/v1/tts")
|
|
def openai_style_tts(request: SpeakRequest) -> Response:
|
|
return synthesize_response(request_to_dict(request))
|
|
|
|
|
|
@APP.post("/speak/base64")
|
|
def speak_base64(request: SpeakRequest) -> Dict[str, Any]:
|
|
audio_bytes, media_type = synthesize_once(request_to_dict(request))
|
|
return {
|
|
"media_type": f"audio/{media_type}",
|
|
"audio_base64": base64.b64encode(audio_bytes).decode("ascii"),
|
|
}
|
|
|
|
|
|
@APP.post("/admin/reload-config")
|
|
def reload_config() -> Dict[str, Any]:
|
|
global simple_config
|
|
simple_config = load_yaml_config(args.config)
|
|
return {"message": "success", "default_voice": get_default_voice_name()}
|
|
|
|
|
|
@APP.post("/admin/weights")
|
|
def set_weights(request: WeightsRequest) -> Dict[str, Any]:
|
|
if tts_pipeline is None:
|
|
raise HTTPException(status_code=503, detail="TTS pipeline is not ready")
|
|
if not request.gpt_weights_path and not request.sovits_weights_path:
|
|
raise HTTPException(status_code=400, detail="gpt_weights_path or sovits_weights_path is required")
|
|
|
|
try:
|
|
with infer_lock:
|
|
if request.gpt_weights_path:
|
|
tts_pipeline.init_t2s_weights(request.gpt_weights_path)
|
|
if request.sovits_weights_path:
|
|
tts_pipeline.init_vits_weights(request.sovits_weights_path)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=400, detail={"message": "change weights failed", "exception": str(exc)}) from exc
|
|
|
|
return {"message": "success"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
try:
|
|
uvicorn.run(app=APP, host=None if host == "None" else host, port=port, workers=1)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
sys.exit(1)
|