增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1

增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1
This commit is contained in:
RVC-Boss 2026-04-30 15:01:11 +08:00 committed by GitHub
parent ea2d2a8166
commit 6d95b559e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -137,6 +137,38 @@ if torch.cuda.is_available():
else:
device = "cpu"
def check_cuda_graph_support():
if device != "cuda":
return False
try:
major, _ = torch.cuda.get_device_capability()
if major < 7:
print("CUDA Graph: GPU compute capability < 7.0, disabled")
return False
a = torch.randn(2, 2, device="cuda")
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
b = a * 2
torch.cuda.current_stream().wait_stream(s)
out = torch.empty_like(b)
with torch.cuda.graph(g):
out.copy_(a * 2)
g.replay()
torch.cuda.synchronize()
del a, b, out, g, s
torch.cuda.empty_cache()
print("CUDA Graph: support check passed, auto-enabled")
return True
except Exception as e:
print(f"CUDA Graph: support check failed ({e}), disabled")
return False
cuda_graph_supported = check_cuda_graph_support()
dict_language_v1 = {
i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变
@ -229,7 +261,8 @@ v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path]
global vq_model, hps, version, model_version, dict_language, if_lora_v3
global vq_model, hps, version, model_version, dict_language, if_lora_v3, t2s_model_cudagraph
t2s_model_cudagraph = None
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
print(sovits_path, version, model_version, if_lora_v3)
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
@ -373,10 +406,16 @@ except:
pass
t2s_model_cudagraph = None
gpt_path_global = None
def change_gpt_weights(gpt_path):
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]
global hz, max_sec, t2s_model, config
global hz, max_sec, t2s_model, config, t2s_model_cudagraph, gpt_path_global
t2s_model_cudagraph = None
gpt_path_global = gpt_path
hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
config = dict_s1["config"]
@ -765,6 +804,7 @@ def get_tts_wav(
sample_steps=8,
if_sr=False,
pause_second=0.3,
use_cuda_graph=False,
):
global cache
if ref_wav_path:
@ -873,6 +913,37 @@ def get_tts_wav(
# print(cache.keys(),if_freeze)
if i_text in cache and if_freeze == True:
pred_semantic = cache[i_text]
else:
if use_cuda_graph and device == "cuda":
global t2s_model_cudagraph
if t2s_model_cudagraph is None:
from AR.models.t2s_model_cudagraph import CUDAGraphRunner
t2s_model_cudagraph = CUDAGraphRunner(
CUDAGraphRunner.load_decoder(gpt_path_global),
torch.device(device),
torch.float16 if is_half else torch.float32,
)
from AR.models.structs_cudagraph import T2SRequest
with torch.no_grad():
t2s_request = T2SRequest(
[all_phoneme_ids.squeeze(0)],
all_phoneme_len,
all_phoneme_ids.new_zeros((1, 0)) if ref_free else prompt,
[bert.squeeze(0)],
valid_length=1,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
use_cuda_graph=True,
)
t2s_result = t2s_model_cudagraph.generate(t2s_request)
if t2s_result.exception is not None:
print(t2s_result.exception)
print(t2s_result.traceback)
raise RuntimeError("CUDA Graph T2S inference failed")
pred_semantic = t2s_result.result[0].unsqueeze(0).unsqueeze(0)
cache[i_text] = pred_semantic
else:
with torch.no_grad():
pred_semantic, idx = t2s_model.model.infer_panel(
@ -1284,6 +1355,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
# get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
with gr.Row():
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
use_cuda_graph_checkbox = gr.Checkbox(
label="CUDA Graph " + i18n("加速"),
value=cuda_graph_supported,
interactive=True if torch.cuda.is_available() else False,
show_label=True,
scale=5,
visible=False,
)
output = gr.Audio(label=i18n("输出的语音"), scale=14)
inference_button.click(
@ -1305,6 +1384,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
sample_steps,
if_sr_Checkbox,
pause_second_slider,
use_cuda_graph_checkbox,
],
[output],
)