diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index e1efd973..bd4953df 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -759,21 +759,35 @@ class TTS: self._set_ref_spec(ref_audio_path) self._set_ref_audio_path(ref_audio_path) - def _set_ref_audio_path(self, ref_audio_path): - self.prompt_cache["ref_audio_path"] = ref_audio_path + def extract_prompt_semantic(self, ref_wav_path: str): + zero_wav = np.zeros( + int(self.configs.sampling_rate * 0.3), + dtype=np.float16 if self.configs.is_half else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(self.configs.device) + zero_wav_torch = zero_wav_torch.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + zero_wav_torch = zero_wav_torch.half() - def _set_ref_spec(self, ref_audio_path): - spec_audio = self._get_ref_spec(ref_audio_path) - if self.prompt_cache["refer_spec"] in [[], None]: - self.prompt_cache["refer_spec"] = [spec_audio] - else: - self.prompt_cache["refer_spec"][0] = spec_audio + wav16k = torch.cat([wav16k, zero_wav_torch]) + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( + 1, 2 + ) # .float() + codes = self.vits_model.extract_latent(hubert_feature) - def _get_ref_spec(self, ref_audio_path): + prompt_semantic = codes[0, 0].to(self.configs.device) + return prompt_semantic + + def extract_ref_spec(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) raw_audio = raw_audio.to(self.configs.device).float() - self.prompt_cache["raw_audio"] = raw_audio - self.prompt_cache["raw_sr"] = raw_sr if raw_sr != self.configs.sampling_rate: audio = raw_audio.to(self.configs.device) @@ -804,33 +818,30 @@ class TTS: audio = audio.half() else: audio = None + return spec, audio, raw_audio, raw_sr + + def extract_text_features(self, text: str, language: str): + return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + + def _set_ref_audio_path(self, ref_audio_path): + self.prompt_cache["ref_audio_path"] = ref_audio_path + + def _set_ref_spec(self, ref_audio_path): + spec_audio = self._get_ref_spec(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [spec_audio] + else: + self.prompt_cache["refer_spec"][0] = spec_audio + + def _get_ref_spec(self, ref_audio_path): + spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path) + self.prompt_cache["raw_audio"] = raw_audio + self.prompt_cache["raw_sr"] = raw_sr return spec, audio def _set_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - self.prompt_cache["prompt_semantic"] = prompt_semantic + prompt_semantic = self.extract_prompt_semantic(ref_wav_path) + self.prompt_cache["prompt_semantic"] = prompt_semantic def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): seq = sequences[0] @@ -1701,6 +1712,115 @@ class TTS: return audio + def using_vocoder_synthesis_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_audio_spec: torch.Tensor, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device) + refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device) + + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio = raw_audio.to(self.configs.device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + tgt_sr = 24000 if self.configs.version == "v3" else 32000 + if raw_sr != tgt_sr: + ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device) + + mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4 + mel2 = mel_spec_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + T_ref = self.vocoder_configs["T_ref"] + T_chunk = self.vocoder_configs["T_chunk"] + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min + + mel2 = mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + + cfm_res = self.vits_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + + with torch.inference_mode(): + wav_gen = self.vocoder(cfm_res) + audio = wav_gen[0][0] + + return audio + + def synthesize_audio_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_spec: tuple, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + refer_audio_spec, audio_tensor = refer_spec + if not self.configs.use_vocoder: + refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)] + sv_emb = None + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) + sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device) + return self.vits_model.decode( + semantic_tokens, + phones, + refer_audio_spec_list, + speed=speed, + sv_emb=sv_emb, + ).detach()[0, 0, :] + + return self.using_vocoder_synthesis_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_audio_spec=refer_audio_spec, + raw_audio=raw_audio, + raw_sr=raw_sr, + speed=speed, + sample_steps=sample_steps, + ) + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/__init__.py b/GPT_SoVITS/TTS_infer_pack/__init__.py index 8579a632..09a257b2 100644 --- a/GPT_SoVITS/TTS_infer_pack/__init__.py +++ b/GPT_SoVITS/TTS_infer_pack/__init__.py @@ -1 +1,11 @@ -from . import TTS, text_segmentation_method +from __future__ import annotations + +import importlib + +__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"] + + +def __getattr__(name: str): + if name in __all__: + return importlib.import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py new file mode 100644 index 00000000..e94a72c7 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -0,0 +1,631 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F + +from AR.models.utils import make_pad_mask_left, sample + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +@dataclass +class SchedulerRequestSpec: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int = 0 + + +@dataclass +class T2SRequestState: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + norm_prompt_text: str + norm_text: str + phones: torch.LongTensor + prompt_phones: torch.LongTensor + all_phones: torch.LongTensor + all_bert_features: torch.Tensor + prompt_semantic: torch.LongTensor + refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + raw_audio: torch.Tensor + raw_sr: int + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int + prepare_profile: Dict[str, float] + + +@dataclass +class T2SRunningRequest: + state: T2SRequestState + y_sequence: torch.LongTensor + prefix_len: int + decode_attn_mask: torch.Tensor + k_cache: List[torch.Tensor] + v_cache: List[torch.Tensor] + step_idx: int + + +@dataclass +class T2SFinishedItem: + request_id: str + semantic_tokens: torch.LongTensor + finish_idx: int + finish_reason: str + + +@dataclass +class T2SActiveBatch: + request_ids: List[str] + states: List[T2SRequestState] + x: torch.Tensor + x_lens: torch.LongTensor + y_sequences: List[torch.LongTensor] + prefix_lens: torch.LongTensor + xy_pos: torch.Tensor + prefill_attn_mask: torch.Tensor + decode_attn_mask: Optional[torch.Tensor] + k_cache: Optional[List[torch.Tensor]] + v_cache: Optional[List[torch.Tensor]] + step_idx: int + prefill_done: bool + + +def normalize_sentence(text: str, language: str) -> str: + text = text.strip("\n").strip() + if not text: + return text + if text[-1] not in {",", ".", "?", "!", ",", "。", "?", "!", "…", ";", ";", ":"}: + text += "。" if language != "en" else "." + return text + + +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + device = tts.configs.device + prepare_start = time.perf_counter() + _sync_device(device) + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + + _sync_device(device) + text_features_start = time.perf_counter() + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + _sync_device(device) + text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + if phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + + _sync_device(device) + prompt_semantic_start = time.perf_counter() + prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + _sync_device(device) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + + _sync_device(device) + ref_spec_start = time.perf_counter() + spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) + _sync_device(device) + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + + _sync_device(device) + tensorize_start = time.perf_counter() + phones_tensor = torch.LongTensor(phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device) + all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to( + dtype=tts.precision, device=tts.configs.device + ) + _sync_device(device) + tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 + + _sync_device(device) + prepare_profile = { + "prompt_text_features_ms": prompt_text_features_ms, + "text_features_ms": text_features_ms, + "prompt_semantic_ms": prompt_semantic_ms, + "ref_spec_ms": ref_spec_ms, + "tensorize_ms": tensorize_ms, + "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, + "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, + } + return T2SRequestState( + request_id=spec.request_id, + ref_audio_path=spec.ref_audio_path, + prompt_text=prompt_text, + prompt_lang=spec.prompt_lang, + text=text, + text_lang=spec.text_lang, + norm_prompt_text=prompt_norm_text, + norm_text=norm_text, + phones=phones_tensor, + prompt_phones=prompt_phones_tensor, + all_phones=all_phones, + all_bert_features=all_bert_features, + prompt_semantic=prompt_semantic, + refer_spec=(spec_audio, audio_16k), + raw_audio=raw_audio, + raw_sr=int(raw_sr), + top_k=spec.top_k, + top_p=spec.top_p, + temperature=spec.temperature, + repetition_penalty=spec.repetition_penalty, + early_stop_num=spec.early_stop_num, + ready_step=spec.ready_step, + prepare_profile=prepare_profile, + ) + + +def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: + if hidden.shape[0] >= target_len: + return hidden + return F.pad(hidden, (0, 0, target_len - hidden.shape[0], 0), value=0) + + +def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: torch.device) -> None: + required_len = max_position + 1 + if model.ar_audio_position.pe is not None and model.ar_audio_position.pe.size(1) >= required_len: + if model.ar_audio_position.pe.dtype != dtype or model.ar_audio_position.pe.device != device: + model.ar_audio_position.pe = model.ar_audio_position.pe.to(dtype=dtype, device=device) + return + model.ar_audio_position.extend_pe( + torch.zeros(1, required_len, model.ar_audio_position.embedding_dim, device=device, dtype=dtype) + ) + + +def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: + x_items: List[torch.Tensor] = [] + y_pos_items: List[torch.Tensor] = [] + x_lens: List[int] = [] + prefix_lens: List[int] = [] + y_sequences: List[torch.LongTensor] = [] + + for state in states: + text_emb = model.ar_text_embedding(state.all_phones.unsqueeze(0)) + bert_proj = model.bert_proj(state.all_bert_features.transpose(0, 1).unsqueeze(0)) + x_pos = model.ar_text_position(text_emb + bert_proj).squeeze(0) + y_emb = model.ar_audio_embedding(state.prompt_semantic.unsqueeze(0)) + y_pos = model.ar_audio_position(y_emb).squeeze(0) + x_items.append(x_pos) + y_pos_items.append(y_pos) + x_lens.append(x_pos.shape[0]) + prefix_lens.append(y_pos.shape[0]) + y_sequences.append(state.prompt_semantic.clone()) + + max_x_len = max(x_lens) + max_prefix_len = max(prefix_lens) + x_batch = torch.stack([_left_pad_hidden(item, max_x_len) for item in x_items], dim=0) + y_pos_batch = torch.stack([_left_pad_hidden(item, max_prefix_len) for item in y_pos_items], dim=0) + xy_pos = torch.cat([x_batch, y_pos_batch], dim=1) + + device = x_batch.device + x_lens_tensor = torch.LongTensor(x_lens).to(device) + prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) + batch_size = len(states) + src_len = max_x_len + max_prefix_len + + x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) + y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) + padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1) + x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) + y_mask = F.pad( + torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), + (max_x_len, 0), + value=False, + ) + causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1) + padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1) + attn_mask = causal_mask.logical_or(padding_mask) + attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool() + + return T2SActiveBatch( + request_ids=[state.request_id for state in states], + states=list(states), + x=x_batch, + x_lens=x_lens_tensor, + y_sequences=y_sequences, + prefix_lens=prefix_lens_tensor, + xy_pos=xy_pos, + prefill_attn_mask=attn_mask, + decode_attn_mask=None, + k_cache=None, + v_cache=None, + step_idx=0, + prefill_done=False, + ) + + +def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> torch.Tensor: + last_tokens = torch.stack([seq[-1:] for seq in y_sequences], dim=0) + y_emb = model.ar_audio_embedding(last_tokens) + position_ids = torch.LongTensor([int(seq.shape[0] - 1) for seq in y_sequences]).to(y_emb.device) + _ensure_audio_pe(model, int(position_ids.max().item()), y_emb.dtype, y_emb.device) + pos_emb = model.ar_audio_position.pe[0].index_select(0, position_ids).unsqueeze(1) + return y_emb * model.ar_audio_position.x_scale + model.ar_audio_position.alpha * pos_emb.to( + dtype=y_emb.dtype, device=y_emb.device + ) + + +def _sample_per_request( + model: Any, + active_batch: T2SActiveBatch, + logits: torch.Tensor, + max_steps: int, +) -> Tuple[List[T2SFinishedItem], List[int], List[torch.LongTensor]]: + finished_items: List[T2SFinishedItem] = [] + keep_indices: List[int] = [] + updated_sequences: List[torch.LongTensor] = [] + + step_idx = active_batch.step_idx + for batch_index, state in enumerate(active_batch.states): + logits_i = logits[batch_index : batch_index + 1].clone() + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: + finish_reason = "early_stop" + elif step_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=step_idx, + finish_reason=finish_reason, + ) + ) + else: + keep_indices.append(batch_index) + updated_sequences.append(new_history) + + return finished_items, keep_indices, updated_sequences + + +def decode_one_step( + model: Any, + active_batch: T2SActiveBatch, + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not active_batch.prefill_done: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( + active_batch.xy_pos, active_batch.prefill_attn_mask, None + ) + active_batch.decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + active_batch.prefill_done = True + else: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( + active_batch.xy_pos, + active_batch.k_cache, + active_batch.v_cache, + active_batch.decode_attn_mask, + ) + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False) + + logits = model.ar_predict_layer(xy_dec[:, -1]) + if active_batch.step_idx < 11: + logits = logits[:, :-1] + + finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) + if len(keep_indices) == 0: + return None, finished_items + + device = logits.device + keep_tensor = torch.LongTensor(keep_indices).to(device) + active_batch.request_ids = [active_batch.request_ids[i] for i in keep_indices] + active_batch.states = [active_batch.states[i] for i in keep_indices] + active_batch.y_sequences = updated_sequences + active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if active_batch.k_cache is not None and active_batch.v_cache is not None: + for cache_index in range(len(active_batch.k_cache)): + active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) + active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + active_batch.step_idx += 1 + return active_batch, finished_items + + +def run_scheduler_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + return run_scheduler_continuous(model, states, max_steps=max_steps) + + +def _pad_cache_left(cache: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - cache.shape[1] + if pad_len <= 0: + return cache + return F.pad(cache, (0, 0, pad_len, 0), value=0) + + +def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - mask.shape[-1] + if pad_len <= 0: + return mask + return F.pad(mask, (pad_len, 0), value=True) + + +def run_prefill_step( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not states: + return [], [] + + active_batch = build_prefill_batch(model, states) + xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) + decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + running_requests: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, state in enumerate(states): + logits_i = logits[batch_index : batch_index + 1].clone() + if 0 < 11: + logits_i = logits_i[:, :-1] + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + prefix_len = int(active_batch.prefix_lens[batch_index].item()) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - prefix_len) > state.early_stop_num: + finish_reason = "early_stop" + elif 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=0, + finish_reason=finish_reason, + ) + ) + continue + + real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len + request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + + running_requests.append( + T2SRunningRequest( + state=state, + y_sequence=new_history, + prefix_len=prefix_len, + decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=1, + ) + ) + + return running_requests, finished_items + + +def _build_decode_batch_from_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) + max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) + max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests) + num_layers = len(running_requests[0].k_cache) + + batched_k_cache: List[torch.Tensor] = [] + batched_v_cache: List[torch.Tensor] = [] + for layer_index in range(num_layers): + batched_k_cache.append( + torch.cat([_pad_cache_left(item.k_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + batched_v_cache.append( + torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests], + dim=0, + ) + return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask + + +def run_decode_step_for_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not running_requests: + return [], [] + + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = _build_decode_batch_from_running( + model, running_requests + ) + xy_dec, next_k_cache, next_v_cache = model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + next_running: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, running_request in enumerate(running_requests): + current_idx = running_request.step_idx + logits_i = logits[batch_index : batch_index + 1].clone() + if current_idx < 11: + logits_i = logits_i[:, :-1] + sampled = sample( + logits_i, + running_request.y_sequence.unsqueeze(0), + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + repetition_penalty=running_request.state.repetition_penalty, + temperature=running_request.state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if running_request.state.early_stop_num != -1 and (new_history.shape[0] - running_request.prefix_len) > running_request.state.early_stop_num: + finish_reason = "early_stop" + elif current_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=running_request.state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=current_idx, + finish_reason=finish_reason, + ) + ) + continue + + real_next_kv_len = running_request.k_cache[0].shape[1] + 1 + request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + next_running.append( + T2SRunningRequest( + state=running_request.state, + y_sequence=new_history, + prefix_len=running_request.prefix_len, + decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=current_idx + 1, + ) + ) + + return next_running, finished_items + + +def run_scheduler_continuous( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) + running_requests: List[T2SRunningRequest] = [] + finished: List[T2SFinishedItem] = [] + current_tick = 0 + + while pending or running_requests: + admitted: List[T2SRequestState] = [] + while pending and pending[0].ready_step <= current_tick: + admitted.append(pending.pop(0)) + + admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps) + finished.extend(admitted_finished) + + if running_requests: + running_requests, step_finished = run_decode_step_for_running( + model, + running_requests, + max_steps=max_steps, + ) + finished.extend(step_finished) + + running_requests.extend(admitted_running) + + if not running_requests and pending: + current_tick = max(current_tick + 1, pending[0].ready_step) + continue + + current_tick += 1 + + finished.sort(key=lambda item: item.request_id) + return finished diff --git a/api_v3.py b/api_v3.py new file mode 100644 index 00000000..9d250119 --- /dev/null +++ b/api_v3.py @@ -0,0 +1,1170 @@ +""" +# WebAPI文档 + +` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` + +## 执行参数: + `-a` - `绑定地址, 默认"127.0.0.1"` + `-p` - `绑定端口, 默认9880` + `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"` + +## 调用: + +### 推理 + +endpoint: `/tts` +GET: +``` +http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true +``` + +POST: +```json +{ + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: +``` +http://127.0.0.1:9880/control?command=restart +``` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + + +### 切换GPT模型 + +endpoint: `/set_gpt_weights` + +GET: +``` +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +``` +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 切换Sovits模型 + +endpoint: `/set_sovits_weights` + +GET: +``` +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +``` + +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + +""" + +import asyncio +import os +import sys +import time +import traceback +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Generator, List, Union + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import argparse +import subprocess +import wave +import signal +import numpy as np +import soundfile as sf +import torch +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from io import BytesIO +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + SchedulerRequestSpec, + T2SFinishedItem, + T2SRunningRequest, + T2SRequestState, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, + run_scheduler_continuous, +) +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from pydantic import BaseModel +import threading + +# print(sys.path) +i18n = I18nAuto() +cut_method_names = get_cut_method_names() + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") +parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") +args = parser.parse_args() +config_path = args.tts_config +# device = args.device +port = args.port +host = args.bind_addr +argv = sys.argv + +if config_path in [None, ""]: + config_path = "GPT-SoVITS/configs/tts_infer.yaml" + +tts_config = TTS_Config(config_path) +print(tts_config) +tts_pipeline = TTS(tts_config) + +APP = FastAPI() + + +class TTS_Request(BaseModel): + text: str = None + text_lang: str = None + ref_audio_path: str = None + aux_ref_audio_paths: list = None + prompt_lang: str = None + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = True + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: Union[bool, int] = False + parallel_infer: bool = True + repetition_penalty: float = 1.35 + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + + +class Scheduler_Debug_Request_Item(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + + +class Scheduler_Debug_Request(BaseModel): + requests: List[Scheduler_Debug_Request_Item] + max_steps: int = 1500 + seed: int = -1 + + +class Scheduler_Submit_Request(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + speed_factor: float = 1.0 + sample_steps: int = 32 + media_type: str = "wav" + timeout_sec: float = 30.0 + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + prepare_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + decode_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + + +class SchedulerDebugWorker: + def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): + self.tts = tts + self.max_steps = max_steps + self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 + self.prepare_lock = threading.Lock() + self.condition = threading.Condition() + self.pending_jobs: List[SchedulerPendingJob] = [] + self.running_requests: List[T2SRunningRequest] = [] + self.job_map: dict[str, SchedulerPendingJob] = {} + self.total_finished = 0 + self.total_submitted = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) + self.worker_thread.start() + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: + with self.prepare_lock: + return prepare_request_state(self.tts, spec) + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_ms: float, + prepare_wall_ms: float, + ) -> SchedulerPendingJob: + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_ms=float(prepare_ms), + prepare_wall_ms=float(prepare_wall_ms), + ) + with self.condition: + self.pending_jobs.append(job) + self.job_map[job.request_id] = job + self.total_submitted += 1 + self.condition.notify_all() + return job + + def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None and tracked_job.first_schedule_time is None: + tracked_job.first_schedule_time = started_at + + def _add_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.prefill_ms += elapsed_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.decode_ms += elapsed_ms + job.decode_steps += 1 + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) + phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device) + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=job.state.prompt_semantic, + prompt_phones=job.state.prompt_phones, + refer_spec=job.state.refer_spec, + raw_audio=job.state.raw_audio, + raw_sr=job.state.raw_sr, + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def get_state(self) -> dict: + with self.condition: + return { + "pending_jobs": len(self.pending_jobs), + "running_requests": len(self.running_requests), + "tracked_jobs": len(self.job_map), + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "max_steps": self.max_steps, + "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), + } + + def _finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for item in items: + job = self.job_map.get(item.request_id) + if job is not None: + jobs_to_finalize.append((job, item)) + + for job, item in jobs_to_finalize: + try: + self._sync_device() + synth_start = time.perf_counter() + sample_rate, audio_data = self._synthesize_finished_audio(job, item) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + except Exception as exc: + self._finalize_error([item.request_id], str(exc)) + continue + + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + continue + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + job.synth_ms += synth_ms + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "decode_ms": job.decode_ms, + "synth_ms": job.synth_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is None: + continue + job.error = error + job.done_event.set() + self.job_map.pop(request_id, None) + self.total_finished += 1 + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and not self.running_requests: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def _run_loop(self) -> None: + while True: + wait_for_batch = len(self.running_requests) == 0 + pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) + + if pending_jobs: + try: + self._sync_device() + prefill_start = time.perf_counter() + self._mark_prefill_started(pending_jobs, prefill_start) + admitted_running, admitted_finished = run_prefill_step( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) + self._finalize_finished(admitted_finished) + self.running_requests.extend(admitted_running) + except Exception as exc: + self._finalize_error([job.request_id for job in pending_jobs], str(exc)) + + if self.running_requests: + try: + active_request_ids = [item.state.request_id for item in self.running_requests] + self._sync_device() + decode_start = time.perf_counter() + self.running_requests, step_finished = run_decode_step_for_running( + self.tts.t2s_model.model, + self.running_requests, + max_steps=self.max_steps, + ) + self._sync_device() + self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) + self._finalize_finished(step_finished) + except Exception as exc: + self._finalize_error(active_request_ids, str(exc)) + self.running_requests = [] + continue + + if not pending_jobs: + time.sleep(self.micro_batch_wait_s) + + +scheduler_debug_worker = SchedulerDebugWorker(tts_pipeline) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + "192k", # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + # This will create a wave header then append the frame input + # It should be first on a streaming wav file + # Other frames better should not have it (else you will hear some artifacts each chunk start) + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + return wav_buf.read() + + +def handle_control(command: str): + if command == "restart": + os.execl(sys.executable, sys.executable, *argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def check_params(req: dict): + text: str = req.get("text", "") + text_lang: str = req.get("text_lang", "") + ref_audio_path: str = req.get("ref_audio_path", "") + streaming_mode: bool = req.get("streaming_mode", False) + media_type: str = req.get("media_type", "wav") + prompt_lang: str = req.get("prompt_lang", "") + text_split_method: str = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if text_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text_lang is required"}) + elif text_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, + ) + if prompt_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) + elif prompt_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, + ) + if media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) + # elif media_type == "ogg" and not streaming_mode: + # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + + if text_split_method not in cut_method_names: + return JSONResponse( + status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} + ) + + return None + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def build_scheduler_request_specs(request_items: List[Scheduler_Debug_Request_Item]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(request_items): + payload = item.dict() + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request[{index}] 参数非法: {detail}") + specs.append( + SchedulerRequestSpec( + request_id=payload["request_id"] or f"req_{index:03d}", + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=int(payload["ready_step"]), + ) + ) + return specs + + +def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return [scheduler_debug_worker.prepare_state(spec) for spec in specs] + + +def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: + payload = request.dict() + request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": float(payload["speed_factor"]), + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": payload["media_type"], + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": int(payload["sample_steps"]), + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request 参数非法: {detail}") + return SchedulerRequestSpec( + request_id=request_id, + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=0, + ) + + +async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): + try: + set_scheduler_seed(request.seed) + specs = build_scheduler_request_specs(request.requests) + states = await asyncio.to_thread(prepare_scheduler_states_batch, specs) + finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) + return JSONResponse( + status_code=200, + content={ + "message": "success", + "request_count": len(states), + "max_steps": int(request.max_steps), + "requests": summarize_scheduler_states(states), + "finished": summarize_scheduler_finished(finished), + }, + ) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler debug failed", "Exception": str(e)}, + ) + + +async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): + try: + request_start = time.perf_counter() + spec = build_scheduler_submit_spec(request) + prepare_start = time.perf_counter() + state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec) + prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0 + prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + job = scheduler_debug_worker.submit( + state, + speed_factor=float(request.speed_factor), + sample_steps=int(request.sample_steps), + media_type=request.media_type, + prepare_ms=prepare_ms, + prepare_wall_ms=prepare_wall_ms, + ) + timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec)) + if not timeout_ok: + return JSONResponse( + status_code=202, + content={ + "message": "queued", + "request_id": job.request_id, + "timings": { + "prepare_ms": prepare_ms, + "prepare_wall_ms": prepare_wall_ms, + "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), + }, + "worker_state": scheduler_debug_worker.get_state(), + }, + ) + if job.error is not None: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "request_id": job.request_id, "Exception": job.error}, + ) + if job.audio_data is None or job.sample_rate is None: + return JSONResponse( + status_code=500, + content={ + "message": "scheduler submit failed", + "request_id": job.request_id, + "Exception": "job finished without audio payload", + }, + ) + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_ms = (time.perf_counter() - pack_start) * 1000.0 + job.pack_ms = pack_ms + request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + headers = { + "X-Request-Id": job.request_id, + "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", + "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", + "X-Queue-Wait-Ms": ( + f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Prepare-Ms": f"{prepare_ms:.3f}", + "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Pack-Ms": f"{pack_ms:.3f}", + "X-Worker-Total-Ms": ( + f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Request-Total-Ms": f"{request_total_ms:.3f}", + "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", + } + if job.result is not None: + prepare_profile = job.result.get("prepare_profile", {}) + headers.update( + { + "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", + "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + } + ) + return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "Exception": str(e)}, + ) + + +async def tts_handle(req: dict): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + } + returns: + StreamingResponse: audio stream response. + """ + + streaming_mode = req.get("streaming_mode", False) + return_fragment = req.get("return_fragment", False) + media_type = req.get("media_type", "wav") + + check_res = check_params(req) + if check_res is not None: + return check_res + + if streaming_mode == 0: + streaming_mode = False + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 1: + streaming_mode = False + return_fragment = True + fixed_length_chunk = False + elif streaming_mode == 2: + streaming_mode = True + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 3: + streaming_mode = True + return_fragment = False + fixed_length_chunk = True + + else: + return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) + + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + req["fixed_length_chunk"] = fixed_length_chunk + + print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") + + streaming_mode = streaming_mode or return_fragment + + + try: + tts_generator = tts_pipeline.run(req) + + if streaming_mode: + + def streaming_generator(tts_generator: Generator, media_type: str): + if_frist_chunk = True + for sr, chunk in tts_generator: + if if_frist_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + media_type = "raw" + if_frist_chunk = False + yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + + # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" + return StreamingResponse( + streaming_generator( + tts_generator, + media_type, + ), + media_type=f"audio/{media_type}", + ) + + else: + sr, audio_data = next(tts_generator) + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_data, media_type=f"audio/{media_type}") + except Exception as e: + return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) + + +@APP.get("/control") +async def control(command: str = None): + if command is None: + return JSONResponse(status_code=400, content={"message": "command is required"}) + handle_control(command) + + +@APP.get("/tts") +async def tts_get_endpoint( + text: str = None, + text_lang: str = None, + ref_audio_path: str = None, + aux_ref_audio_paths: list = None, + prompt_lang: str = None, + prompt_text: str = "", + top_k: int = 15, + top_p: float = 1, + temperature: float = 1, + text_split_method: str = "cut5", + batch_size: int = 1, + batch_threshold: float = 0.75, + split_bucket: bool = True, + speed_factor: float = 1.0, + fragment_interval: float = 0.3, + seed: int = -1, + media_type: str = "wav", + parallel_infer: bool = True, + repetition_penalty: float = 1.35, + sample_steps: int = 32, + super_sampling: bool = False, + streaming_mode: Union[bool, int] = False, + overlap_length: int = 2, + min_chunk_length: int = 16, +): + req = { + "text": text, + "text_lang": text_lang.lower(), + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": aux_ref_audio_paths, + "prompt_text": prompt_text, + "prompt_lang": prompt_lang.lower(), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size": int(batch_size), + "batch_threshold": float(batch_threshold), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "fragment_interval": fragment_interval, + "seed": seed, + "media_type": media_type, + "streaming_mode": streaming_mode, + "parallel_infer": parallel_infer, + "repetition_penalty": float(repetition_penalty), + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "min_chunk_length": int(min_chunk_length), + } + return await tts_handle(req) + + +@APP.post("/tts") +async def tts_post_endpoint(request: TTS_Request): + req = request.dict() + return await tts_handle(req) + + +@APP.post("/tts_scheduler_debug") +async def tts_scheduler_debug_endpoint(request: Scheduler_Debug_Request): + return await tts_scheduler_debug_handle(request) + + +@APP.post("/tts_scheduler_submit") +async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request): + return await tts_scheduler_submit_handle(request) + + +@APP.get("/tts_scheduler_state") +async def tts_scheduler_state_endpoint(): + return JSONResponse(status_code=200, content={"message": "success", "worker_state": scheduler_debug_worker.get_state()}) + + +@APP.get("/set_refer_audio") +async def set_refer_aduio(refer_audio_path: str = None): + try: + tts_pipeline.set_ref_audio(refer_audio_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +# @APP.post("/set_refer_audio") +# async def set_refer_aduio_post(audio_file: UploadFile = File(...)): +# try: +# # 检查文件类型,确保是音频文件 +# if not audio_file.content_type.startswith("audio/"): +# return JSONResponse(status_code=400, content={"message": "file type is not supported"}) + +# os.makedirs("uploaded_audio", exist_ok=True) +# save_path = os.path.join("uploaded_audio", audio_file.filename) +# # 保存音频文件到服务器上的一个目录 +# with open(save_path , "wb") as buffer: +# buffer.write(await audio_file.read()) + +# tts_pipeline.set_ref_audio(save_path) +# except Exception as e: +# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)}) +# return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_gpt_weights") +async def set_gpt_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) + tts_pipeline.init_t2s_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) + + return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_sovits_weights") +async def set_sovits_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) + tts_pipeline.init_vits_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +if __name__ == "__main__": + try: + if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None + uvicorn.run(app=APP, host=host, port=port, workers=1) + except Exception: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) diff --git a/tools/t2s_scheduler_prototype.py b/tools/t2s_scheduler_prototype.py new file mode 100644 index 00000000..cd4b9c6d --- /dev/null +++ b/tools/t2s_scheduler_prototype.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SFinishedItem, + T2SRequestState, + prepare_request_state, + run_scheduler_continuous, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-lang", type=str, default="en") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run") + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_pipeline(config_path: Path): + try: + from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。" + ) from exc + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8") + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8") + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + use_cuda = str(tts.configs.device).startswith("cuda") + set_seed(args.seed, use_cuda) + + request_specs = load_request_specs(args) + states = [prepare_request_state(tts, spec) for spec in request_specs] + finished = run_scheduler_continuous(model, states, max_steps=args.max_steps) + + summary = { + "request_count": len(states), + "max_steps": args.max_steps, + "requests": summarise_requests(states), + "finished": summarise_finished(finished), + } + output_path = args.output_dir / "scheduler_prototype_summary.json" + output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(summary, ensure_ascii=False, indent=2)) + print(f"[saved] {output_path}") + + +if __name__ == "__main__": + try: + main() + except ModuleNotFoundError as exc: + print(f"[error] {exc}") + raise SystemExit(1) from None