From e8616c87c65ed5e5c3091e65464edd020869aa8b Mon Sep 17 00:00:00 2001 From: SanghyeonAn94 Date: Wed, 15 Oct 2025 17:06:21 +0900 Subject: [PATCH] =?UTF-8?q?3=EC=B4=88=20=EB=AF=B8=EB=A7=8C=20=EC=A0=9C?= =?UTF-8?q?=ED=95=9C=EC=97=86=EC=9D=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 4 +- GPT_SoVITS/inference_webui.py | 3 - api_v2.py | 777 ++++++++++++++++++++++++++++++- 3 files changed, 774 insertions(+), 10 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0c1d2484..8b623aa9 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -799,8 +799,8 @@ class TTS: ) with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + # if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + # raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) wav16k = torch.from_numpy(wav16k) zero_wav_torch = torch.from_numpy(zero_wav) wav16k = wav16k.to(self.configs.device) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index a361ed58..86302cd8 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -811,9 +811,6 @@ def get_tts_wav( if not ref_free: with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - gr.Warning(i18n("参考音频在3~10秒范围外,请更换!")) - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) wav16k = torch.from_numpy(wav16k) if is_half == True: wav16k = wav16k.half().to(device) diff --git a/api_v2.py b/api_v2.py index 5947df53..88b819b9 100644 --- a/api_v2.py +++ b/api_v2.py @@ -101,7 +101,10 @@ RESP: import os import sys import traceback -from typing import Generator +from typing import Generator, Dict, Any, Optional +import uuid +import asyncio +from datetime import datetime now_dir = os.getcwd() sys.path.append(now_dir) @@ -121,11 +124,47 @@ from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel +import json +import yaml + +# Import config variables (avoiding webui to prevent Gradio loading) +from config import ( + exp_root, + python_exec, + is_half, + GPU_INDEX, + infer_device, + SoVITS_weight_version2root, + GPT_weight_version2root, +) # print(sys.path) i18n = I18nAuto() cut_method_names = get_cut_method_names() +# GPU helper functions (replicated from webui.py to avoid import) +set_gpu_numbers = GPU_INDEX +default_gpu_numbers = infer_device.index if hasattr(infer_device, 'index') else 0 + +def fix_gpu_number(input_val): + """Fix GPU number to be within valid range.""" + try: + if int(input_val) not in set_gpu_numbers: + return default_gpu_numbers + except: + return input_val + return input_val + +def fix_gpu_numbers(inputs): + """Fix multiple GPU numbers separated by comma.""" + output = [] + try: + for input_val in inputs.split(","): + output.append(str(fix_gpu_number(input_val))) + return ",".join(output) + except: + return inputs + parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") @@ -172,6 +211,69 @@ class TTS_Request(BaseModel): super_sampling: bool = False +class SpeechSlicingRequest(BaseModel): + inp: str + opt_root: str + threshold: str = "-34" + min_length: str = "4000" + min_interval: str = "300" + hop_size: str = "10" + max_sil_kept: str = "500" + _max: float = 0.9 + alpha: float = 0.25 + n_parts: int = 4 + + +class STTRequest(BaseModel): + input_folder: str + output_folder: str + model_path: str = "tools/asr/models/faster-whisper-large-v3" + language: str = "auto" + precision: str = "float32" + + +class DatasetFormattingRequest(BaseModel): + inp_text: str + inp_wav_dir: str + exp_name: str + version: str = "v4" + gpu_numbers: str = "0-0" + bert_pretrained_dir: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" + ssl_pretrained_dir: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base" + pretrained_s2G_path: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth" + + +class FineTuneSoVITSRequest(BaseModel): + version: str = "v4" + batch_size: int = 2 + total_epoch: int = 2 + exp_name: str + text_low_lr_rate: float = 0.4 + if_save_latest: bool = True + if_save_every_weights: bool = True + save_every_epoch: int = 1 + gpu_numbers1Ba: str = "0" + pretrained_s2G: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth" + pretrained_s2D: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Dv4.pth" + if_grad_ckpt: bool = False + lora_rank: str = "32" + + +class FineTuneGPTRequest(BaseModel): + batch_size: int = 8 + total_epoch: int = 15 + exp_name: str + if_dpo: bool = False + if_save_latest: bool = True + if_save_every_weights: bool = True + save_every_epoch: int = 5 + gpu_numbers: str = "0" + pretrained_s1: str = "GPT_SoVITS/pretrained_models/s1v3.ckpt" + + +jobs: Dict[str, Dict[str, Any]] = {} + + ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: @@ -272,14 +374,14 @@ def check_params(req: dict): return JSONResponse(status_code=400, content={"message": "text is required"}) if text_lang in [None, ""]: return JSONResponse(status_code=400, content={"message": "text_lang is required"}) - elif text_lang.lower() not in tts_config.languages: + elif text_lang not in tts_config.languages: return JSONResponse( status_code=400, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, ) if prompt_lang in [None, ""]: return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) - elif prompt_lang.lower() not in tts_config.languages: + elif prompt_lang not in tts_config.languages: return JSONResponse( status_code=400, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, @@ -407,11 +509,11 @@ async def tts_get_endpoint( ): req = { "text": text, - "text_lang": text_lang.lower(), + "text_lang": text_lang, "ref_audio_path": ref_audio_path, "aux_ref_audio_paths": aux_ref_audio_paths, "prompt_text": prompt_text, - "prompt_lang": prompt_lang.lower(), + "prompt_lang": prompt_lang, "top_k": top_k, "top_p": top_p, "temperature": temperature, @@ -434,6 +536,23 @@ async def tts_get_endpoint( @APP.post("/tts") async def tts_post_endpoint(request: TTS_Request): + # DEBUG: Print received payload + print(f"\n{'='*80}") + print(f"[TTS DEBUG] Received request:") + print(f" text: {request.text[:100] if len(request.text) > 100 else request.text}") # Truncate long text + print(f" text_lang: {request.text_lang}") + print(f" ref_audio_path: {request.ref_audio_path}") + print(f" prompt_text: {request.prompt_text[:100] if request.prompt_text and len(request.prompt_text) > 100 else request.prompt_text}") + print(f" prompt_lang: {request.prompt_lang}") + print(f" top_k: {request.top_k}") + print(f" top_p: {request.top_p}") + print(f" temperature: {request.temperature}") + print(f" text_split_method: {request.text_split_method}") + print(f" batch_size: {request.batch_size}") + print(f" speed_factor: {request.speed_factor}") + print(f" streaming_mode: {request.streaming_mode}") + print(f"{'='*80}\n") + req = request.dict() return await tts_handle(req) @@ -489,6 +608,654 @@ async def set_sovits_weights(weights_path: str = None): return JSONResponse(status_code=200, content={"message": "success"}) +async def execute_job_async(job_id: str, operation_func, *args, **kwargs): + """ + Execute a job asynchronously in background. + + Args: + job_id: Unique job identifier + operation_func: Function to execute (from webui.py) + args, kwargs: Arguments for the operation function + """ + jobs[job_id]["status"] = "running" + jobs[job_id]["started_at"] = datetime.now().isoformat() + + try: + result = await asyncio.to_thread(operation_func, *args, **kwargs) + + if hasattr(result, '__iter__') and not isinstance(result, (str, dict)): + final_result = None + for item in result: + final_result = item + result = final_result + + jobs[job_id]["status"] = "completed" + jobs[job_id]["result"] = result + jobs[job_id]["completed_at"] = datetime.now().isoformat() + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + jobs[job_id]["traceback"] = traceback.format_exc() + jobs[job_id]["failed_at"] = datetime.now().isoformat() + + +@APP.get("/jobs/{job_id}") +@APP.get("/job-status/{job_id}") # Alias for compatibility +async def get_job_status(job_id: str): + """ + Get job status and result. + + Returns: + { + "job_id": str, + "status": "queued" | "running" | "completed" | "failed", + "result": Any (if completed), + "error": str (if failed), + "created_at": str, + "started_at": str (if running/completed/failed), + "completed_at": str (if completed), + "failed_at": str (if failed) + } + """ + if job_id not in jobs: + return JSONResponse(status_code=404, content={"message": "job not found"}) + + job_data = jobs[job_id].copy() + job_data["job_id"] = job_id + return JSONResponse(status_code=200, content=job_data) + + +async def execute_speech_slicing_direct(job_id: str, request: SpeechSlicingRequest): + """ + Execute speech slicing by directly calling slice_audio.py subprocess. + Replaces webui.open_slice() to avoid Gradio dependency. + """ + jobs[job_id]["status"] = "running" + jobs[job_id]["started_at"] = datetime.now().isoformat() + + try: + # Prepare environment with PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]) + + # Create processes for parallel slicing (n_parts) + processes = [] + for i_part in range(request.n_parts): + cmd = [ + python_exec, + "tools/slice_audio.py", + request.inp, + request.opt_root, + str(request.threshold), + str(request.min_length), + str(request.min_interval), + str(request.hop_size), + str(request.max_sil_kept), + str(request._max), + str(request.alpha), + str(i_part), + str(request.n_parts), + ] + print(f"[SPEECH SLICING] Executing: {' '.join(cmd)}") + p = subprocess.Popen(cmd, env=env, cwd=now_dir) + processes.append(p) + + # Wait for all processes to complete + for p in processes: + p.wait() + + # Check if any process failed + exit_codes = [p.returncode for p in processes] + if any(code != 0 for code in exit_codes): + raise Exception(f"Speech slicing failed with exit codes: {exit_codes}") + + jobs[job_id]["status"] = "completed" + jobs[job_id]["result"] = { + "output_dir": request.opt_root, + "file_count": request.n_parts + } + jobs[job_id]["completed_at"] = datetime.now().isoformat() + + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + jobs[job_id]["traceback"] = traceback.format_exc() + jobs[job_id]["failed_at"] = datetime.now().isoformat() + + +@APP.post("/preprocessing/speech-slicing") +async def speech_slicing_endpoint(request: SpeechSlicingRequest): + """ + Start speech slicing job. + + Directly executes tools/slice_audio.py (no webui dependency). + """ + # DEBUG: Print received payload + print(f"\n{'='*80}") + print(f"[SPEECH SLICING DEBUG] Received request:") + print(f" inp: {request.inp}") + print(f" opt_root: {request.opt_root}") + print(f" threshold: {request.threshold}") + print(f" min_length: {request.min_length}") + print(f" min_interval: {request.min_interval}") + print(f" hop_size: {request.hop_size}") + print(f" max_sil_kept: {request.max_sil_kept}") + print(f" _max: {request._max}") + print(f" alpha: {request.alpha}") + print(f" n_parts: {request.n_parts}") + print(f"{'='*80}\n") + + job_id = str(uuid.uuid4()) + jobs[job_id] = { + "status": "queued", + "operation": "speech_slicing", + "created_at": datetime.now().isoformat() + } + + try: + asyncio.create_task(execute_speech_slicing_direct(job_id, request)) + return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"}) + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)}) + + +@APP.post("/preprocessing/stt") +async def stt_endpoint(request: STTRequest): + """ + Start STT (Speech-to-Text) job. + + Wraps tools/asr/fasterwhisper_asr.execute_asr() + """ + # DEBUG: Print received payload + print(f"\n{'='*80}") + print(f"[STT DEBUG] Received STT request:") + print(request) + print(f"{'='*80}\n") + + job_id = str(uuid.uuid4()) + jobs[job_id] = { + "status": "queued", + "operation": "stt", + "created_at": datetime.now().isoformat() + } + + try: + from tools.asr.fasterwhisper_asr import execute_asr + + asyncio.create_task(execute_job_async( + job_id, + execute_asr, + request.input_folder, + request.output_folder, + request.model_path, + request.language, + request.precision + )) + + return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"}) + except Exception as e: + print(f"[STT ERROR] Failed to start STT job: {str(e)}") + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)}) + + +async def execute_dataset_formatting(job_id: str, request: DatasetFormattingRequest): + """ + Execute dataset formatting sequentially: open1a -> open1b -> open1c + Directly executes subprocess (no webui dependency). + """ + jobs[job_id]["status"] = "running" + jobs[job_id]["started_at"] = datetime.now().isoformat() + jobs[job_id]["current_stage"] = "open1a" + + try: + opt_dir = f"{exp_root}/{request.exp_name}" + os.makedirs(opt_dir, exist_ok=True) + + # Parse GPU numbers + gpu_names = request.gpu_numbers.split("-") + all_parts = len(gpu_names) + + # Stage 1a: Get text features + print(f"[DATASET FORMATTING] Starting open1a...") + for i_part in range(all_parts): + env = os.environ.copy() + env.update({ + "PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]), + "inp_text": request.inp_text, + "inp_wav_dir": request.inp_wav_dir, + "exp_name": request.exp_name, + "opt_dir": opt_dir, + "bert_pretrained_dir": request.bert_pretrained_dir, + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), + "is_half": str(is_half), + }) + cmd = [python_exec, "GPT_SoVITS/prepare_datasets/1-get-text.py"] + print(f"[DATASET FORMATTING] Executing 1a part {i_part}: {' '.join(cmd)}") + await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True) + + # Merge text files from 1a stage + opt = [] + path_text = f"{opt_dir}/2-name2text.txt" + for i_part in range(all_parts): + text_path = f"{opt_dir}/2-name2text-{i_part}.txt" + if os.path.exists(text_path): + with open(text_path, "r", encoding="utf8") as f: + opt += f.read().strip("\n").split("\n") + os.remove(text_path) + + with open(path_text, "w", encoding="utf8") as f: + f.write("\n".join(opt) + "\n") + + # Stage 1b: Get hubert features + jobs[job_id]["current_stage"] = "open1b" + print(f"[DATASET FORMATTING] Starting open1b...") + sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt" + + for i_part in range(all_parts): + env = os.environ.copy() + env.update({ + "PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]), + "inp_text": request.inp_text, + "inp_wav_dir": request.inp_wav_dir, + "exp_name": request.exp_name, + "opt_dir": opt_dir, + "cnhubert_base_dir": request.ssl_pretrained_dir, + "sv_path": sv_path, + "is_half": str(is_half), + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), + }) + cmd = [python_exec, "GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py"] + print(f"[DATASET FORMATTING] Executing 1b part {i_part}: {' '.join(cmd)}") + await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True) + + # For v2Pro version, also run 2-get-sv.py + if "Pro" in request.version: + for i_part in range(all_parts): + env = os.environ.copy() + env.update({ + "PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]), + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), + "exp_dir": opt_dir, + "sv_path": sv_path, + "is_half": str(is_half), + }) + cmd = [python_exec, "GPT_SoVITS/prepare_datasets/2-get-sv.py"] + print(f"[DATASET FORMATTING] Executing 2-get-sv part {i_part}: {' '.join(cmd)}") + await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True) + + # Stage 1c: Get semantic features + jobs[job_id]["current_stage"] = "open1c" + print(f"[DATASET FORMATTING] Starting open1c...") + config_file = ( + "GPT_SoVITS/configs/s2.json" + if request.version not in {"v2Pro", "v2ProPlus"} + else f"GPT_SoVITS/configs/s2{request.version}.json" + ) + + for i_part in range(all_parts): + env = os.environ.copy() + env.update({ + "PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]), + "inp_text": request.inp_text, + "exp_name": request.exp_name, + "opt_dir": opt_dir, + "pretrained_s2G": request.pretrained_s2G_path, + "s2config_path": config_file, + "is_half": str(is_half), + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), + }) + cmd = [python_exec, "GPT_SoVITS/prepare_datasets/3-get-semantic.py"] + print(f"[DATASET FORMATTING] Executing 1c part {i_part}: {' '.join(cmd)}") + await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True) + + # Merge semantic files (from open1c logic in webui.py) + opt = ["item_name\tsemantic_audio"] + path_semantic = f"{opt_dir}/6-name2semantic.tsv" + for i_part in range(all_parts): + semantic_path = f"{opt_dir}/6-name2semantic-{i_part}.tsv" + if os.path.exists(semantic_path): + with open(semantic_path, "r", encoding="utf8") as f: + opt += f.read().strip("\n").split("\n") + os.remove(semantic_path) + + with open(path_semantic, "w", encoding="utf8") as f: + f.write("\n".join(opt)) + + jobs[job_id]["status"] = "completed" + jobs[job_id]["result"] = { + "exp_name": request.exp_name, + "stages_completed": ["open1a", "open1b", "open1c"] + } + jobs[job_id]["completed_at"] = datetime.now().isoformat() + + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + jobs[job_id]["traceback"] = traceback.format_exc() + jobs[job_id]["failed_at"] = datetime.now().isoformat() + + +@APP.post("/training/format-dataset") +async def format_dataset_endpoint(request: DatasetFormattingRequest): + """ + Start dataset formatting job (open1a -> open1b -> open1c). + + Wraps webui.open1a(), open1b(), open1c() sequentially. + """ + # DEBUG: Print received payload + print(f"\n{'='*80}") + print(f"[DATASET FORMATTING DEBUG] Received request:") + print(f" version: {request.version}") + print(f" inp_text: {request.inp_text}") + print(f" inp_wav_dir: {request.inp_wav_dir}") + print(f" exp_name: {request.exp_name}") + print(f" gpu_numbers1a: {request.gpu_numbers}") + print(f" bert_pretrained_dir: {request.bert_pretrained_dir}") + print(f" ssl_pretrained_dir: {request.ssl_pretrained_dir}") + print(f" pretrained_s2G_path: {request.pretrained_s2G_path}") + print(f"{'='*80}\n") + + job_id = str(uuid.uuid4()) + jobs[job_id] = { + "status": "queued", + "operation": "format_dataset", + "created_at": datetime.now().isoformat() + } + + try: + asyncio.create_task(execute_dataset_formatting(job_id, request)) + return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"}) + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)}) + + +async def execute_fine_tune_sovits_direct(job_id: str, request: FineTuneSoVITSRequest): + """ + Execute SoVITS fine-tuning by directly calling s2_train.py subprocess. + Replaces webui.open1Ba() to avoid Gradio dependency. + """ + jobs[job_id]["status"] = "running" + jobs[job_id]["started_at"] = datetime.now().isoformat() + + try: + s2_dir = f"{exp_root}/{request.exp_name}" + os.makedirs(f"{s2_dir}/logs_s2_{request.version}", exist_ok=True) + + # Load config template + config_file = ( + "GPT_SoVITS/configs/s2.json" + if request.version not in {"v2Pro", "v2ProPlus"} + else f"GPT_SoVITS/configs/s2{request.version}.json" + ) + with open(config_file) as f: + data = json.loads(f.read()) + + # Update config with request parameters + batch_size = request.batch_size + if is_half == False: + data["train"]["fp16_run"] = False + batch_size = max(1, batch_size // 2) + + data["train"]["batch_size"] = batch_size + data["train"]["epochs"] = request.total_epoch + data["train"]["text_low_lr_rate"] = request.text_low_lr_rate + data["train"]["pretrained_s2G"] = request.pretrained_s2G + data["train"]["pretrained_s2D"] = request.pretrained_s2D + data["train"]["if_save_latest"] = request.if_save_latest + data["train"]["if_save_every_weights"] = request.if_save_every_weights + data["train"]["save_every_epoch"] = request.save_every_epoch + data["train"]["gpu_numbers"] = request.gpu_numbers1Ba + data["train"]["grad_ckpt"] = request.if_grad_ckpt + data["train"]["lora_rank"] = request.lora_rank + data["model"]["version"] = request.version + data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir + data["save_weight_dir"] = SoVITS_weight_version2root[request.version] + data["name"] = request.exp_name + data["version"] = request.version + + # Write temporary config + tmp_config_path = f"{now_dir}/TEMP/tmp_s2.json" + os.makedirs(f"{now_dir}/TEMP", exist_ok=True) + with open(tmp_config_path, "w") as f: + f.write(json.dumps(data)) + + # Prepare environment with PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]) + + # Determine training script based on version + if request.version in ["v1", "v2", "v2Pro", "v2ProPlus"]: + cmd = [python_exec, "GPT_SoVITS/s2_train.py", "--config", tmp_config_path] + else: + cmd = [python_exec, "GPT_SoVITS/s2_train_v3_lora.py", "--config", tmp_config_path] + + print(f"[SOVITS FINE-TUNING] Executing: {' '.join(cmd)}") + result = await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True) + + # Find latest SoVITS checkpoint + sovits_weights_dir = data["save_weight_dir"] + latest_sovits_checkpoint = None + + if os.path.exists(sovits_weights_dir): + import re + pattern = re.compile(rf"^{re.escape(request.exp_name)}_e(\d+)_s(\d+)_l(\d+)\.pth$") + checkpoints = [] + for filename in os.listdir(sovits_weights_dir): + match = pattern.match(filename) + if match: + epoch = int(match.group(1)) + step = int(match.group(2)) + checkpoints.append((epoch, step, filename)) + + if checkpoints: + checkpoints.sort(reverse=True) + latest_filename = checkpoints[0][2] + latest_sovits_checkpoint = os.path.join(sovits_weights_dir, latest_filename) + print(f"[SOVITS FINE-TUNING] Latest checkpoint: {latest_sovits_checkpoint}") + + jobs[job_id]["status"] = "completed" + jobs[job_id]["result"] = { + "exp_name": request.exp_name, + "config_path": tmp_config_path, + "checkpoint_path": latest_sovits_checkpoint, + "sovits_checkpoint_path": latest_sovits_checkpoint + } + jobs[job_id]["completed_at"] = datetime.now().isoformat() + + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + jobs[job_id]["traceback"] = traceback.format_exc() + jobs[job_id]["failed_at"] = datetime.now().isoformat() + + +@APP.post("/training/fine-tune-sovits") +async def fine_tune_sovits_endpoint(request: FineTuneSoVITSRequest): + """ + Start SoVITS fine-tuning job. + + Directly executes s2_train.py (no webui dependency). + """ + # DEBUG: Print received payload + print(f"\n{'='*80}") + print(f"[SOVITS FINE-TUNING DEBUG] Received request:") + print(f" version: {request.version}") + print(f" batch_size: {request.batch_size}") + print(f" total_epoch: {request.total_epoch}") + print(f" exp_name: {request.exp_name}") + print(f" text_low_lr_rate: {request.text_low_lr_rate}") + print(f" if_save_latest: {request.if_save_latest}") + print(f" if_save_every_weights: {request.if_save_every_weights}") + print(f" save_every_epoch: {request.save_every_epoch}") + print(f" gpu_numbers1Ba: {request.gpu_numbers1Ba}") + print(f" pretrained_s2G: {request.pretrained_s2G}") + print(f" pretrained_s2D: {request.pretrained_s2D}") + print(f" if_grad_ckpt: {request.if_grad_ckpt}") + print(f" lora_rank: {request.lora_rank}") + print(f"{'='*80}\n") + + job_id = str(uuid.uuid4()) + jobs[job_id] = { + "status": "queued", + "operation": "fine_tune_sovits", + "created_at": datetime.now().isoformat() + } + + try: + asyncio.create_task(execute_fine_tune_sovits_direct(job_id, request)) + return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"}) + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)}) + + +@APP.post("/training/fine-tune-gpt") +async def fine_tune_gpt_endpoint(request: FineTuneGPTRequest): + """ + Start GPT fine-tuning job. + + Wraps webui.open1Bb() + """ + # DEBUG: Print received payload + print(f"\n{'='*80}") + print(f"[GPT FINE-TUNING DEBUG] Received request:") + print(f" batch_size: {request.batch_size}") + print(f" total_epoch: {request.total_epoch}") + print(f" exp_name: {request.exp_name}") + print(f" if_dpo: {request.if_dpo}") + print(f" if_save_latest: {request.if_save_latest}") + print(f" if_save_every_weights: {request.if_save_every_weights}") + print(f" save_every_epoch: {request.save_every_epoch}") + print(f" gpu_numbers: {request.gpu_numbers}") + print(f" pretrained_s1: {request.pretrained_s1}") + print(f"{'='*80}\n") + + job_id = str(uuid.uuid4()) + jobs[job_id] = { + "status": "queued", + "operation": "fine_tune_gpt", + "created_at": datetime.now().isoformat() + } + + try: + asyncio.create_task(execute_fine_tune_gpt_direct(job_id, request)) + return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"}) + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)}) + + +async def execute_fine_tune_gpt_direct(job_id: str, request: FineTuneGPTRequest): + """ + Execute GPT fine-tuning by directly calling s1_train.py subprocess. + Replaces webui.open1Bb() to avoid Gradio dependency. + """ + jobs[job_id]["status"] = "running" + jobs[job_id]["started_at"] = datetime.now().isoformat() + + try: + s1_dir = f"{exp_root}/{request.exp_name}" + os.makedirs(f"{s1_dir}/logs_s1", exist_ok=True) + + # Determine version (from webui.py line 606) + version = os.environ.get("version", "v4") + + # Load config template + config_path = ( + "GPT_SoVITS/configs/s1longer.yaml" if version == "v1" + else "GPT_SoVITS/configs/s1longer-v2.yaml" + ) + with open(config_path) as f: + data = yaml.load(f.read(), Loader=yaml.FullLoader) + + # Update config with request parameters + batch_size = request.batch_size + if is_half == False: + data["train"]["precision"] = "32" + batch_size = max(1, batch_size // 2) + + data["train"]["batch_size"] = batch_size + data["train"]["epochs"] = request.total_epoch + data["pretrained_s1"] = request.pretrained_s1 + data["train"]["save_every_n_epoch"] = request.save_every_epoch + data["train"]["if_save_every_weights"] = request.if_save_every_weights + data["train"]["if_save_latest"] = request.if_save_latest + data["train"]["if_dpo"] = request.if_dpo + data["train"]["half_weights_save_dir"] = GPT_weight_version2root[version] + data["train"]["exp_name"] = request.exp_name + data["train_semantic_path"] = f"{s1_dir}/6-name2semantic.tsv" + data["train_phoneme_path"] = f"{s1_dir}/2-name2text.txt" + data["output_dir"] = f"{s1_dir}/logs_s1_{version}" + + # Set environment variables for GPU and PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]) + env["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(request.gpu_numbers.replace("-", ",")) + env["hz"] = "25hz" + + # Write temporary config + tmp_config_path = f"{now_dir}/TEMP/tmp_s1.yaml" + os.makedirs(f"{now_dir}/TEMP", exist_ok=True) + with open(tmp_config_path, "w") as f: + f.write(yaml.dump(data, default_flow_style=False)) + + # Execute training + cmd = [python_exec, "GPT_SoVITS/s1_train.py", "--config_file", tmp_config_path] + print(f"[GPT FINE-TUNING] Executing: {' '.join(cmd)}") + result = await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True) + + # Find latest GPT checkpoint + gpt_weights_dir = data["train"]["half_weights_save_dir"] + latest_gpt_checkpoint = None + + if os.path.exists(gpt_weights_dir): + import re + pattern = re.compile(rf"^{re.escape(request.exp_name)}-e(\d+)\.ckpt$") + checkpoints = [] + for filename in os.listdir(gpt_weights_dir): + match = pattern.match(filename) + if match: + epoch = int(match.group(1)) + checkpoints.append((epoch, filename)) + + if checkpoints: + checkpoints.sort(reverse=True) + latest_filename = checkpoints[0][1] + latest_gpt_checkpoint = os.path.join(gpt_weights_dir, latest_filename) + print(f"[GPT FINE-TUNING] Latest checkpoint: {latest_gpt_checkpoint}") + + jobs[job_id]["status"] = "completed" + jobs[job_id]["result"] = { + "exp_name": request.exp_name, + "config_path": tmp_config_path, + "checkpoint_path": latest_gpt_checkpoint, + "gpt_checkpoint_path": latest_gpt_checkpoint + } + jobs[job_id]["completed_at"] = datetime.now().isoformat() + + except Exception as e: + jobs[job_id]["status"] = "failed" + jobs[job_id]["error"] = str(e) + jobs[job_id]["traceback"] = traceback.format_exc() + jobs[job_id]["failed_at"] = datetime.now().isoformat() + + if __name__ == "__main__": try: if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈