3초 미만 제한없음

This commit is contained in:
SanghyeonAn94 2025-10-15 17:06:21 +09:00
parent 11aa78bd9b
commit e8616c87c6
3 changed files with 774 additions and 10 deletions

View File

@ -799,8 +799,8 @@ class TTS:
) )
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: # if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
raise OSError(i18n("参考音频在3~10秒范围外请更换")) # raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k) wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav) zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(self.configs.device) wav16k = wav16k.to(self.configs.device)

View File

@ -811,9 +811,6 @@ def get_tts_wav(
if not ref_free: if not ref_free:
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
gr.Warning(i18n("参考音频在3~10秒范围外请更换"))
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k) wav16k = torch.from_numpy(wav16k)
if is_half == True: if is_half == True:
wav16k = wav16k.half().to(device) wav16k = wav16k.half().to(device)

777
api_v2.py
View File

@ -101,7 +101,10 @@ RESP:
import os import os
import sys import sys
import traceback import traceback
from typing import Generator from typing import Generator, Dict, Any, Optional
import uuid
import asyncio
from datetime import datetime
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -121,11 +124,47 @@ from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from pydantic import BaseModel from pydantic import BaseModel
import json
import yaml
# Import config variables (avoiding webui to prevent Gradio loading)
from config import (
exp_root,
python_exec,
is_half,
GPU_INDEX,
infer_device,
SoVITS_weight_version2root,
GPT_weight_version2root,
)
# print(sys.path) # print(sys.path)
i18n = I18nAuto() i18n = I18nAuto()
cut_method_names = get_cut_method_names() cut_method_names = get_cut_method_names()
# GPU helper functions (replicated from webui.py to avoid import)
set_gpu_numbers = GPU_INDEX
default_gpu_numbers = infer_device.index if hasattr(infer_device, 'index') else 0
def fix_gpu_number(input_val):
"""Fix GPU number to be within valid range."""
try:
if int(input_val) not in set_gpu_numbers:
return default_gpu_numbers
except:
return input_val
return input_val
def fix_gpu_numbers(inputs):
"""Fix multiple GPU numbers separated by comma."""
output = []
try:
for input_val in inputs.split(","):
output.append(str(fix_gpu_number(input_val)))
return ",".join(output)
except:
return inputs
parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
@ -172,6 +211,69 @@ class TTS_Request(BaseModel):
super_sampling: bool = False super_sampling: bool = False
class SpeechSlicingRequest(BaseModel):
inp: str
opt_root: str
threshold: str = "-34"
min_length: str = "4000"
min_interval: str = "300"
hop_size: str = "10"
max_sil_kept: str = "500"
_max: float = 0.9
alpha: float = 0.25
n_parts: int = 4
class STTRequest(BaseModel):
input_folder: str
output_folder: str
model_path: str = "tools/asr/models/faster-whisper-large-v3"
language: str = "auto"
precision: str = "float32"
class DatasetFormattingRequest(BaseModel):
inp_text: str
inp_wav_dir: str
exp_name: str
version: str = "v4"
gpu_numbers: str = "0-0"
bert_pretrained_dir: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
ssl_pretrained_dir: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
pretrained_s2G_path: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
class FineTuneSoVITSRequest(BaseModel):
version: str = "v4"
batch_size: int = 2
total_epoch: int = 2
exp_name: str
text_low_lr_rate: float = 0.4
if_save_latest: bool = True
if_save_every_weights: bool = True
save_every_epoch: int = 1
gpu_numbers1Ba: str = "0"
pretrained_s2G: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
pretrained_s2D: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Dv4.pth"
if_grad_ckpt: bool = False
lora_rank: str = "32"
class FineTuneGPTRequest(BaseModel):
batch_size: int = 8
total_epoch: int = 15
exp_name: str
if_dpo: bool = False
if_save_latest: bool = True
if_save_every_weights: bool = True
save_every_epoch: int = 5
gpu_numbers: str = "0"
pretrained_s1: str = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
jobs: Dict[str, Dict[str, Any]] = {}
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
@ -272,14 +374,14 @@ def check_params(req: dict):
return JSONResponse(status_code=400, content={"message": "text is required"}) return JSONResponse(status_code=400, content={"message": "text is required"})
if text_lang in [None, ""]: if text_lang in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text_lang is required"}) return JSONResponse(status_code=400, content={"message": "text_lang is required"})
elif text_lang.lower() not in tts_config.languages: elif text_lang not in tts_config.languages:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"},
) )
if prompt_lang in [None, ""]: if prompt_lang in [None, ""]:
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages: elif prompt_lang not in tts_config.languages:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"},
@ -407,11 +509,11 @@ async def tts_get_endpoint(
): ):
req = { req = {
"text": text, "text": text,
"text_lang": text_lang.lower(), "text_lang": text_lang,
"ref_audio_path": ref_audio_path, "ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": aux_ref_audio_paths, "aux_ref_audio_paths": aux_ref_audio_paths,
"prompt_text": prompt_text, "prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(), "prompt_lang": prompt_lang,
"top_k": top_k, "top_k": top_k,
"top_p": top_p, "top_p": top_p,
"temperature": temperature, "temperature": temperature,
@ -434,6 +536,23 @@ async def tts_get_endpoint(
@APP.post("/tts") @APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request): async def tts_post_endpoint(request: TTS_Request):
# DEBUG: Print received payload
print(f"\n{'='*80}")
print(f"[TTS DEBUG] Received request:")
print(f" text: {request.text[:100] if len(request.text) > 100 else request.text}") # Truncate long text
print(f" text_lang: {request.text_lang}")
print(f" ref_audio_path: {request.ref_audio_path}")
print(f" prompt_text: {request.prompt_text[:100] if request.prompt_text and len(request.prompt_text) > 100 else request.prompt_text}")
print(f" prompt_lang: {request.prompt_lang}")
print(f" top_k: {request.top_k}")
print(f" top_p: {request.top_p}")
print(f" temperature: {request.temperature}")
print(f" text_split_method: {request.text_split_method}")
print(f" batch_size: {request.batch_size}")
print(f" speed_factor: {request.speed_factor}")
print(f" streaming_mode: {request.streaming_mode}")
print(f"{'='*80}\n")
req = request.dict() req = request.dict()
return await tts_handle(req) return await tts_handle(req)
@ -489,6 +608,654 @@ async def set_sovits_weights(weights_path: str = None):
return JSONResponse(status_code=200, content={"message": "success"}) return JSONResponse(status_code=200, content={"message": "success"})
async def execute_job_async(job_id: str, operation_func, *args, **kwargs):
"""
Execute a job asynchronously in background.
Args:
job_id: Unique job identifier
operation_func: Function to execute (from webui.py)
args, kwargs: Arguments for the operation function
"""
jobs[job_id]["status"] = "running"
jobs[job_id]["started_at"] = datetime.now().isoformat()
try:
result = await asyncio.to_thread(operation_func, *args, **kwargs)
if hasattr(result, '__iter__') and not isinstance(result, (str, dict)):
final_result = None
for item in result:
final_result = item
result = final_result
jobs[job_id]["status"] = "completed"
jobs[job_id]["result"] = result
jobs[job_id]["completed_at"] = datetime.now().isoformat()
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
jobs[job_id]["traceback"] = traceback.format_exc()
jobs[job_id]["failed_at"] = datetime.now().isoformat()
@APP.get("/jobs/{job_id}")
@APP.get("/job-status/{job_id}") # Alias for compatibility
async def get_job_status(job_id: str):
"""
Get job status and result.
Returns:
{
"job_id": str,
"status": "queued" | "running" | "completed" | "failed",
"result": Any (if completed),
"error": str (if failed),
"created_at": str,
"started_at": str (if running/completed/failed),
"completed_at": str (if completed),
"failed_at": str (if failed)
}
"""
if job_id not in jobs:
return JSONResponse(status_code=404, content={"message": "job not found"})
job_data = jobs[job_id].copy()
job_data["job_id"] = job_id
return JSONResponse(status_code=200, content=job_data)
async def execute_speech_slicing_direct(job_id: str, request: SpeechSlicingRequest):
"""
Execute speech slicing by directly calling slice_audio.py subprocess.
Replaces webui.open_slice() to avoid Gradio dependency.
"""
jobs[job_id]["status"] = "running"
jobs[job_id]["started_at"] = datetime.now().isoformat()
try:
# Prepare environment with PYTHONPATH
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")])
# Create processes for parallel slicing (n_parts)
processes = []
for i_part in range(request.n_parts):
cmd = [
python_exec,
"tools/slice_audio.py",
request.inp,
request.opt_root,
str(request.threshold),
str(request.min_length),
str(request.min_interval),
str(request.hop_size),
str(request.max_sil_kept),
str(request._max),
str(request.alpha),
str(i_part),
str(request.n_parts),
]
print(f"[SPEECH SLICING] Executing: {' '.join(cmd)}")
p = subprocess.Popen(cmd, env=env, cwd=now_dir)
processes.append(p)
# Wait for all processes to complete
for p in processes:
p.wait()
# Check if any process failed
exit_codes = [p.returncode for p in processes]
if any(code != 0 for code in exit_codes):
raise Exception(f"Speech slicing failed with exit codes: {exit_codes}")
jobs[job_id]["status"] = "completed"
jobs[job_id]["result"] = {
"output_dir": request.opt_root,
"file_count": request.n_parts
}
jobs[job_id]["completed_at"] = datetime.now().isoformat()
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
jobs[job_id]["traceback"] = traceback.format_exc()
jobs[job_id]["failed_at"] = datetime.now().isoformat()
@APP.post("/preprocessing/speech-slicing")
async def speech_slicing_endpoint(request: SpeechSlicingRequest):
"""
Start speech slicing job.
Directly executes tools/slice_audio.py (no webui dependency).
"""
# DEBUG: Print received payload
print(f"\n{'='*80}")
print(f"[SPEECH SLICING DEBUG] Received request:")
print(f" inp: {request.inp}")
print(f" opt_root: {request.opt_root}")
print(f" threshold: {request.threshold}")
print(f" min_length: {request.min_length}")
print(f" min_interval: {request.min_interval}")
print(f" hop_size: {request.hop_size}")
print(f" max_sil_kept: {request.max_sil_kept}")
print(f" _max: {request._max}")
print(f" alpha: {request.alpha}")
print(f" n_parts: {request.n_parts}")
print(f"{'='*80}\n")
job_id = str(uuid.uuid4())
jobs[job_id] = {
"status": "queued",
"operation": "speech_slicing",
"created_at": datetime.now().isoformat()
}
try:
asyncio.create_task(execute_speech_slicing_direct(job_id, request))
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
@APP.post("/preprocessing/stt")
async def stt_endpoint(request: STTRequest):
"""
Start STT (Speech-to-Text) job.
Wraps tools/asr/fasterwhisper_asr.execute_asr()
"""
# DEBUG: Print received payload
print(f"\n{'='*80}")
print(f"[STT DEBUG] Received STT request:")
print(request)
print(f"{'='*80}\n")
job_id = str(uuid.uuid4())
jobs[job_id] = {
"status": "queued",
"operation": "stt",
"created_at": datetime.now().isoformat()
}
try:
from tools.asr.fasterwhisper_asr import execute_asr
asyncio.create_task(execute_job_async(
job_id,
execute_asr,
request.input_folder,
request.output_folder,
request.model_path,
request.language,
request.precision
))
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
except Exception as e:
print(f"[STT ERROR] Failed to start STT job: {str(e)}")
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
async def execute_dataset_formatting(job_id: str, request: DatasetFormattingRequest):
"""
Execute dataset formatting sequentially: open1a -> open1b -> open1c
Directly executes subprocess (no webui dependency).
"""
jobs[job_id]["status"] = "running"
jobs[job_id]["started_at"] = datetime.now().isoformat()
jobs[job_id]["current_stage"] = "open1a"
try:
opt_dir = f"{exp_root}/{request.exp_name}"
os.makedirs(opt_dir, exist_ok=True)
# Parse GPU numbers
gpu_names = request.gpu_numbers.split("-")
all_parts = len(gpu_names)
# Stage 1a: Get text features
print(f"[DATASET FORMATTING] Starting open1a...")
for i_part in range(all_parts):
env = os.environ.copy()
env.update({
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
"inp_text": request.inp_text,
"inp_wav_dir": request.inp_wav_dir,
"exp_name": request.exp_name,
"opt_dir": opt_dir,
"bert_pretrained_dir": request.bert_pretrained_dir,
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
"is_half": str(is_half),
})
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/1-get-text.py"]
print(f"[DATASET FORMATTING] Executing 1a part {i_part}: {' '.join(cmd)}")
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
# Merge text files from 1a stage
opt = []
path_text = f"{opt_dir}/2-name2text.txt"
for i_part in range(all_parts):
text_path = f"{opt_dir}/2-name2text-{i_part}.txt"
if os.path.exists(text_path):
with open(text_path, "r", encoding="utf8") as f:
opt += f.read().strip("\n").split("\n")
os.remove(text_path)
with open(path_text, "w", encoding="utf8") as f:
f.write("\n".join(opt) + "\n")
# Stage 1b: Get hubert features
jobs[job_id]["current_stage"] = "open1b"
print(f"[DATASET FORMATTING] Starting open1b...")
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
for i_part in range(all_parts):
env = os.environ.copy()
env.update({
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
"inp_text": request.inp_text,
"inp_wav_dir": request.inp_wav_dir,
"exp_name": request.exp_name,
"opt_dir": opt_dir,
"cnhubert_base_dir": request.ssl_pretrained_dir,
"sv_path": sv_path,
"is_half": str(is_half),
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
})
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py"]
print(f"[DATASET FORMATTING] Executing 1b part {i_part}: {' '.join(cmd)}")
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
# For v2Pro version, also run 2-get-sv.py
if "Pro" in request.version:
for i_part in range(all_parts):
env = os.environ.copy()
env.update({
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
"exp_dir": opt_dir,
"sv_path": sv_path,
"is_half": str(is_half),
})
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/2-get-sv.py"]
print(f"[DATASET FORMATTING] Executing 2-get-sv part {i_part}: {' '.join(cmd)}")
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
# Stage 1c: Get semantic features
jobs[job_id]["current_stage"] = "open1c"
print(f"[DATASET FORMATTING] Starting open1c...")
config_file = (
"GPT_SoVITS/configs/s2.json"
if request.version not in {"v2Pro", "v2ProPlus"}
else f"GPT_SoVITS/configs/s2{request.version}.json"
)
for i_part in range(all_parts):
env = os.environ.copy()
env.update({
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
"inp_text": request.inp_text,
"exp_name": request.exp_name,
"opt_dir": opt_dir,
"pretrained_s2G": request.pretrained_s2G_path,
"s2config_path": config_file,
"is_half": str(is_half),
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
})
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/3-get-semantic.py"]
print(f"[DATASET FORMATTING] Executing 1c part {i_part}: {' '.join(cmd)}")
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
# Merge semantic files (from open1c logic in webui.py)
opt = ["item_name\tsemantic_audio"]
path_semantic = f"{opt_dir}/6-name2semantic.tsv"
for i_part in range(all_parts):
semantic_path = f"{opt_dir}/6-name2semantic-{i_part}.tsv"
if os.path.exists(semantic_path):
with open(semantic_path, "r", encoding="utf8") as f:
opt += f.read().strip("\n").split("\n")
os.remove(semantic_path)
with open(path_semantic, "w", encoding="utf8") as f:
f.write("\n".join(opt))
jobs[job_id]["status"] = "completed"
jobs[job_id]["result"] = {
"exp_name": request.exp_name,
"stages_completed": ["open1a", "open1b", "open1c"]
}
jobs[job_id]["completed_at"] = datetime.now().isoformat()
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
jobs[job_id]["traceback"] = traceback.format_exc()
jobs[job_id]["failed_at"] = datetime.now().isoformat()
@APP.post("/training/format-dataset")
async def format_dataset_endpoint(request: DatasetFormattingRequest):
"""
Start dataset formatting job (open1a -> open1b -> open1c).
Wraps webui.open1a(), open1b(), open1c() sequentially.
"""
# DEBUG: Print received payload
print(f"\n{'='*80}")
print(f"[DATASET FORMATTING DEBUG] Received request:")
print(f" version: {request.version}")
print(f" inp_text: {request.inp_text}")
print(f" inp_wav_dir: {request.inp_wav_dir}")
print(f" exp_name: {request.exp_name}")
print(f" gpu_numbers1a: {request.gpu_numbers}")
print(f" bert_pretrained_dir: {request.bert_pretrained_dir}")
print(f" ssl_pretrained_dir: {request.ssl_pretrained_dir}")
print(f" pretrained_s2G_path: {request.pretrained_s2G_path}")
print(f"{'='*80}\n")
job_id = str(uuid.uuid4())
jobs[job_id] = {
"status": "queued",
"operation": "format_dataset",
"created_at": datetime.now().isoformat()
}
try:
asyncio.create_task(execute_dataset_formatting(job_id, request))
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
async def execute_fine_tune_sovits_direct(job_id: str, request: FineTuneSoVITSRequest):
"""
Execute SoVITS fine-tuning by directly calling s2_train.py subprocess.
Replaces webui.open1Ba() to avoid Gradio dependency.
"""
jobs[job_id]["status"] = "running"
jobs[job_id]["started_at"] = datetime.now().isoformat()
try:
s2_dir = f"{exp_root}/{request.exp_name}"
os.makedirs(f"{s2_dir}/logs_s2_{request.version}", exist_ok=True)
# Load config template
config_file = (
"GPT_SoVITS/configs/s2.json"
if request.version not in {"v2Pro", "v2ProPlus"}
else f"GPT_SoVITS/configs/s2{request.version}.json"
)
with open(config_file) as f:
data = json.loads(f.read())
# Update config with request parameters
batch_size = request.batch_size
if is_half == False:
data["train"]["fp16_run"] = False
batch_size = max(1, batch_size // 2)
data["train"]["batch_size"] = batch_size
data["train"]["epochs"] = request.total_epoch
data["train"]["text_low_lr_rate"] = request.text_low_lr_rate
data["train"]["pretrained_s2G"] = request.pretrained_s2G
data["train"]["pretrained_s2D"] = request.pretrained_s2D
data["train"]["if_save_latest"] = request.if_save_latest
data["train"]["if_save_every_weights"] = request.if_save_every_weights
data["train"]["save_every_epoch"] = request.save_every_epoch
data["train"]["gpu_numbers"] = request.gpu_numbers1Ba
data["train"]["grad_ckpt"] = request.if_grad_ckpt
data["train"]["lora_rank"] = request.lora_rank
data["model"]["version"] = request.version
data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir
data["save_weight_dir"] = SoVITS_weight_version2root[request.version]
data["name"] = request.exp_name
data["version"] = request.version
# Write temporary config
tmp_config_path = f"{now_dir}/TEMP/tmp_s2.json"
os.makedirs(f"{now_dir}/TEMP", exist_ok=True)
with open(tmp_config_path, "w") as f:
f.write(json.dumps(data))
# Prepare environment with PYTHONPATH
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")])
# Determine training script based on version
if request.version in ["v1", "v2", "v2Pro", "v2ProPlus"]:
cmd = [python_exec, "GPT_SoVITS/s2_train.py", "--config", tmp_config_path]
else:
cmd = [python_exec, "GPT_SoVITS/s2_train_v3_lora.py", "--config", tmp_config_path]
print(f"[SOVITS FINE-TUNING] Executing: {' '.join(cmd)}")
result = await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
# Find latest SoVITS checkpoint
sovits_weights_dir = data["save_weight_dir"]
latest_sovits_checkpoint = None
if os.path.exists(sovits_weights_dir):
import re
pattern = re.compile(rf"^{re.escape(request.exp_name)}_e(\d+)_s(\d+)_l(\d+)\.pth$")
checkpoints = []
for filename in os.listdir(sovits_weights_dir):
match = pattern.match(filename)
if match:
epoch = int(match.group(1))
step = int(match.group(2))
checkpoints.append((epoch, step, filename))
if checkpoints:
checkpoints.sort(reverse=True)
latest_filename = checkpoints[0][2]
latest_sovits_checkpoint = os.path.join(sovits_weights_dir, latest_filename)
print(f"[SOVITS FINE-TUNING] Latest checkpoint: {latest_sovits_checkpoint}")
jobs[job_id]["status"] = "completed"
jobs[job_id]["result"] = {
"exp_name": request.exp_name,
"config_path": tmp_config_path,
"checkpoint_path": latest_sovits_checkpoint,
"sovits_checkpoint_path": latest_sovits_checkpoint
}
jobs[job_id]["completed_at"] = datetime.now().isoformat()
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
jobs[job_id]["traceback"] = traceback.format_exc()
jobs[job_id]["failed_at"] = datetime.now().isoformat()
@APP.post("/training/fine-tune-sovits")
async def fine_tune_sovits_endpoint(request: FineTuneSoVITSRequest):
"""
Start SoVITS fine-tuning job.
Directly executes s2_train.py (no webui dependency).
"""
# DEBUG: Print received payload
print(f"\n{'='*80}")
print(f"[SOVITS FINE-TUNING DEBUG] Received request:")
print(f" version: {request.version}")
print(f" batch_size: {request.batch_size}")
print(f" total_epoch: {request.total_epoch}")
print(f" exp_name: {request.exp_name}")
print(f" text_low_lr_rate: {request.text_low_lr_rate}")
print(f" if_save_latest: {request.if_save_latest}")
print(f" if_save_every_weights: {request.if_save_every_weights}")
print(f" save_every_epoch: {request.save_every_epoch}")
print(f" gpu_numbers1Ba: {request.gpu_numbers1Ba}")
print(f" pretrained_s2G: {request.pretrained_s2G}")
print(f" pretrained_s2D: {request.pretrained_s2D}")
print(f" if_grad_ckpt: {request.if_grad_ckpt}")
print(f" lora_rank: {request.lora_rank}")
print(f"{'='*80}\n")
job_id = str(uuid.uuid4())
jobs[job_id] = {
"status": "queued",
"operation": "fine_tune_sovits",
"created_at": datetime.now().isoformat()
}
try:
asyncio.create_task(execute_fine_tune_sovits_direct(job_id, request))
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
@APP.post("/training/fine-tune-gpt")
async def fine_tune_gpt_endpoint(request: FineTuneGPTRequest):
"""
Start GPT fine-tuning job.
Wraps webui.open1Bb()
"""
# DEBUG: Print received payload
print(f"\n{'='*80}")
print(f"[GPT FINE-TUNING DEBUG] Received request:")
print(f" batch_size: {request.batch_size}")
print(f" total_epoch: {request.total_epoch}")
print(f" exp_name: {request.exp_name}")
print(f" if_dpo: {request.if_dpo}")
print(f" if_save_latest: {request.if_save_latest}")
print(f" if_save_every_weights: {request.if_save_every_weights}")
print(f" save_every_epoch: {request.save_every_epoch}")
print(f" gpu_numbers: {request.gpu_numbers}")
print(f" pretrained_s1: {request.pretrained_s1}")
print(f"{'='*80}\n")
job_id = str(uuid.uuid4())
jobs[job_id] = {
"status": "queued",
"operation": "fine_tune_gpt",
"created_at": datetime.now().isoformat()
}
try:
asyncio.create_task(execute_fine_tune_gpt_direct(job_id, request))
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
async def execute_fine_tune_gpt_direct(job_id: str, request: FineTuneGPTRequest):
"""
Execute GPT fine-tuning by directly calling s1_train.py subprocess.
Replaces webui.open1Bb() to avoid Gradio dependency.
"""
jobs[job_id]["status"] = "running"
jobs[job_id]["started_at"] = datetime.now().isoformat()
try:
s1_dir = f"{exp_root}/{request.exp_name}"
os.makedirs(f"{s1_dir}/logs_s1", exist_ok=True)
# Determine version (from webui.py line 606)
version = os.environ.get("version", "v4")
# Load config template
config_path = (
"GPT_SoVITS/configs/s1longer.yaml" if version == "v1"
else "GPT_SoVITS/configs/s1longer-v2.yaml"
)
with open(config_path) as f:
data = yaml.load(f.read(), Loader=yaml.FullLoader)
# Update config with request parameters
batch_size = request.batch_size
if is_half == False:
data["train"]["precision"] = "32"
batch_size = max(1, batch_size // 2)
data["train"]["batch_size"] = batch_size
data["train"]["epochs"] = request.total_epoch
data["pretrained_s1"] = request.pretrained_s1
data["train"]["save_every_n_epoch"] = request.save_every_epoch
data["train"]["if_save_every_weights"] = request.if_save_every_weights
data["train"]["if_save_latest"] = request.if_save_latest
data["train"]["if_dpo"] = request.if_dpo
data["train"]["half_weights_save_dir"] = GPT_weight_version2root[version]
data["train"]["exp_name"] = request.exp_name
data["train_semantic_path"] = f"{s1_dir}/6-name2semantic.tsv"
data["train_phoneme_path"] = f"{s1_dir}/2-name2text.txt"
data["output_dir"] = f"{s1_dir}/logs_s1_{version}"
# Set environment variables for GPU and PYTHONPATH
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")])
env["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(request.gpu_numbers.replace("-", ","))
env["hz"] = "25hz"
# Write temporary config
tmp_config_path = f"{now_dir}/TEMP/tmp_s1.yaml"
os.makedirs(f"{now_dir}/TEMP", exist_ok=True)
with open(tmp_config_path, "w") as f:
f.write(yaml.dump(data, default_flow_style=False))
# Execute training
cmd = [python_exec, "GPT_SoVITS/s1_train.py", "--config_file", tmp_config_path]
print(f"[GPT FINE-TUNING] Executing: {' '.join(cmd)}")
result = await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
# Find latest GPT checkpoint
gpt_weights_dir = data["train"]["half_weights_save_dir"]
latest_gpt_checkpoint = None
if os.path.exists(gpt_weights_dir):
import re
pattern = re.compile(rf"^{re.escape(request.exp_name)}-e(\d+)\.ckpt$")
checkpoints = []
for filename in os.listdir(gpt_weights_dir):
match = pattern.match(filename)
if match:
epoch = int(match.group(1))
checkpoints.append((epoch, filename))
if checkpoints:
checkpoints.sort(reverse=True)
latest_filename = checkpoints[0][1]
latest_gpt_checkpoint = os.path.join(gpt_weights_dir, latest_filename)
print(f"[GPT FINE-TUNING] Latest checkpoint: {latest_gpt_checkpoint}")
jobs[job_id]["status"] = "completed"
jobs[job_id]["result"] = {
"exp_name": request.exp_name,
"config_path": tmp_config_path,
"checkpoint_path": latest_gpt_checkpoint,
"gpt_checkpoint_path": latest_gpt_checkpoint
}
jobs[job_id]["completed_at"] = datetime.now().isoformat()
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
jobs[job_id]["traceback"] = traceback.format_exc()
jobs[job_id]["failed_at"] = datetime.now().isoformat()
if __name__ == "__main__": if __name__ == "__main__":
try: try:
if host == "None": # 在调用时使用 -a None 参数可以让api监听双栈 if host == "None": # 在调用时使用 -a None 参数可以让api监听双栈