mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-03 20:40:30 +08:00
增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1
增加cuda graph支持,普通推理模式推理速度原地翻倍,效果不变。1
This commit is contained in:
parent
ea2d2a8166
commit
6d95b559e8
@ -137,6 +137,38 @@ if torch.cuda.is_available():
|
|||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_graph_support():
|
||||||
|
if device != "cuda":
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
if major < 7:
|
||||||
|
print("CUDA Graph: GPU compute capability < 7.0, disabled")
|
||||||
|
return False
|
||||||
|
a = torch.randn(2, 2, device="cuda")
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
s = torch.cuda.Stream()
|
||||||
|
s.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
b = a * 2
|
||||||
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
out = torch.empty_like(b)
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
out.copy_(a * 2)
|
||||||
|
g.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
del a, b, out, g, s
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("CUDA Graph: support check passed, auto-enabled")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"CUDA Graph: support check failed ({e}), disabled")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
cuda_graph_supported = check_cuda_graph_support()
|
||||||
|
|
||||||
dict_language_v1 = {
|
dict_language_v1 = {
|
||||||
i18n("中文"): "all_zh", # 全部按中文识别
|
i18n("中文"): "all_zh", # 全部按中文识别
|
||||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||||
@ -229,7 +261,8 @@ v3v4set = {"v3", "v4"}
|
|||||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||||
if "!" in sovits_path or "!" in sovits_path:
|
if "!" in sovits_path or "!" in sovits_path:
|
||||||
sovits_path = name2sovits_path[sovits_path]
|
sovits_path = name2sovits_path[sovits_path]
|
||||||
global vq_model, hps, version, model_version, dict_language, if_lora_v3
|
global vq_model, hps, version, model_version, dict_language, if_lora_v3, t2s_model_cudagraph
|
||||||
|
t2s_model_cudagraph = None
|
||||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||||
print(sovits_path, version, model_version, if_lora_v3)
|
print(sovits_path, version, model_version, if_lora_v3)
|
||||||
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||||
@ -373,10 +406,16 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
t2s_model_cudagraph = None
|
||||||
|
gpt_path_global = None
|
||||||
|
|
||||||
|
|
||||||
def change_gpt_weights(gpt_path):
|
def change_gpt_weights(gpt_path):
|
||||||
if "!" in gpt_path or "!" in gpt_path:
|
if "!" in gpt_path or "!" in gpt_path:
|
||||||
gpt_path = name2gpt_path[gpt_path]
|
gpt_path = name2gpt_path[gpt_path]
|
||||||
global hz, max_sec, t2s_model, config
|
global hz, max_sec, t2s_model, config, t2s_model_cudagraph, gpt_path_global
|
||||||
|
t2s_model_cudagraph = None
|
||||||
|
gpt_path_global = gpt_path
|
||||||
hz = 50
|
hz = 50
|
||||||
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
|
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
|
||||||
config = dict_s1["config"]
|
config = dict_s1["config"]
|
||||||
@ -765,6 +804,7 @@ def get_tts_wav(
|
|||||||
sample_steps=8,
|
sample_steps=8,
|
||||||
if_sr=False,
|
if_sr=False,
|
||||||
pause_second=0.3,
|
pause_second=0.3,
|
||||||
|
use_cuda_graph=False,
|
||||||
):
|
):
|
||||||
global cache
|
global cache
|
||||||
if ref_wav_path:
|
if ref_wav_path:
|
||||||
@ -873,6 +913,37 @@ def get_tts_wav(
|
|||||||
# print(cache.keys(),if_freeze)
|
# print(cache.keys(),if_freeze)
|
||||||
if i_text in cache and if_freeze == True:
|
if i_text in cache and if_freeze == True:
|
||||||
pred_semantic = cache[i_text]
|
pred_semantic = cache[i_text]
|
||||||
|
else:
|
||||||
|
if use_cuda_graph and device == "cuda":
|
||||||
|
global t2s_model_cudagraph
|
||||||
|
if t2s_model_cudagraph is None:
|
||||||
|
from AR.models.t2s_model_cudagraph import CUDAGraphRunner
|
||||||
|
t2s_model_cudagraph = CUDAGraphRunner(
|
||||||
|
CUDAGraphRunner.load_decoder(gpt_path_global),
|
||||||
|
torch.device(device),
|
||||||
|
torch.float16 if is_half else torch.float32,
|
||||||
|
)
|
||||||
|
from AR.models.structs_cudagraph import T2SRequest
|
||||||
|
with torch.no_grad():
|
||||||
|
t2s_request = T2SRequest(
|
||||||
|
[all_phoneme_ids.squeeze(0)],
|
||||||
|
all_phoneme_len,
|
||||||
|
all_phoneme_ids.new_zeros((1, 0)) if ref_free else prompt,
|
||||||
|
[bert.squeeze(0)],
|
||||||
|
valid_length=1,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
early_stop_num=hz * max_sec,
|
||||||
|
use_cuda_graph=True,
|
||||||
|
)
|
||||||
|
t2s_result = t2s_model_cudagraph.generate(t2s_request)
|
||||||
|
if t2s_result.exception is not None:
|
||||||
|
print(t2s_result.exception)
|
||||||
|
print(t2s_result.traceback)
|
||||||
|
raise RuntimeError("CUDA Graph T2S inference failed")
|
||||||
|
pred_semantic = t2s_result.result[0].unsqueeze(0).unsqueeze(0)
|
||||||
|
cache[i_text] = pred_semantic
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||||
@ -1284,6 +1355,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
|||||||
# get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
|
# get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
||||||
|
use_cuda_graph_checkbox = gr.Checkbox(
|
||||||
|
label="CUDA Graph " + i18n("加速"),
|
||||||
|
value=cuda_graph_supported,
|
||||||
|
interactive=True if torch.cuda.is_available() else False,
|
||||||
|
show_label=True,
|
||||||
|
scale=5,
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
@ -1305,6 +1384,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
|||||||
sample_steps,
|
sample_steps,
|
||||||
if_sr_Checkbox,
|
if_sr_Checkbox,
|
||||||
pause_second_slider,
|
pause_second_slider,
|
||||||
|
use_cuda_graph_checkbox,
|
||||||
],
|
],
|
||||||
[output],
|
[output],
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user