mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-18 00:08:16 +08:00
Integrate UnifiedTTSEngine into TTS API for improved audio processing and control. Refactor tts_handle and control endpoints to utilize the new engine, enhancing error handling and response management. Update set_refer_audio and set_gpt_weights endpoints to return payloads from the engine, streamlining audio configuration processes.
This commit is contained in:
parent
827d6ea47c
commit
69ac7f9027
@ -468,7 +468,26 @@ class TTS:
|
||||
)
|
||||
|
||||
self._init_models()
|
||||
self.refresh_runtime_components()
|
||||
|
||||
self.prompt_cache: dict = {
|
||||
"ref_audio_path": None,
|
||||
"prompt_semantic": None,
|
||||
"refer_spec": [],
|
||||
"prompt_text": None,
|
||||
"prompt_lang": None,
|
||||
"phones": None,
|
||||
"bert_features": None,
|
||||
"norm_text": None,
|
||||
"aux_ref_audio_paths": [],
|
||||
}
|
||||
|
||||
self.stop_flag: bool = False
|
||||
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||
|
||||
def refresh_runtime_components(self):
|
||||
self.prepare_bert_batch_worker = None
|
||||
self.prepare_ref_semantic_batch_worker = None
|
||||
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
|
||||
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
|
||||
bert_model=self.bert_model,
|
||||
@ -509,7 +528,7 @@ class TTS:
|
||||
max_batch_samples=int(ref_max_batch_samples),
|
||||
)
|
||||
|
||||
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
|
||||
self.text_preprocessor = TextPreprocessor(
|
||||
self.bert_model,
|
||||
self.bert_tokenizer,
|
||||
self.configs.device,
|
||||
@ -517,21 +536,6 @@ class TTS:
|
||||
bert_batch_worker=self.prepare_bert_batch_worker,
|
||||
)
|
||||
|
||||
self.prompt_cache: dict = {
|
||||
"ref_audio_path": None,
|
||||
"prompt_semantic": None,
|
||||
"refer_spec": [],
|
||||
"prompt_text": None,
|
||||
"prompt_lang": None,
|
||||
"phones": None,
|
||||
"bert_features": None,
|
||||
"norm_text": None,
|
||||
"aux_ref_audio_paths": [],
|
||||
}
|
||||
|
||||
self.stop_flag: bool = False
|
||||
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||
|
||||
def _init_models(
|
||||
self,
|
||||
):
|
||||
|
||||
1255
GPT_SoVITS/TTS_infer_pack/unified_engine.py
Normal file
1255
GPT_SoVITS/TTS_infer_pack/unified_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
99
api_v2.py
99
api_v2.py
@ -123,6 +123,7 @@ from io import BytesIO
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
|
||||
@ -147,6 +148,14 @@ if config_path in [None, ""]:
|
||||
tts_config = TTS_Config(config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
tts_engine = UnifiedTTSEngine(
|
||||
tts_pipeline,
|
||||
cut_method_names=cut_method_names,
|
||||
control_callbacks=RuntimeControlCallbacks(
|
||||
restart=lambda: os.execl(sys.executable, sys.executable, *argv),
|
||||
exit=lambda: os.kill(os.getpid(), signal.SIGTERM),
|
||||
),
|
||||
)
|
||||
|
||||
APP = FastAPI()
|
||||
|
||||
@ -377,70 +386,11 @@ async def tts_handle(req: dict):
|
||||
StreamingResponse: audio stream response.
|
||||
"""
|
||||
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
return_fragment = req.get("return_fragment", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode == 0:
|
||||
streaming_mode = False
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 1:
|
||||
streaming_mode = False
|
||||
return_fragment = True
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 2:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 3:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = True
|
||||
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"})
|
||||
|
||||
req["streaming_mode"] = streaming_mode
|
||||
req["return_fragment"] = return_fragment
|
||||
req["fixed_length_chunk"] = fixed_length_chunk
|
||||
|
||||
print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}")
|
||||
|
||||
streaming_mode = streaming_mode or return_fragment
|
||||
|
||||
|
||||
try:
|
||||
tts_generator = tts_pipeline.run(req)
|
||||
|
||||
if streaming_mode:
|
||||
|
||||
def streaming_generator(tts_generator: Generator, media_type: str):
|
||||
if_frist_chunk = True
|
||||
for sr, chunk in tts_generator:
|
||||
if if_frist_chunk and media_type == "wav":
|
||||
yield wave_header_chunk(sample_rate=sr)
|
||||
media_type = "raw"
|
||||
if_frist_chunk = False
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(
|
||||
streaming_generator(
|
||||
tts_generator,
|
||||
media_type,
|
||||
),
|
||||
media_type=f"audio/{media_type}",
|
||||
)
|
||||
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return Response(audio_data, media_type=f"audio/{media_type}")
|
||||
result = tts_engine.run_direct_tts(req)
|
||||
if result.streaming:
|
||||
return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}")
|
||||
return Response(result.audio_bytes, media_type=f"audio/{result.media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
|
||||
|
||||
@ -449,7 +399,11 @@ async def tts_handle(req: dict):
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||
handle_control(command)
|
||||
try:
|
||||
tts_engine.handle_control(command)
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)})
|
||||
|
||||
|
||||
@APP.get("/tts")
|
||||
@ -517,10 +471,10 @@ async def tts_post_endpoint(request: TTS_Request):
|
||||
@APP.get("/set_refer_audio")
|
||||
async def set_refer_aduio(refer_audio_path: str = None):
|
||||
try:
|
||||
tts_pipeline.set_ref_audio(refer_audio_path)
|
||||
payload = tts_engine.set_refer_audio(refer_audio_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
# @APP.post("/set_refer_audio")
|
||||
@ -545,24 +499,19 @@ async def set_refer_aduio(refer_audio_path: str = None):
|
||||
@APP.get("/set_gpt_weights")
|
||||
async def set_gpt_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
|
||||
tts_pipeline.init_t2s_weights(weights_path)
|
||||
payload = tts_engine.set_gpt_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
@APP.get("/set_sovits_weights")
|
||||
async def set_sovits_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
|
||||
tts_pipeline.init_vits_weights(weights_path)
|
||||
payload = tts_engine.set_sovits_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
323
api_v3.py
323
api_v3.py
@ -144,6 +144,7 @@ from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
|
||||
run_prefill_active_batch,
|
||||
run_scheduler_continuous,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
@ -169,6 +170,14 @@ if config_path in [None, ""]:
|
||||
tts_config = TTS_Config(config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
tts_engine = UnifiedTTSEngine(
|
||||
tts_pipeline,
|
||||
cut_method_names=cut_method_names,
|
||||
control_callbacks=RuntimeControlCallbacks(
|
||||
restart=lambda: os.execl(sys.executable, sys.executable, *argv),
|
||||
exit=lambda: os.kill(os.getpid(), signal.SIGTERM),
|
||||
),
|
||||
)
|
||||
|
||||
APP = FastAPI()
|
||||
|
||||
@ -805,7 +814,7 @@ class SchedulerDebugWorker:
|
||||
time.sleep(self.micro_batch_wait_s)
|
||||
|
||||
|
||||
scheduler_debug_worker = SchedulerDebugWorker(tts_pipeline)
|
||||
scheduler_debug_worker = tts_engine.scheduler_worker
|
||||
|
||||
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
@ -1116,20 +1125,12 @@ def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerR
|
||||
|
||||
async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
|
||||
try:
|
||||
set_scheduler_seed(request.seed)
|
||||
specs = build_scheduler_request_specs(request.requests)
|
||||
states = await scheduler_debug_worker.prepare_states_batch_async(specs)
|
||||
finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps))
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"message": "success",
|
||||
"request_count": len(states),
|
||||
"max_steps": int(request.max_steps),
|
||||
"requests": summarize_scheduler_states(states),
|
||||
"finished": summarize_scheduler_finished(finished),
|
||||
},
|
||||
result = await tts_engine.run_scheduler_debug(
|
||||
request_items=[item.dict() for item in request.requests],
|
||||
max_steps=int(request.max_steps),
|
||||
seed=int(request.seed),
|
||||
)
|
||||
return JSONResponse(status_code=200, content=result.payload)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
@ -1139,206 +1140,8 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
|
||||
|
||||
async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
try:
|
||||
request_start = time.perf_counter()
|
||||
prepare_start = request_start
|
||||
spec = build_scheduler_submit_spec(request)
|
||||
spec_ready_at = time.perf_counter()
|
||||
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async(
|
||||
spec,
|
||||
spec_ready_at,
|
||||
)
|
||||
prepare_end = time.perf_counter()
|
||||
prepare_wall_ms = (prepare_end - prepare_start) * 1000.0
|
||||
prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms))
|
||||
prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms))
|
||||
prepare_executor_queue_ms = float(
|
||||
state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0))
|
||||
)
|
||||
prepare_executor_run_ms = float(
|
||||
state.prepare_profile.get(
|
||||
"executor_run_wall_ms",
|
||||
max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0),
|
||||
)
|
||||
)
|
||||
prepare_other_ms = max(
|
||||
0.0,
|
||||
prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms,
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
job = scheduler_debug_worker.submit(
|
||||
state,
|
||||
speed_factor=float(request.speed_factor),
|
||||
sample_steps=int(request.sample_steps),
|
||||
media_type=request.media_type,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
)
|
||||
api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0)
|
||||
timeout_ok = False
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec))
|
||||
timeout_ok = True
|
||||
except asyncio.TimeoutError:
|
||||
timeout_ok = False
|
||||
wait_return_at = time.perf_counter()
|
||||
if not timeout_ok:
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"message": "queued",
|
||||
"request_id": job.request_id,
|
||||
"timings": {
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"api_after_prepare_ms": api_after_prepare_ms,
|
||||
"request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0),
|
||||
},
|
||||
"worker_state": scheduler_debug_worker.get_state(),
|
||||
},
|
||||
)
|
||||
if job.error is not None:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": "scheduler submit failed", "request_id": job.request_id, "Exception": job.error},
|
||||
)
|
||||
if job.audio_data is None or job.sample_rate is None:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"message": "scheduler submit failed",
|
||||
"request_id": job.request_id,
|
||||
"Exception": "job finished without audio payload",
|
||||
},
|
||||
)
|
||||
pack_start = time.perf_counter()
|
||||
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
|
||||
pack_end = time.perf_counter()
|
||||
pack_ms = (pack_end - pack_start) * 1000.0
|
||||
job.pack_ms = pack_ms
|
||||
api_wait_result_ms = 0.0
|
||||
if job.result_ready_time is not None:
|
||||
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
|
||||
worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0
|
||||
headers = {
|
||||
"X-Request-Id": job.request_id,
|
||||
"X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0",
|
||||
"X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown",
|
||||
"X-Queue-Wait-Ms": (
|
||||
f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000"
|
||||
),
|
||||
"X-Prepare-Ms": f"{prepare_wall_ms:.3f}",
|
||||
"X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}",
|
||||
"X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}",
|
||||
"X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}",
|
||||
"X-Prepare-Admission-Wait-Ms": (
|
||||
f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}"
|
||||
if job.result is not None
|
||||
else "0.000"
|
||||
),
|
||||
"X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}",
|
||||
"X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}",
|
||||
"X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}",
|
||||
"X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}",
|
||||
"X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}",
|
||||
"X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000",
|
||||
"X-Pack-Ms": f"{pack_ms:.3f}",
|
||||
"X-Worker-Total-Ms": (
|
||||
f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000"
|
||||
),
|
||||
"X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}",
|
||||
"X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0",
|
||||
}
|
||||
if job.result is not None:
|
||||
prepare_profile = job.result.get("prepare_profile", {})
|
||||
headers.update(
|
||||
{
|
||||
"X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(
|
||||
int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(
|
||||
int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-High-Pressure-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-High-Pressure-Peak": str(
|
||||
int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-Batch-Size-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-Batch-Size-Peak": str(
|
||||
int(prepare_profile.get("text_bert_batch_size_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
|
||||
"X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Batch-Size": str(
|
||||
int(prepare_profile.get("prompt_semantic_batch_size", 0.0))
|
||||
),
|
||||
"X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Inflight-On-Enter": str(
|
||||
int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))
|
||||
),
|
||||
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
|
||||
}
|
||||
)
|
||||
response_ready_at = time.perf_counter()
|
||||
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
|
||||
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
|
||||
request_other_ms = max(
|
||||
0.0,
|
||||
request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms,
|
||||
)
|
||||
headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}"
|
||||
headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}"
|
||||
headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}"
|
||||
return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers)
|
||||
result = await tts_engine.run_scheduler_submit(request.dict())
|
||||
return Response(result.audio_bytes, media_type=result.media_type, headers=result.headers)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
@ -1381,70 +1184,11 @@ async def tts_handle(req: dict):
|
||||
StreamingResponse: audio stream response.
|
||||
"""
|
||||
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
return_fragment = req.get("return_fragment", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode == 0:
|
||||
streaming_mode = False
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 1:
|
||||
streaming_mode = False
|
||||
return_fragment = True
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 2:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 3:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = True
|
||||
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"})
|
||||
|
||||
req["streaming_mode"] = streaming_mode
|
||||
req["return_fragment"] = return_fragment
|
||||
req["fixed_length_chunk"] = fixed_length_chunk
|
||||
|
||||
print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}")
|
||||
|
||||
streaming_mode = streaming_mode or return_fragment
|
||||
|
||||
|
||||
try:
|
||||
tts_generator = tts_pipeline.run(req)
|
||||
|
||||
if streaming_mode:
|
||||
|
||||
def streaming_generator(tts_generator: Generator, media_type: str):
|
||||
if_frist_chunk = True
|
||||
for sr, chunk in tts_generator:
|
||||
if if_frist_chunk and media_type == "wav":
|
||||
yield wave_header_chunk(sample_rate=sr)
|
||||
media_type = "raw"
|
||||
if_frist_chunk = False
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(
|
||||
streaming_generator(
|
||||
tts_generator,
|
||||
media_type,
|
||||
),
|
||||
media_type=f"audio/{media_type}",
|
||||
)
|
||||
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return Response(audio_data, media_type=f"audio/{media_type}")
|
||||
result = tts_engine.run_direct_tts(req)
|
||||
if result.streaming:
|
||||
return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}")
|
||||
return Response(result.audio_bytes, media_type=f"audio/{result.media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
|
||||
|
||||
@ -1453,7 +1197,11 @@ async def tts_handle(req: dict):
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||
handle_control(command)
|
||||
try:
|
||||
tts_engine.handle_control(command)
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)})
|
||||
|
||||
|
||||
@APP.get("/tts")
|
||||
@ -1530,16 +1278,16 @@ async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request):
|
||||
|
||||
@APP.get("/tts_scheduler_state")
|
||||
async def tts_scheduler_state_endpoint():
|
||||
return JSONResponse(status_code=200, content={"message": "success", "worker_state": scheduler_debug_worker.get_state()})
|
||||
return JSONResponse(status_code=200, content=tts_engine.get_runtime_state())
|
||||
|
||||
|
||||
@APP.get("/set_refer_audio")
|
||||
async def set_refer_aduio(refer_audio_path: str = None):
|
||||
try:
|
||||
tts_pipeline.set_ref_audio(refer_audio_path)
|
||||
payload = tts_engine.set_refer_audio(refer_audio_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
# @APP.post("/set_refer_audio")
|
||||
@ -1564,24 +1312,19 @@ async def set_refer_aduio(refer_audio_path: str = None):
|
||||
@APP.get("/set_gpt_weights")
|
||||
async def set_gpt_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
|
||||
tts_pipeline.init_t2s_weights(weights_path)
|
||||
payload = tts_engine.set_gpt_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
@APP.get("/set_sovits_weights")
|
||||
async def set_sovits_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
|
||||
tts_pipeline.init_vits_weights(weights_path)
|
||||
payload = tts_engine.set_sovits_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user