This commit is contained in:
XXXXRT666 2025-10-20 00:36:36 +01:00
parent 620b7810e6
commit 4679ca9fc5
2 changed files with 43 additions and 8 deletions

View File

@ -153,7 +153,7 @@ class T2SEngine(T2SEngineProtocol):
pos_sorted = mx.sort(pos, axis=0)
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
pos_final = pos_sorted[: int(valid_count)]
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
newly_done_indices = newly_done_indices[pos_final]
mx.set_default_device(self.device)
if debug:

View File

@ -4,6 +4,7 @@ import os
import random
import time
import traceback
import warnings
from copy import deepcopy
from pathlib import Path
from typing import Any
@ -19,7 +20,7 @@ from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer
from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends
from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends, console
from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
@ -37,6 +38,14 @@ now_dir = os.getcwd()
resample_transform_dict = {}
v3v4set = {"v3", "v4"}
warnings.filterwarnings(
"ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
)
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict
@ -963,6 +972,7 @@ class TTS:
Tuple[int, np.ndarray]: sampling rate and audio data.
"""
########## variables initialization ###########
ttfb_time = time.perf_counter()
self.stop_flag: bool = False
text: str = inputs.get("text", "")
text_lang: str = inputs.get("text_lang", "")
@ -1013,6 +1023,7 @@ class TTS:
)
###### setting reference audio and prompt text preprocessing ########
t_34 = t_45 = 0.0
t0 = time.perf_counter()
if (ref_audio_path is not None) and (
ref_audio_path != self.prompt_cache["ref_audio_path"]
@ -1107,6 +1118,9 @@ class TTS:
return batch[0]
t2 = time.perf_counter()
infer_len: list[int] = []
infer_time: list[float] = []
audio_len = [0.0]
try:
print("############ 推理 ############")
###### inference ######
@ -1114,7 +1128,7 @@ class TTS:
t_45 = 0.0
audio = []
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
for item in data:
for idx, item in enumerate(data):
t3 = time.perf_counter()
if return_fragment:
item = make_batch(item)
@ -1158,6 +1172,8 @@ class TTS:
pred_semantic_list = t2s_result.result
assert pred_semantic_list
pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list]
infer_len.append(t2s_result.total_tokens)
infer_time.append(t2s_result.infer_speed[-1])
t4 = time.perf_counter()
t_34 += t4 - t3
@ -1246,12 +1262,13 @@ class TTS:
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.append(audio_fragment)
if idx == 0:
ttfb_time = time.perf_counter() - ttfb_time
t5 = time.perf_counter()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess(
console.print(f">> Time Stamps For Fragment: {t_34:.3f}\t{t_45:.3f}")
tmp = self.audio_postprocess(
[batch_audio_fragment],
output_sr,
None,
@ -1260,19 +1277,23 @@ class TTS:
fragment_interval,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
audio_len.append(len(tmp[-1]) / output_sr)
yield tmp
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
audio_len.append(1)
yield 16000, np.zeros(int(16000), dtype=np.int16)
return
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0:
audio_len.append(1)
yield 16000, np.zeros(int(16000), dtype=np.int16)
return
yield self.audio_postprocess(
tmp = self.audio_postprocess(
audio,
output_sr,
batch_index_list,
@ -1281,10 +1302,14 @@ class TTS:
fragment_interval,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
audio_len.append(len(tmp[-1]) / output_sr)
yield tmp
except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
audio_len.append(1)
yield 16000, np.zeros(int(16000), dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。
del self.t2s_model
@ -1295,6 +1320,16 @@ class TTS:
self.init_vits_weights(self.configs.vits_weights_path)
raise e
finally:
infer_speed_avg = sum(infer_len) / sum(infer_time)
rtf_value = sum((t1 - t0, t2 - t1, t_34, t_45)) / sum(audio_len)
console.print(f">> Time Stamps: {t1 - t0:.3f}\t{t2 - t1:.3f}\t{t_34:.3f}\t{t_45:.3f}")
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
console.print(f">> RTF: {rtf_value:.2f}")
if ttfb_time > 2:
console.print(f">> TTFB: {ttfb_time:.3f} s")
else:
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
self.empty_cache()
def empty_cache(self):