mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 09:16:59 +08:00
3초 미만 제한없음
This commit is contained in:
parent
11aa78bd9b
commit
e8616c87c6
@ -799,8 +799,8 @@ class TTS:
|
||||
)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
# if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
||||
# raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
wav16k = wav16k.to(self.configs.device)
|
||||
|
||||
@ -811,9 +811,6 @@ def get_tts_wav(
|
||||
if not ref_free:
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
||||
gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
if is_half == True:
|
||||
wav16k = wav16k.half().to(device)
|
||||
|
||||
777
api_v2.py
777
api_v2.py
@ -101,7 +101,10 @@ RESP:
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
from typing import Generator, Dict, Any, Optional
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -121,11 +124,47 @@ from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
import yaml
|
||||
|
||||
# Import config variables (avoiding webui to prevent Gradio loading)
|
||||
from config import (
|
||||
exp_root,
|
||||
python_exec,
|
||||
is_half,
|
||||
GPU_INDEX,
|
||||
infer_device,
|
||||
SoVITS_weight_version2root,
|
||||
GPT_weight_version2root,
|
||||
)
|
||||
|
||||
# print(sys.path)
|
||||
i18n = I18nAuto()
|
||||
cut_method_names = get_cut_method_names()
|
||||
|
||||
# GPU helper functions (replicated from webui.py to avoid import)
|
||||
set_gpu_numbers = GPU_INDEX
|
||||
default_gpu_numbers = infer_device.index if hasattr(infer_device, 'index') else 0
|
||||
|
||||
def fix_gpu_number(input_val):
|
||||
"""Fix GPU number to be within valid range."""
|
||||
try:
|
||||
if int(input_val) not in set_gpu_numbers:
|
||||
return default_gpu_numbers
|
||||
except:
|
||||
return input_val
|
||||
return input_val
|
||||
|
||||
def fix_gpu_numbers(inputs):
|
||||
"""Fix multiple GPU numbers separated by comma."""
|
||||
output = []
|
||||
try:
|
||||
for input_val in inputs.split(","):
|
||||
output.append(str(fix_gpu_number(input_val)))
|
||||
return ",".join(output)
|
||||
except:
|
||||
return inputs
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
@ -172,6 +211,69 @@ class TTS_Request(BaseModel):
|
||||
super_sampling: bool = False
|
||||
|
||||
|
||||
class SpeechSlicingRequest(BaseModel):
|
||||
inp: str
|
||||
opt_root: str
|
||||
threshold: str = "-34"
|
||||
min_length: str = "4000"
|
||||
min_interval: str = "300"
|
||||
hop_size: str = "10"
|
||||
max_sil_kept: str = "500"
|
||||
_max: float = 0.9
|
||||
alpha: float = 0.25
|
||||
n_parts: int = 4
|
||||
|
||||
|
||||
class STTRequest(BaseModel):
|
||||
input_folder: str
|
||||
output_folder: str
|
||||
model_path: str = "tools/asr/models/faster-whisper-large-v3"
|
||||
language: str = "auto"
|
||||
precision: str = "float32"
|
||||
|
||||
|
||||
class DatasetFormattingRequest(BaseModel):
|
||||
inp_text: str
|
||||
inp_wav_dir: str
|
||||
exp_name: str
|
||||
version: str = "v4"
|
||||
gpu_numbers: str = "0-0"
|
||||
bert_pretrained_dir: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
ssl_pretrained_dir: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
pretrained_s2G_path: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
||||
|
||||
|
||||
class FineTuneSoVITSRequest(BaseModel):
|
||||
version: str = "v4"
|
||||
batch_size: int = 2
|
||||
total_epoch: int = 2
|
||||
exp_name: str
|
||||
text_low_lr_rate: float = 0.4
|
||||
if_save_latest: bool = True
|
||||
if_save_every_weights: bool = True
|
||||
save_every_epoch: int = 1
|
||||
gpu_numbers1Ba: str = "0"
|
||||
pretrained_s2G: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
||||
pretrained_s2D: str = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Dv4.pth"
|
||||
if_grad_ckpt: bool = False
|
||||
lora_rank: str = "32"
|
||||
|
||||
|
||||
class FineTuneGPTRequest(BaseModel):
|
||||
batch_size: int = 8
|
||||
total_epoch: int = 15
|
||||
exp_name: str
|
||||
if_dpo: bool = False
|
||||
if_save_latest: bool = True
|
||||
if_save_every_weights: bool = True
|
||||
save_every_epoch: int = 5
|
||||
gpu_numbers: str = "0"
|
||||
pretrained_s1: str = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||
|
||||
|
||||
jobs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
@ -272,14 +374,14 @@ def check_params(req: dict):
|
||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||
if text_lang in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
|
||||
elif text_lang.lower() not in tts_config.languages:
|
||||
elif text_lang not in tts_config.languages:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"},
|
||||
)
|
||||
if prompt_lang in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
||||
elif prompt_lang.lower() not in tts_config.languages:
|
||||
elif prompt_lang not in tts_config.languages:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"},
|
||||
@ -407,11 +509,11 @@ async def tts_get_endpoint(
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
"text_lang": text_lang.lower(),
|
||||
"text_lang": text_lang,
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_text": prompt_text,
|
||||
"prompt_lang": prompt_lang.lower(),
|
||||
"prompt_lang": prompt_lang,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
@ -434,6 +536,23 @@ async def tts_get_endpoint(
|
||||
|
||||
@APP.post("/tts")
|
||||
async def tts_post_endpoint(request: TTS_Request):
|
||||
# DEBUG: Print received payload
|
||||
print(f"\n{'='*80}")
|
||||
print(f"[TTS DEBUG] Received request:")
|
||||
print(f" text: {request.text[:100] if len(request.text) > 100 else request.text}") # Truncate long text
|
||||
print(f" text_lang: {request.text_lang}")
|
||||
print(f" ref_audio_path: {request.ref_audio_path}")
|
||||
print(f" prompt_text: {request.prompt_text[:100] if request.prompt_text and len(request.prompt_text) > 100 else request.prompt_text}")
|
||||
print(f" prompt_lang: {request.prompt_lang}")
|
||||
print(f" top_k: {request.top_k}")
|
||||
print(f" top_p: {request.top_p}")
|
||||
print(f" temperature: {request.temperature}")
|
||||
print(f" text_split_method: {request.text_split_method}")
|
||||
print(f" batch_size: {request.batch_size}")
|
||||
print(f" speed_factor: {request.speed_factor}")
|
||||
print(f" streaming_mode: {request.streaming_mode}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
req = request.dict()
|
||||
return await tts_handle(req)
|
||||
|
||||
@ -489,6 +608,654 @@ async def set_sovits_weights(weights_path: str = None):
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
|
||||
async def execute_job_async(job_id: str, operation_func, *args, **kwargs):
|
||||
"""
|
||||
Execute a job asynchronously in background.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
operation_func: Function to execute (from webui.py)
|
||||
args, kwargs: Arguments for the operation function
|
||||
"""
|
||||
jobs[job_id]["status"] = "running"
|
||||
jobs[job_id]["started_at"] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
result = await asyncio.to_thread(operation_func, *args, **kwargs)
|
||||
|
||||
if hasattr(result, '__iter__') and not isinstance(result, (str, dict)):
|
||||
final_result = None
|
||||
for item in result:
|
||||
final_result = item
|
||||
result = final_result
|
||||
|
||||
jobs[job_id]["status"] = "completed"
|
||||
jobs[job_id]["result"] = result
|
||||
jobs[job_id]["completed_at"] = datetime.now().isoformat()
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
jobs[job_id]["traceback"] = traceback.format_exc()
|
||||
jobs[job_id]["failed_at"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
@APP.get("/jobs/{job_id}")
|
||||
@APP.get("/job-status/{job_id}") # Alias for compatibility
|
||||
async def get_job_status(job_id: str):
|
||||
"""
|
||||
Get job status and result.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"job_id": str,
|
||||
"status": "queued" | "running" | "completed" | "failed",
|
||||
"result": Any (if completed),
|
||||
"error": str (if failed),
|
||||
"created_at": str,
|
||||
"started_at": str (if running/completed/failed),
|
||||
"completed_at": str (if completed),
|
||||
"failed_at": str (if failed)
|
||||
}
|
||||
"""
|
||||
if job_id not in jobs:
|
||||
return JSONResponse(status_code=404, content={"message": "job not found"})
|
||||
|
||||
job_data = jobs[job_id].copy()
|
||||
job_data["job_id"] = job_id
|
||||
return JSONResponse(status_code=200, content=job_data)
|
||||
|
||||
|
||||
async def execute_speech_slicing_direct(job_id: str, request: SpeechSlicingRequest):
|
||||
"""
|
||||
Execute speech slicing by directly calling slice_audio.py subprocess.
|
||||
Replaces webui.open_slice() to avoid Gradio dependency.
|
||||
"""
|
||||
jobs[job_id]["status"] = "running"
|
||||
jobs[job_id]["started_at"] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
# Prepare environment with PYTHONPATH
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")])
|
||||
|
||||
# Create processes for parallel slicing (n_parts)
|
||||
processes = []
|
||||
for i_part in range(request.n_parts):
|
||||
cmd = [
|
||||
python_exec,
|
||||
"tools/slice_audio.py",
|
||||
request.inp,
|
||||
request.opt_root,
|
||||
str(request.threshold),
|
||||
str(request.min_length),
|
||||
str(request.min_interval),
|
||||
str(request.hop_size),
|
||||
str(request.max_sil_kept),
|
||||
str(request._max),
|
||||
str(request.alpha),
|
||||
str(i_part),
|
||||
str(request.n_parts),
|
||||
]
|
||||
print(f"[SPEECH SLICING] Executing: {' '.join(cmd)}")
|
||||
p = subprocess.Popen(cmd, env=env, cwd=now_dir)
|
||||
processes.append(p)
|
||||
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.wait()
|
||||
|
||||
# Check if any process failed
|
||||
exit_codes = [p.returncode for p in processes]
|
||||
if any(code != 0 for code in exit_codes):
|
||||
raise Exception(f"Speech slicing failed with exit codes: {exit_codes}")
|
||||
|
||||
jobs[job_id]["status"] = "completed"
|
||||
jobs[job_id]["result"] = {
|
||||
"output_dir": request.opt_root,
|
||||
"file_count": request.n_parts
|
||||
}
|
||||
jobs[job_id]["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
jobs[job_id]["traceback"] = traceback.format_exc()
|
||||
jobs[job_id]["failed_at"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
@APP.post("/preprocessing/speech-slicing")
|
||||
async def speech_slicing_endpoint(request: SpeechSlicingRequest):
|
||||
"""
|
||||
Start speech slicing job.
|
||||
|
||||
Directly executes tools/slice_audio.py (no webui dependency).
|
||||
"""
|
||||
# DEBUG: Print received payload
|
||||
print(f"\n{'='*80}")
|
||||
print(f"[SPEECH SLICING DEBUG] Received request:")
|
||||
print(f" inp: {request.inp}")
|
||||
print(f" opt_root: {request.opt_root}")
|
||||
print(f" threshold: {request.threshold}")
|
||||
print(f" min_length: {request.min_length}")
|
||||
print(f" min_interval: {request.min_interval}")
|
||||
print(f" hop_size: {request.hop_size}")
|
||||
print(f" max_sil_kept: {request.max_sil_kept}")
|
||||
print(f" _max: {request._max}")
|
||||
print(f" alpha: {request.alpha}")
|
||||
print(f" n_parts: {request.n_parts}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {
|
||||
"status": "queued",
|
||||
"operation": "speech_slicing",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
try:
|
||||
asyncio.create_task(execute_speech_slicing_direct(job_id, request))
|
||||
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
|
||||
|
||||
|
||||
@APP.post("/preprocessing/stt")
|
||||
async def stt_endpoint(request: STTRequest):
|
||||
"""
|
||||
Start STT (Speech-to-Text) job.
|
||||
|
||||
Wraps tools/asr/fasterwhisper_asr.execute_asr()
|
||||
"""
|
||||
# DEBUG: Print received payload
|
||||
print(f"\n{'='*80}")
|
||||
print(f"[STT DEBUG] Received STT request:")
|
||||
print(request)
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {
|
||||
"status": "queued",
|
||||
"operation": "stt",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
try:
|
||||
from tools.asr.fasterwhisper_asr import execute_asr
|
||||
|
||||
asyncio.create_task(execute_job_async(
|
||||
job_id,
|
||||
execute_asr,
|
||||
request.input_folder,
|
||||
request.output_folder,
|
||||
request.model_path,
|
||||
request.language,
|
||||
request.precision
|
||||
))
|
||||
|
||||
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
|
||||
except Exception as e:
|
||||
print(f"[STT ERROR] Failed to start STT job: {str(e)}")
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
|
||||
|
||||
|
||||
async def execute_dataset_formatting(job_id: str, request: DatasetFormattingRequest):
|
||||
"""
|
||||
Execute dataset formatting sequentially: open1a -> open1b -> open1c
|
||||
Directly executes subprocess (no webui dependency).
|
||||
"""
|
||||
jobs[job_id]["status"] = "running"
|
||||
jobs[job_id]["started_at"] = datetime.now().isoformat()
|
||||
jobs[job_id]["current_stage"] = "open1a"
|
||||
|
||||
try:
|
||||
opt_dir = f"{exp_root}/{request.exp_name}"
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
|
||||
# Parse GPU numbers
|
||||
gpu_names = request.gpu_numbers.split("-")
|
||||
all_parts = len(gpu_names)
|
||||
|
||||
# Stage 1a: Get text features
|
||||
print(f"[DATASET FORMATTING] Starting open1a...")
|
||||
for i_part in range(all_parts):
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
|
||||
"inp_text": request.inp_text,
|
||||
"inp_wav_dir": request.inp_wav_dir,
|
||||
"exp_name": request.exp_name,
|
||||
"opt_dir": opt_dir,
|
||||
"bert_pretrained_dir": request.bert_pretrained_dir,
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
"is_half": str(is_half),
|
||||
})
|
||||
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/1-get-text.py"]
|
||||
print(f"[DATASET FORMATTING] Executing 1a part {i_part}: {' '.join(cmd)}")
|
||||
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
|
||||
|
||||
# Merge text files from 1a stage
|
||||
opt = []
|
||||
path_text = f"{opt_dir}/2-name2text.txt"
|
||||
for i_part in range(all_parts):
|
||||
text_path = f"{opt_dir}/2-name2text-{i_part}.txt"
|
||||
if os.path.exists(text_path):
|
||||
with open(text_path, "r", encoding="utf8") as f:
|
||||
opt += f.read().strip("\n").split("\n")
|
||||
os.remove(text_path)
|
||||
|
||||
with open(path_text, "w", encoding="utf8") as f:
|
||||
f.write("\n".join(opt) + "\n")
|
||||
|
||||
# Stage 1b: Get hubert features
|
||||
jobs[job_id]["current_stage"] = "open1b"
|
||||
print(f"[DATASET FORMATTING] Starting open1b...")
|
||||
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
||||
|
||||
for i_part in range(all_parts):
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
|
||||
"inp_text": request.inp_text,
|
||||
"inp_wav_dir": request.inp_wav_dir,
|
||||
"exp_name": request.exp_name,
|
||||
"opt_dir": opt_dir,
|
||||
"cnhubert_base_dir": request.ssl_pretrained_dir,
|
||||
"sv_path": sv_path,
|
||||
"is_half": str(is_half),
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
})
|
||||
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py"]
|
||||
print(f"[DATASET FORMATTING] Executing 1b part {i_part}: {' '.join(cmd)}")
|
||||
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
|
||||
|
||||
# For v2Pro version, also run 2-get-sv.py
|
||||
if "Pro" in request.version:
|
||||
for i_part in range(all_parts):
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
"exp_dir": opt_dir,
|
||||
"sv_path": sv_path,
|
||||
"is_half": str(is_half),
|
||||
})
|
||||
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/2-get-sv.py"]
|
||||
print(f"[DATASET FORMATTING] Executing 2-get-sv part {i_part}: {' '.join(cmd)}")
|
||||
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
|
||||
|
||||
# Stage 1c: Get semantic features
|
||||
jobs[job_id]["current_stage"] = "open1c"
|
||||
print(f"[DATASET FORMATTING] Starting open1c...")
|
||||
config_file = (
|
||||
"GPT_SoVITS/configs/s2.json"
|
||||
if request.version not in {"v2Pro", "v2ProPlus"}
|
||||
else f"GPT_SoVITS/configs/s2{request.version}.json"
|
||||
)
|
||||
|
||||
for i_part in range(all_parts):
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
"PYTHONPATH": os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")]),
|
||||
"inp_text": request.inp_text,
|
||||
"exp_name": request.exp_name,
|
||||
"opt_dir": opt_dir,
|
||||
"pretrained_s2G": request.pretrained_s2G_path,
|
||||
"s2config_path": config_file,
|
||||
"is_half": str(is_half),
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
})
|
||||
cmd = [python_exec, "GPT_SoVITS/prepare_datasets/3-get-semantic.py"]
|
||||
print(f"[DATASET FORMATTING] Executing 1c part {i_part}: {' '.join(cmd)}")
|
||||
await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
|
||||
|
||||
# Merge semantic files (from open1c logic in webui.py)
|
||||
opt = ["item_name\tsemantic_audio"]
|
||||
path_semantic = f"{opt_dir}/6-name2semantic.tsv"
|
||||
for i_part in range(all_parts):
|
||||
semantic_path = f"{opt_dir}/6-name2semantic-{i_part}.tsv"
|
||||
if os.path.exists(semantic_path):
|
||||
with open(semantic_path, "r", encoding="utf8") as f:
|
||||
opt += f.read().strip("\n").split("\n")
|
||||
os.remove(semantic_path)
|
||||
|
||||
with open(path_semantic, "w", encoding="utf8") as f:
|
||||
f.write("\n".join(opt))
|
||||
|
||||
jobs[job_id]["status"] = "completed"
|
||||
jobs[job_id]["result"] = {
|
||||
"exp_name": request.exp_name,
|
||||
"stages_completed": ["open1a", "open1b", "open1c"]
|
||||
}
|
||||
jobs[job_id]["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
jobs[job_id]["traceback"] = traceback.format_exc()
|
||||
jobs[job_id]["failed_at"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
@APP.post("/training/format-dataset")
|
||||
async def format_dataset_endpoint(request: DatasetFormattingRequest):
|
||||
"""
|
||||
Start dataset formatting job (open1a -> open1b -> open1c).
|
||||
|
||||
Wraps webui.open1a(), open1b(), open1c() sequentially.
|
||||
"""
|
||||
# DEBUG: Print received payload
|
||||
print(f"\n{'='*80}")
|
||||
print(f"[DATASET FORMATTING DEBUG] Received request:")
|
||||
print(f" version: {request.version}")
|
||||
print(f" inp_text: {request.inp_text}")
|
||||
print(f" inp_wav_dir: {request.inp_wav_dir}")
|
||||
print(f" exp_name: {request.exp_name}")
|
||||
print(f" gpu_numbers1a: {request.gpu_numbers}")
|
||||
print(f" bert_pretrained_dir: {request.bert_pretrained_dir}")
|
||||
print(f" ssl_pretrained_dir: {request.ssl_pretrained_dir}")
|
||||
print(f" pretrained_s2G_path: {request.pretrained_s2G_path}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {
|
||||
"status": "queued",
|
||||
"operation": "format_dataset",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
try:
|
||||
asyncio.create_task(execute_dataset_formatting(job_id, request))
|
||||
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
|
||||
|
||||
|
||||
async def execute_fine_tune_sovits_direct(job_id: str, request: FineTuneSoVITSRequest):
|
||||
"""
|
||||
Execute SoVITS fine-tuning by directly calling s2_train.py subprocess.
|
||||
Replaces webui.open1Ba() to avoid Gradio dependency.
|
||||
"""
|
||||
jobs[job_id]["status"] = "running"
|
||||
jobs[job_id]["started_at"] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
s2_dir = f"{exp_root}/{request.exp_name}"
|
||||
os.makedirs(f"{s2_dir}/logs_s2_{request.version}", exist_ok=True)
|
||||
|
||||
# Load config template
|
||||
config_file = (
|
||||
"GPT_SoVITS/configs/s2.json"
|
||||
if request.version not in {"v2Pro", "v2ProPlus"}
|
||||
else f"GPT_SoVITS/configs/s2{request.version}.json"
|
||||
)
|
||||
with open(config_file) as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# Update config with request parameters
|
||||
batch_size = request.batch_size
|
||||
if is_half == False:
|
||||
data["train"]["fp16_run"] = False
|
||||
batch_size = max(1, batch_size // 2)
|
||||
|
||||
data["train"]["batch_size"] = batch_size
|
||||
data["train"]["epochs"] = request.total_epoch
|
||||
data["train"]["text_low_lr_rate"] = request.text_low_lr_rate
|
||||
data["train"]["pretrained_s2G"] = request.pretrained_s2G
|
||||
data["train"]["pretrained_s2D"] = request.pretrained_s2D
|
||||
data["train"]["if_save_latest"] = request.if_save_latest
|
||||
data["train"]["if_save_every_weights"] = request.if_save_every_weights
|
||||
data["train"]["save_every_epoch"] = request.save_every_epoch
|
||||
data["train"]["gpu_numbers"] = request.gpu_numbers1Ba
|
||||
data["train"]["grad_ckpt"] = request.if_grad_ckpt
|
||||
data["train"]["lora_rank"] = request.lora_rank
|
||||
data["model"]["version"] = request.version
|
||||
data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir
|
||||
data["save_weight_dir"] = SoVITS_weight_version2root[request.version]
|
||||
data["name"] = request.exp_name
|
||||
data["version"] = request.version
|
||||
|
||||
# Write temporary config
|
||||
tmp_config_path = f"{now_dir}/TEMP/tmp_s2.json"
|
||||
os.makedirs(f"{now_dir}/TEMP", exist_ok=True)
|
||||
with open(tmp_config_path, "w") as f:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
# Prepare environment with PYTHONPATH
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")])
|
||||
|
||||
# Determine training script based on version
|
||||
if request.version in ["v1", "v2", "v2Pro", "v2ProPlus"]:
|
||||
cmd = [python_exec, "GPT_SoVITS/s2_train.py", "--config", tmp_config_path]
|
||||
else:
|
||||
cmd = [python_exec, "GPT_SoVITS/s2_train_v3_lora.py", "--config", tmp_config_path]
|
||||
|
||||
print(f"[SOVITS FINE-TUNING] Executing: {' '.join(cmd)}")
|
||||
result = await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
|
||||
|
||||
# Find latest SoVITS checkpoint
|
||||
sovits_weights_dir = data["save_weight_dir"]
|
||||
latest_sovits_checkpoint = None
|
||||
|
||||
if os.path.exists(sovits_weights_dir):
|
||||
import re
|
||||
pattern = re.compile(rf"^{re.escape(request.exp_name)}_e(\d+)_s(\d+)_l(\d+)\.pth$")
|
||||
checkpoints = []
|
||||
for filename in os.listdir(sovits_weights_dir):
|
||||
match = pattern.match(filename)
|
||||
if match:
|
||||
epoch = int(match.group(1))
|
||||
step = int(match.group(2))
|
||||
checkpoints.append((epoch, step, filename))
|
||||
|
||||
if checkpoints:
|
||||
checkpoints.sort(reverse=True)
|
||||
latest_filename = checkpoints[0][2]
|
||||
latest_sovits_checkpoint = os.path.join(sovits_weights_dir, latest_filename)
|
||||
print(f"[SOVITS FINE-TUNING] Latest checkpoint: {latest_sovits_checkpoint}")
|
||||
|
||||
jobs[job_id]["status"] = "completed"
|
||||
jobs[job_id]["result"] = {
|
||||
"exp_name": request.exp_name,
|
||||
"config_path": tmp_config_path,
|
||||
"checkpoint_path": latest_sovits_checkpoint,
|
||||
"sovits_checkpoint_path": latest_sovits_checkpoint
|
||||
}
|
||||
jobs[job_id]["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
jobs[job_id]["traceback"] = traceback.format_exc()
|
||||
jobs[job_id]["failed_at"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
@APP.post("/training/fine-tune-sovits")
|
||||
async def fine_tune_sovits_endpoint(request: FineTuneSoVITSRequest):
|
||||
"""
|
||||
Start SoVITS fine-tuning job.
|
||||
|
||||
Directly executes s2_train.py (no webui dependency).
|
||||
"""
|
||||
# DEBUG: Print received payload
|
||||
print(f"\n{'='*80}")
|
||||
print(f"[SOVITS FINE-TUNING DEBUG] Received request:")
|
||||
print(f" version: {request.version}")
|
||||
print(f" batch_size: {request.batch_size}")
|
||||
print(f" total_epoch: {request.total_epoch}")
|
||||
print(f" exp_name: {request.exp_name}")
|
||||
print(f" text_low_lr_rate: {request.text_low_lr_rate}")
|
||||
print(f" if_save_latest: {request.if_save_latest}")
|
||||
print(f" if_save_every_weights: {request.if_save_every_weights}")
|
||||
print(f" save_every_epoch: {request.save_every_epoch}")
|
||||
print(f" gpu_numbers1Ba: {request.gpu_numbers1Ba}")
|
||||
print(f" pretrained_s2G: {request.pretrained_s2G}")
|
||||
print(f" pretrained_s2D: {request.pretrained_s2D}")
|
||||
print(f" if_grad_ckpt: {request.if_grad_ckpt}")
|
||||
print(f" lora_rank: {request.lora_rank}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {
|
||||
"status": "queued",
|
||||
"operation": "fine_tune_sovits",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
try:
|
||||
asyncio.create_task(execute_fine_tune_sovits_direct(job_id, request))
|
||||
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
|
||||
|
||||
|
||||
@APP.post("/training/fine-tune-gpt")
|
||||
async def fine_tune_gpt_endpoint(request: FineTuneGPTRequest):
|
||||
"""
|
||||
Start GPT fine-tuning job.
|
||||
|
||||
Wraps webui.open1Bb()
|
||||
"""
|
||||
# DEBUG: Print received payload
|
||||
print(f"\n{'='*80}")
|
||||
print(f"[GPT FINE-TUNING DEBUG] Received request:")
|
||||
print(f" batch_size: {request.batch_size}")
|
||||
print(f" total_epoch: {request.total_epoch}")
|
||||
print(f" exp_name: {request.exp_name}")
|
||||
print(f" if_dpo: {request.if_dpo}")
|
||||
print(f" if_save_latest: {request.if_save_latest}")
|
||||
print(f" if_save_every_weights: {request.if_save_every_weights}")
|
||||
print(f" save_every_epoch: {request.save_every_epoch}")
|
||||
print(f" gpu_numbers: {request.gpu_numbers}")
|
||||
print(f" pretrained_s1: {request.pretrained_s1}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {
|
||||
"status": "queued",
|
||||
"operation": "fine_tune_gpt",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
try:
|
||||
asyncio.create_task(execute_fine_tune_gpt_direct(job_id, request))
|
||||
return JSONResponse(status_code=200, content={"job_id": job_id, "status": "queued"})
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
return JSONResponse(status_code=500, content={"message": "failed to start job", "error": str(e)})
|
||||
|
||||
|
||||
async def execute_fine_tune_gpt_direct(job_id: str, request: FineTuneGPTRequest):
|
||||
"""
|
||||
Execute GPT fine-tuning by directly calling s1_train.py subprocess.
|
||||
Replaces webui.open1Bb() to avoid Gradio dependency.
|
||||
"""
|
||||
jobs[job_id]["status"] = "running"
|
||||
jobs[job_id]["started_at"] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
s1_dir = f"{exp_root}/{request.exp_name}"
|
||||
os.makedirs(f"{s1_dir}/logs_s1", exist_ok=True)
|
||||
|
||||
# Determine version (from webui.py line 606)
|
||||
version = os.environ.get("version", "v4")
|
||||
|
||||
# Load config template
|
||||
config_path = (
|
||||
"GPT_SoVITS/configs/s1longer.yaml" if version == "v1"
|
||||
else "GPT_SoVITS/configs/s1longer-v2.yaml"
|
||||
)
|
||||
with open(config_path) as f:
|
||||
data = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||
|
||||
# Update config with request parameters
|
||||
batch_size = request.batch_size
|
||||
if is_half == False:
|
||||
data["train"]["precision"] = "32"
|
||||
batch_size = max(1, batch_size // 2)
|
||||
|
||||
data["train"]["batch_size"] = batch_size
|
||||
data["train"]["epochs"] = request.total_epoch
|
||||
data["pretrained_s1"] = request.pretrained_s1
|
||||
data["train"]["save_every_n_epoch"] = request.save_every_epoch
|
||||
data["train"]["if_save_every_weights"] = request.if_save_every_weights
|
||||
data["train"]["if_save_latest"] = request.if_save_latest
|
||||
data["train"]["if_dpo"] = request.if_dpo
|
||||
data["train"]["half_weights_save_dir"] = GPT_weight_version2root[version]
|
||||
data["train"]["exp_name"] = request.exp_name
|
||||
data["train_semantic_path"] = f"{s1_dir}/6-name2semantic.tsv"
|
||||
data["train_phoneme_path"] = f"{s1_dir}/2-name2text.txt"
|
||||
data["output_dir"] = f"{s1_dir}/logs_s1_{version}"
|
||||
|
||||
# Set environment variables for GPU and PYTHONPATH
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = os.pathsep.join([now_dir, os.path.join(now_dir, "GPT_SoVITS")])
|
||||
env["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(request.gpu_numbers.replace("-", ","))
|
||||
env["hz"] = "25hz"
|
||||
|
||||
# Write temporary config
|
||||
tmp_config_path = f"{now_dir}/TEMP/tmp_s1.yaml"
|
||||
os.makedirs(f"{now_dir}/TEMP", exist_ok=True)
|
||||
with open(tmp_config_path, "w") as f:
|
||||
f.write(yaml.dump(data, default_flow_style=False))
|
||||
|
||||
# Execute training
|
||||
cmd = [python_exec, "GPT_SoVITS/s1_train.py", "--config_file", tmp_config_path]
|
||||
print(f"[GPT FINE-TUNING] Executing: {' '.join(cmd)}")
|
||||
result = await asyncio.to_thread(subprocess.run, cmd, env=env, cwd=now_dir, check=True)
|
||||
|
||||
# Find latest GPT checkpoint
|
||||
gpt_weights_dir = data["train"]["half_weights_save_dir"]
|
||||
latest_gpt_checkpoint = None
|
||||
|
||||
if os.path.exists(gpt_weights_dir):
|
||||
import re
|
||||
pattern = re.compile(rf"^{re.escape(request.exp_name)}-e(\d+)\.ckpt$")
|
||||
checkpoints = []
|
||||
for filename in os.listdir(gpt_weights_dir):
|
||||
match = pattern.match(filename)
|
||||
if match:
|
||||
epoch = int(match.group(1))
|
||||
checkpoints.append((epoch, filename))
|
||||
|
||||
if checkpoints:
|
||||
checkpoints.sort(reverse=True)
|
||||
latest_filename = checkpoints[0][1]
|
||||
latest_gpt_checkpoint = os.path.join(gpt_weights_dir, latest_filename)
|
||||
print(f"[GPT FINE-TUNING] Latest checkpoint: {latest_gpt_checkpoint}")
|
||||
|
||||
jobs[job_id]["status"] = "completed"
|
||||
jobs[job_id]["result"] = {
|
||||
"exp_name": request.exp_name,
|
||||
"config_path": tmp_config_path,
|
||||
"checkpoint_path": latest_gpt_checkpoint,
|
||||
"gpt_checkpoint_path": latest_gpt_checkpoint
|
||||
}
|
||||
jobs[job_id]["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
except Exception as e:
|
||||
jobs[job_id]["status"] = "failed"
|
||||
jobs[job_id]["error"] = str(e)
|
||||
jobs[job_id]["traceback"] = traceback.format_exc()
|
||||
jobs[job_id]["failed_at"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user