From 611ff1e8c02d584c97b053ec793696ceb9f5ed3f Mon Sep 17 00:00:00 2001 From: Jacky He Date: Tue, 2 Sep 2025 17:48:44 +0800 Subject: [PATCH 1/3] feat: make GPU selectable in get_tts_wav --- GPT_SoVITS/inference_webui.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index a361ed58..9e4a4407 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -765,8 +765,17 @@ def get_tts_wav( sample_steps=8, if_sr=False, pause_second=0.3, + device_override=None ): global cache + global device + if device_override: + device = device_override + + # Check if models are loaded + if (tokenizer is None or bert_model is None or ssl_model is None or + vq_model is None or t2s_model is None): + raise RuntimeError("Models not loaded. Please call load_models() first.") if ref_wav_path: pass else: From 65cf7d67d57a1534675a543c4c672766f78abbc5 Mon Sep 17 00:00:00 2001 From: Jacky He Date: Tue, 2 Sep 2025 17:50:34 +0800 Subject: [PATCH 2/3] refactor: separate loading model logic to a function instead of while importing --- GPT_SoVITS/inference_webui.py | 102 ++++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 24 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 9e4a4407..0e68a9c3 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -160,15 +160,14 @@ dict_language_v2 = { } dict_language = dict_language_v1 if version == "v1" else dict_language_v2 -tokenizer = AutoTokenizer.from_pretrained(bert_path) -bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) -if is_half == True: - bert_model = bert_model.half().to(device) -else: - bert_model = bert_model.to(device) +# Initialize model variables as None - they will be loaded by load_models() function +tokenizer = None +bert_model = None def get_bert_feature(text, word2ph): + if tokenizer is None or bert_model is None: + raise RuntimeError("Models not loaded. Please call load_models() first.") with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: @@ -180,6 +179,8 @@ def get_bert_feature(text, word2ph): for i in range(len(word2ph)): repeat_feature = res[i].repeat(word2ph[i], 1) phone_level_feature.append(repeat_feature) + if len(phone_level_feature) == 0: + return torch.empty((res.shape[1], 0), dtype=res.dtype, device=res.device) phone_level_feature = torch.cat(phone_level_feature, dim=0) return phone_level_feature.T @@ -212,11 +213,8 @@ class DictToAttrRecursive(dict): raise AttributeError(f"Attribute {item} not found") -ssl_model = cnhubert.get_model() -if is_half == True: - ssl_model = ssl_model.half().to(device) -else: - ssl_model = ssl_model.to(device) +# Initialize SSL model as None - it will be loaded by load_models() function +ssl_model = None ###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt @@ -367,11 +365,13 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) f.write(json.dumps(data)) -try: - next(change_sovits_weights(sovits_path)) -except: - pass - +# Initialize global model variables as None - they will be loaded by load_models() function +vq_model = None +hps = None +t2s_model = None +config = None +hz = None +max_sec = None def change_gpt_weights(gpt_path): if "!" in gpt_path or "!" in gpt_path: @@ -385,7 +385,8 @@ def change_gpt_weights(gpt_path): t2s_model.load_state_dict(dict_s1["weight"]) if is_half == True: t2s_model = t2s_model.half() - t2s_model = t2s_model.to(device) + else: + t2s_model = t2s_model.to(device) t2s_model.eval() # total = sum([param.nelement() for param in t2s_model.parameters()]) # print("Number of parameter: %.2fM" % (total / 1e6)) @@ -397,7 +398,8 @@ def change_gpt_weights(gpt_path): f.write(json.dumps(data)) -change_gpt_weights(gpt_path) +# Remove the automatic loading of GPT weights at import time +# change_gpt_weights(gpt_path) os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" import torch @@ -495,17 +497,68 @@ def init_sv_cn(): clean_hifigan_model() +# Initialize vocoder model variables as None - they will be loaded by load_models() function bigvgan_model = hifigan_model = sv_cn_model = None -if model_version == "v3": - init_bigvgan() -if model_version == "v4": - init_hifigan() -if model_version in {"v2Pro", "v2ProPlus"}: - init_sv_cn() +# Remove automatic vocoder loading at import time +# if model_version == "v3": +# init_bigvgan() +# if model_version == "v4": +# init_hifigan() +# if model_version in {"v2Pro", "v2ProPlus"}: +# init_sv_cn() resample_transform_dict = {} +def load_models( ): + """ + Load all models onto GPU. Call this function when you want to initialize models. + """ + global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec + global bigvgan_model, hifigan_model, sv_cn_model + + print("Loading models onto GPU...") + + # Load BERT tokenizer and model + print("Loading BERT model...") + tokenizer = AutoTokenizer.from_pretrained(bert_path) + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) + if is_half == True: + bert_model = bert_model.half().to(device) + else: + bert_model = bert_model.to(device) + + # Load SSL model + print("Loading SSL model...") + ssl_model = cnhubert.get_model() + if is_half == True: + ssl_model = ssl_model.half().to(device) + else: + ssl_model = ssl_model.to(device) + + # Load SoVITS model + print("Loading SoVITS model...") + try: + next(change_sovits_weights(sovits_path)) + except: + pass + + # Load GPT model + print("Loading GPT model...") + change_gpt_weights(gpt_path) + + # Load appropriate vocoder model based on version + print(f"Loading vocoder model for version {model_version}...") + if model_version == "v3": + init_bigvgan() + elif model_version == "v4": + init_hifigan() + elif model_version in {"v2Pro", "v2ProPlus"}: + init_sv_cn() + + print("All models loaded successfully!") + + def resample(audio_tensor, sr0, sr1, device): global resample_transform_dict key = "%s-%s-%s" % (sr0, sr1, str(device)) @@ -1353,6 +1406,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css # gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))) if __name__ == "__main__": + load_models() app.queue().launch( # concurrency_count=511, max_size=1022 server_name="0.0.0.0", inbrowser=True, From fc987e2a6512ae82ba27ec925c221d54cc54a3ae Mon Sep 17 00:00:00 2001 From: Jacky He Date: Thu, 4 Sep 2025 09:38:04 +0800 Subject: [PATCH 3/3] refactor: centralize model loading logic --- GPT_SoVITS/inference_cli.py | 3 +- GPT_SoVITS/inference_gui.py | 3 +- GPT_SoVITS/inference_webui.py | 67 ++++++++++++++++++++++------------- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/GPT_SoVITS/inference_cli.py b/GPT_SoVITS/inference_cli.py index 459a3d36..b6736d34 100644 --- a/GPT_SoVITS/inference_cli.py +++ b/GPT_SoVITS/inference_cli.py @@ -3,7 +3,7 @@ import os import soundfile as sf from tools.i18n.i18n import I18nAuto -from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav, load_models i18n = I18nAuto() @@ -69,6 +69,7 @@ def main(): parser.add_argument("--output_path", required=True, help="Path to the output directory") args = parser.parse_args() + load_models() synthesize( args.gpt_model, diff --git a/GPT_SoVITS/inference_gui.py b/GPT_SoVITS/inference_gui.py index 379f7fa8..4efdbd0b 100644 --- a/GPT_SoVITS/inference_gui.py +++ b/GPT_SoVITS/inference_gui.py @@ -9,7 +9,7 @@ from tools.i18n.i18n import I18nAuto i18n = I18nAuto() -from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav +from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav, load_models class GPTSoVITSGUI(QMainWindow): @@ -281,6 +281,7 @@ class GPTSoVITSGUI(QMainWindow): target_text = self.target_text_input.text() output_path = self.output_input.text() + load_models() if GPT_model_path != self.GPT_Path: change_gpt_weights(gpt_path=GPT_model_path) self.GPT_Path = GPT_model_path diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 0e68a9c3..33bbb26a 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -510,51 +510,68 @@ bigvgan_model = hifigan_model = sv_cn_model = None resample_transform_dict = {} -def load_models( ): +def load_models(device_override=None): """ Load all models onto GPU. Call this function when you want to initialize models. + + Args: + device_override (str, optional): Override the global device setting. + If None, uses the global device variable. """ global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec - global bigvgan_model, hifigan_model, sv_cn_model + global bigvgan_model, hifigan_model, sv_cn_model, device - print("Loading models onto GPU...") + # Use device override if provided, otherwise use global device + target_device = device_override if device_override is not None else device + + print(f"Loading models onto {target_device}...") # Load BERT tokenizer and model print("Loading BERT model...") tokenizer = AutoTokenizer.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) if is_half == True: - bert_model = bert_model.half().to(device) + bert_model = bert_model.half().to(target_device) else: - bert_model = bert_model.to(device) + bert_model = bert_model.to(target_device) # Load SSL model print("Loading SSL model...") ssl_model = cnhubert.get_model() if is_half == True: - ssl_model = ssl_model.half().to(device) + ssl_model = ssl_model.half().to(target_device) else: - ssl_model = ssl_model.to(device) + ssl_model = ssl_model.to(target_device) + + # Temporarily update global device if override is provided for model loading functions + original_device = device if device_override is not None else None + if device_override is not None: + device = target_device - # Load SoVITS model - print("Loading SoVITS model...") try: - next(change_sovits_weights(sovits_path)) - except: - pass - - # Load GPT model - print("Loading GPT model...") - change_gpt_weights(gpt_path) - - # Load appropriate vocoder model based on version - print(f"Loading vocoder model for version {model_version}...") - if model_version == "v3": - init_bigvgan() - elif model_version == "v4": - init_hifigan() - elif model_version in {"v2Pro", "v2ProPlus"}: - init_sv_cn() + # Load SoVITS model + print("Loading SoVITS model...") + try: + next(change_sovits_weights(sovits_path)) + except: + pass + + # Load GPT model + print("Loading GPT model...") + change_gpt_weights(gpt_path) + + # Load appropriate vocoder model based on version + print(f"Loading vocoder model for version {model_version}...") + if model_version == "v3": + init_bigvgan() + elif model_version == "v4": + init_hifigan() + elif model_version in {"v2Pro", "v2ProPlus"}: + init_sv_cn() + finally: + # Restore original device if it was overridden + if original_device is not None: + device = original_device print("All models loaded successfully!")