mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-03 20:40:30 +08:00
增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1
增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1
This commit is contained in:
parent
ea2d2a8166
commit
6d95b559e8
@ -137,6 +137,38 @@ if torch.cuda.is_available():
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
|
||||
def check_cuda_graph_support():
|
||||
if device != "cuda":
|
||||
return False
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 7:
|
||||
print("CUDA Graph: GPU compute capability < 7.0, disabled")
|
||||
return False
|
||||
a = torch.randn(2, 2, device="cuda")
|
||||
g = torch.cuda.CUDAGraph()
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
b = a * 2
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
out = torch.empty_like(b)
|
||||
with torch.cuda.graph(g):
|
||||
out.copy_(a * 2)
|
||||
g.replay()
|
||||
torch.cuda.synchronize()
|
||||
del a, b, out, g, s
|
||||
torch.cuda.empty_cache()
|
||||
print("CUDA Graph: support check passed, auto-enabled")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"CUDA Graph: support check failed ({e}), disabled")
|
||||
return False
|
||||
|
||||
|
||||
cuda_graph_supported = check_cuda_graph_support()
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||
@ -229,7 +261,8 @@ v3v4set = {"v3", "v4"}
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
global vq_model, hps, version, model_version, dict_language, if_lora_v3
|
||||
global vq_model, hps, version, model_version, dict_language, if_lora_v3, t2s_model_cudagraph
|
||||
t2s_model_cudagraph = None
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
print(sovits_path, version, model_version, if_lora_v3)
|
||||
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||
@ -373,10 +406,16 @@ except:
|
||||
pass
|
||||
|
||||
|
||||
t2s_model_cudagraph = None
|
||||
gpt_path_global = None
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
global hz, max_sec, t2s_model, config
|
||||
global hz, max_sec, t2s_model, config, t2s_model_cudagraph, gpt_path_global
|
||||
t2s_model_cudagraph = None
|
||||
gpt_path_global = gpt_path
|
||||
hz = 50
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
|
||||
config = dict_s1["config"]
|
||||
@ -765,6 +804,7 @@ def get_tts_wav(
|
||||
sample_steps=8,
|
||||
if_sr=False,
|
||||
pause_second=0.3,
|
||||
use_cuda_graph=False,
|
||||
):
|
||||
global cache
|
||||
if ref_wav_path:
|
||||
@ -874,20 +914,51 @@ def get_tts_wav(
|
||||
if i_text in cache and if_freeze == True:
|
||||
pred_semantic = cache[i_text]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
None if ref_free else prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
cache[i_text] = pred_semantic
|
||||
if use_cuda_graph and device == "cuda":
|
||||
global t2s_model_cudagraph
|
||||
if t2s_model_cudagraph is None:
|
||||
from AR.models.t2s_model_cudagraph import CUDAGraphRunner
|
||||
t2s_model_cudagraph = CUDAGraphRunner(
|
||||
CUDAGraphRunner.load_decoder(gpt_path_global),
|
||||
torch.device(device),
|
||||
torch.float16 if is_half else torch.float32,
|
||||
)
|
||||
from AR.models.structs_cudagraph import T2SRequest
|
||||
with torch.no_grad():
|
||||
t2s_request = T2SRequest(
|
||||
[all_phoneme_ids.squeeze(0)],
|
||||
all_phoneme_len,
|
||||
all_phoneme_ids.new_zeros((1, 0)) if ref_free else prompt,
|
||||
[bert.squeeze(0)],
|
||||
valid_length=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=hz * max_sec,
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
t2s_result = t2s_model_cudagraph.generate(t2s_request)
|
||||
if t2s_result.exception is not None:
|
||||
print(t2s_result.exception)
|
||||
print(t2s_result.traceback)
|
||||
raise RuntimeError("CUDA Graph T2S inference failed")
|
||||
pred_semantic = t2s_result.result[0].unsqueeze(0).unsqueeze(0)
|
||||
cache[i_text] = pred_semantic
|
||||
else:
|
||||
with torch.no_grad():
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
None if ref_free else prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
cache[i_text] = pred_semantic
|
||||
t3 = ttime()
|
||||
is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
|
||||
# print(23333,is_v2pro,model_version)
|
||||
@ -1284,6 +1355,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
# get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
|
||||
with gr.Row():
|
||||
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
||||
use_cuda_graph_checkbox = gr.Checkbox(
|
||||
label="CUDA Graph " + i18n("加速"),
|
||||
value=cuda_graph_supported,
|
||||
interactive=True if torch.cuda.is_available() else False,
|
||||
show_label=True,
|
||||
scale=5,
|
||||
visible=False,
|
||||
)
|
||||
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
||||
|
||||
inference_button.click(
|
||||
@ -1305,6 +1384,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
sample_steps,
|
||||
if_sr_Checkbox,
|
||||
pause_second_slider,
|
||||
use_cuda_graph_checkbox,
|
||||
],
|
||||
[output],
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user