增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1

增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1
This commit is contained in:
RVC-Boss 2026-04-30 15:01:11 +08:00 committed by GitHub
parent ea2d2a8166
commit 6d95b559e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -137,6 +137,38 @@ if torch.cuda.is_available():
else: else:
device = "cpu" device = "cpu"
def check_cuda_graph_support():
if device != "cuda":
return False
try:
major, _ = torch.cuda.get_device_capability()
if major < 7:
print("CUDA Graph: GPU compute capability < 7.0, disabled")
return False
a = torch.randn(2, 2, device="cuda")
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
b = a * 2
torch.cuda.current_stream().wait_stream(s)
out = torch.empty_like(b)
with torch.cuda.graph(g):
out.copy_(a * 2)
g.replay()
torch.cuda.synchronize()
del a, b, out, g, s
torch.cuda.empty_cache()
print("CUDA Graph: support check passed, auto-enabled")
return True
except Exception as e:
print(f"CUDA Graph: support check failed ({e}), disabled")
return False
cuda_graph_supported = check_cuda_graph_support()
dict_language_v1 = { dict_language_v1 = {
i18n("中文"): "all_zh", # 全部按中文识别 i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变 i18n("英文"): "en", # 全部按英文识别#######不变
@ -229,7 +261,8 @@ v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
if "" in sovits_path or "!" in sovits_path: if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path] sovits_path = name2sovits_path[sovits_path]
global vq_model, hps, version, model_version, dict_language, if_lora_v3 global vq_model, hps, version, model_version, dict_language, if_lora_v3, t2s_model_cudagraph
t2s_model_cudagraph = None
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
print(sovits_path, version, model_version, if_lora_v3) print(sovits_path, version, model_version, if_lora_v3)
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
@ -373,10 +406,16 @@ except:
pass pass
t2s_model_cudagraph = None
gpt_path_global = None
def change_gpt_weights(gpt_path): def change_gpt_weights(gpt_path):
if "" in gpt_path or "!" in gpt_path: if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path] gpt_path = name2gpt_path[gpt_path]
global hz, max_sec, t2s_model, config global hz, max_sec, t2s_model, config, t2s_model_cudagraph, gpt_path_global
t2s_model_cudagraph = None
gpt_path_global = gpt_path
hz = 50 hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
config = dict_s1["config"] config = dict_s1["config"]
@ -765,6 +804,7 @@ def get_tts_wav(
sample_steps=8, sample_steps=8,
if_sr=False, if_sr=False,
pause_second=0.3, pause_second=0.3,
use_cuda_graph=False,
): ):
global cache global cache
if ref_wav_path: if ref_wav_path:
@ -874,20 +914,51 @@ def get_tts_wav(
if i_text in cache and if_freeze == True: if i_text in cache and if_freeze == True:
pred_semantic = cache[i_text] pred_semantic = cache[i_text]
else: else:
with torch.no_grad(): if use_cuda_graph and device == "cuda":
pred_semantic, idx = t2s_model.model.infer_panel( global t2s_model_cudagraph
all_phoneme_ids, if t2s_model_cudagraph is None:
all_phoneme_len, from AR.models.t2s_model_cudagraph import CUDAGraphRunner
None if ref_free else prompt, t2s_model_cudagraph = CUDAGraphRunner(
bert, CUDAGraphRunner.load_decoder(gpt_path_global),
# prompt_phone_len=ph_offset, torch.device(device),
top_k=top_k, torch.float16 if is_half else torch.float32,
top_p=top_p, )
temperature=temperature, from AR.models.structs_cudagraph import T2SRequest
early_stop_num=hz * max_sec, with torch.no_grad():
) t2s_request = T2SRequest(
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) [all_phoneme_ids.squeeze(0)],
cache[i_text] = pred_semantic all_phoneme_len,
all_phoneme_ids.new_zeros((1, 0)) if ref_free else prompt,
[bert.squeeze(0)],
valid_length=1,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
use_cuda_graph=True,
)
t2s_result = t2s_model_cudagraph.generate(t2s_request)
if t2s_result.exception is not None:
print(t2s_result.exception)
print(t2s_result.traceback)
raise RuntimeError("CUDA Graph T2S inference failed")
pred_semantic = t2s_result.result[0].unsqueeze(0).unsqueeze(0)
cache[i_text] = pred_semantic
else:
with torch.no_grad():
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
None if ref_free else prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
cache[i_text] = pred_semantic
t3 = ttime() t3 = ttime()
is_v2pro = model_version in {"v2Pro", "v2ProPlus"} is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
# print(23333,is_v2pro,model_version) # print(23333,is_v2pro,model_version)
@ -1284,6 +1355,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
# get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary") # get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
with gr.Row(): with gr.Row():
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25) inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
use_cuda_graph_checkbox = gr.Checkbox(
label="CUDA Graph " + i18n("加速"),
value=cuda_graph_supported,
interactive=True if torch.cuda.is_available() else False,
show_label=True,
scale=5,
visible=False,
)
output = gr.Audio(label=i18n("输出的语音"), scale=14) output = gr.Audio(label=i18n("输出的语音"), scale=14)
inference_button.click( inference_button.click(
@ -1305,6 +1384,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
sample_steps, sample_steps,
if_sr_Checkbox, if_sr_Checkbox,
pause_second_slider, pause_second_slider,
use_cuda_graph_checkbox,
], ],
[output], [output],
) )