mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-30 18:38:31 +08:00
.
This commit is contained in:
parent
620b7810e6
commit
4679ca9fc5
@ -153,7 +153,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
pos_sorted = mx.sort(pos, axis=0)
|
||||
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
|
||||
pos_final = pos_sorted[: int(valid_count)]
|
||||
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
|
||||
newly_done_indices = newly_done_indices[pos_final]
|
||||
mx.set_default_device(self.device)
|
||||
|
||||
if debug:
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -19,7 +20,7 @@ from peft import LoraConfig, get_peft_model
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends
|
||||
from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends, console
|
||||
from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
|
||||
from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
|
||||
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||||
@ -37,6 +38,14 @@ now_dir = os.getcwd()
|
||||
resample_transform_dict = {}
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0, sr1, device):
|
||||
global resample_transform_dict
|
||||
@ -963,6 +972,7 @@ class TTS:
|
||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
"""
|
||||
########## variables initialization ###########
|
||||
ttfb_time = time.perf_counter()
|
||||
self.stop_flag: bool = False
|
||||
text: str = inputs.get("text", "")
|
||||
text_lang: str = inputs.get("text_lang", "")
|
||||
@ -1013,6 +1023,7 @@ class TTS:
|
||||
)
|
||||
|
||||
###### setting reference audio and prompt text preprocessing ########
|
||||
t_34 = t_45 = 0.0
|
||||
t0 = time.perf_counter()
|
||||
if (ref_audio_path is not None) and (
|
||||
ref_audio_path != self.prompt_cache["ref_audio_path"]
|
||||
@ -1107,6 +1118,9 @@ class TTS:
|
||||
return batch[0]
|
||||
|
||||
t2 = time.perf_counter()
|
||||
infer_len: list[int] = []
|
||||
infer_time: list[float] = []
|
||||
audio_len = [0.0]
|
||||
try:
|
||||
print("############ 推理 ############")
|
||||
###### inference ######
|
||||
@ -1114,7 +1128,7 @@ class TTS:
|
||||
t_45 = 0.0
|
||||
audio = []
|
||||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||||
for item in data:
|
||||
for idx, item in enumerate(data):
|
||||
t3 = time.perf_counter()
|
||||
if return_fragment:
|
||||
item = make_batch(item)
|
||||
@ -1158,6 +1172,8 @@ class TTS:
|
||||
pred_semantic_list = t2s_result.result
|
||||
assert pred_semantic_list
|
||||
pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list]
|
||||
infer_len.append(t2s_result.total_tokens)
|
||||
infer_time.append(t2s_result.infer_speed[-1])
|
||||
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
@ -1246,12 +1262,13 @@ class TTS:
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
|
||||
if idx == 0:
|
||||
ttfb_time = time.perf_counter() - ttfb_time
|
||||
t5 = time.perf_counter()
|
||||
t_45 += t5 - t4
|
||||
if return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||
yield self.audio_postprocess(
|
||||
console.print(f">> Time Stamps For Fragment: {t_34:.3f}\t{t_45:.3f}")
|
||||
tmp = self.audio_postprocess(
|
||||
[batch_audio_fragment],
|
||||
output_sr,
|
||||
None,
|
||||
@ -1260,19 +1277,23 @@ class TTS:
|
||||
fragment_interval,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
audio_len.append(len(tmp[-1]) / output_sr)
|
||||
yield tmp
|
||||
else:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
audio_len.append(1)
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
if len(audio) == 0:
|
||||
audio_len.append(1)
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
return
|
||||
yield self.audio_postprocess(
|
||||
|
||||
tmp = self.audio_postprocess(
|
||||
audio,
|
||||
output_sr,
|
||||
batch_index_list,
|
||||
@ -1281,10 +1302,14 @@ class TTS:
|
||||
fragment_interval,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
audio_len.append(len(tmp[-1]) / output_sr)
|
||||
|
||||
yield tmp
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# 必须返回一个空音频, 否则会导致显存不释放。
|
||||
audio_len.append(1)
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
# 重置模型, 否则会导致显存释放不完全。
|
||||
del self.t2s_model
|
||||
@ -1295,6 +1320,16 @@ class TTS:
|
||||
self.init_vits_weights(self.configs.vits_weights_path)
|
||||
raise e
|
||||
finally:
|
||||
infer_speed_avg = sum(infer_len) / sum(infer_time)
|
||||
rtf_value = sum((t1 - t0, t2 - t1, t_34, t_45)) / sum(audio_len)
|
||||
console.print(f">> Time Stamps: {t1 - t0:.3f}\t{t2 - t1:.3f}\t{t_34:.3f}\t{t_45:.3f}")
|
||||
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
|
||||
console.print(f">> RTF: {rtf_value:.2f}")
|
||||
if ttfb_time > 2:
|
||||
console.print(f">> TTFB: {ttfb_time:.3f} s")
|
||||
else:
|
||||
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
|
||||
|
||||
self.empty_cache()
|
||||
|
||||
def empty_cache(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user