From 4679ca9fc57b5181d379474978fac1dcbc28d73d Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Mon, 20 Oct 2025 00:36:36 +0100 Subject: [PATCH] . --- GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py | 2 +- GPT_SoVITS/TTS_infer_pack/TTS.py | 49 ++++++++++++++++++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py index 1b603ec1..e6453135 100644 --- a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py @@ -153,7 +153,7 @@ class T2SEngine(T2SEngineProtocol): pos_sorted = mx.sort(pos, axis=0) valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz)) pos_final = pos_sorted[: int(valid_count)] - newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0) + newly_done_indices = newly_done_indices[pos_final] mx.set_default_device(self.device) if debug: diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 350e7ba4..cb1e4b69 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -4,6 +4,7 @@ import os import random import time import traceback +import warnings from copy import deepcopy from pathlib import Path from typing import Any @@ -19,7 +20,7 @@ from peft import LoraConfig, get_peft_model from tqdm import tqdm from transformers import AutoModelForMaskedLM, AutoTokenizer -from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends +from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends, console from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN from GPT_SoVITS.feature_extractor.cnhubert import CNHubert from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch @@ -37,6 +38,14 @@ now_dir = os.getcwd() resample_transform_dict = {} v3v4set = {"v3", "v4"} +warnings.filterwarnings( + "ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively." +) +warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*") + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + def resample(audio_tensor, sr0, sr1, device): global resample_transform_dict @@ -963,6 +972,7 @@ class TTS: Tuple[int, np.ndarray]: sampling rate and audio data. """ ########## variables initialization ########### + ttfb_time = time.perf_counter() self.stop_flag: bool = False text: str = inputs.get("text", "") text_lang: str = inputs.get("text_lang", "") @@ -1013,6 +1023,7 @@ class TTS: ) ###### setting reference audio and prompt text preprocessing ######## + t_34 = t_45 = 0.0 t0 = time.perf_counter() if (ref_audio_path is not None) and ( ref_audio_path != self.prompt_cache["ref_audio_path"] @@ -1107,6 +1118,9 @@ class TTS: return batch[0] t2 = time.perf_counter() + infer_len: list[int] = [] + infer_time: list[float] = [] + audio_len = [0.0] try: print("############ 推理 ############") ###### inference ###### @@ -1114,7 +1128,7 @@ class TTS: t_45 = 0.0 audio = [] output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] - for item in data: + for idx, item in enumerate(data): t3 = time.perf_counter() if return_fragment: item = make_batch(item) @@ -1158,6 +1172,8 @@ class TTS: pred_semantic_list = t2s_result.result assert pred_semantic_list pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list] + infer_len.append(t2s_result.total_tokens) + infer_time.append(t2s_result.infer_speed[-1]) t4 = time.perf_counter() t_34 += t4 - t3 @@ -1246,12 +1262,13 @@ class TTS: _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps ) batch_audio_fragment.append(audio_fragment) - + if idx == 0: + ttfb_time = time.perf_counter() - ttfb_time t5 = time.perf_counter() t_45 += t5 - t4 if return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess( + console.print(f">> Time Stamps For Fragment: {t_34:.3f}\t{t_45:.3f}") + tmp = self.audio_postprocess( [batch_audio_fragment], output_sr, None, @@ -1260,19 +1277,23 @@ class TTS: fragment_interval, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) + audio_len.append(len(tmp[-1]) / output_sr) + yield tmp else: audio.append(batch_audio_fragment) if self.stop_flag: + audio_len.append(1) yield 16000, np.zeros(int(16000), dtype=np.int16) return if not return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) if len(audio) == 0: + audio_len.append(1) yield 16000, np.zeros(int(16000), dtype=np.int16) return - yield self.audio_postprocess( + + tmp = self.audio_postprocess( audio, output_sr, batch_index_list, @@ -1281,10 +1302,14 @@ class TTS: fragment_interval, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) + audio_len.append(len(tmp[-1]) / output_sr) + + yield tmp except Exception as e: traceback.print_exc() # 必须返回一个空音频, 否则会导致显存不释放。 + audio_len.append(1) yield 16000, np.zeros(int(16000), dtype=np.int16) # 重置模型, 否则会导致显存释放不完全。 del self.t2s_model @@ -1295,6 +1320,16 @@ class TTS: self.init_vits_weights(self.configs.vits_weights_path) raise e finally: + infer_speed_avg = sum(infer_len) / sum(infer_time) + rtf_value = sum((t1 - t0, t2 - t1, t_34, t_45)) / sum(audio_len) + console.print(f">> Time Stamps: {t1 - t0:.3f}\t{t2 - t1:.3f}\t{t_34:.3f}\t{t_45:.3f}") + console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s") + console.print(f">> RTF: {rtf_value:.2f}") + if ttfb_time > 2: + console.print(f">> TTFB: {ttfb_time:.3f} s") + else: + console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms") + self.empty_cache() def empty_cache(self):