diff --git a/GPT_SoVITS/inference_cli.py b/GPT_SoVITS/inference_cli.py index 459a3d36..b6736d34 100644 --- a/GPT_SoVITS/inference_cli.py +++ b/GPT_SoVITS/inference_cli.py @@ -3,7 +3,7 @@ import os import soundfile as sf from tools.i18n.i18n import I18nAuto -from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav, load_models i18n = I18nAuto() @@ -69,6 +69,7 @@ def main(): parser.add_argument("--output_path", required=True, help="Path to the output directory") args = parser.parse_args() + load_models() synthesize( args.gpt_model, diff --git a/GPT_SoVITS/inference_gui.py b/GPT_SoVITS/inference_gui.py index 379f7fa8..4efdbd0b 100644 --- a/GPT_SoVITS/inference_gui.py +++ b/GPT_SoVITS/inference_gui.py @@ -9,7 +9,7 @@ from tools.i18n.i18n import I18nAuto i18n = I18nAuto() -from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav +from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav, load_models class GPTSoVITSGUI(QMainWindow): @@ -281,6 +281,7 @@ class GPTSoVITSGUI(QMainWindow): target_text = self.target_text_input.text() output_path = self.output_input.text() + load_models() if GPT_model_path != self.GPT_Path: change_gpt_weights(gpt_path=GPT_model_path) self.GPT_Path = GPT_model_path diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 0e68a9c3..33bbb26a 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -510,51 +510,68 @@ bigvgan_model = hifigan_model = sv_cn_model = None resample_transform_dict = {} -def load_models( ): +def load_models(device_override=None): """ Load all models onto GPU. Call this function when you want to initialize models. + + Args: + device_override (str, optional): Override the global device setting. + If None, uses the global device variable. """ global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec - global bigvgan_model, hifigan_model, sv_cn_model + global bigvgan_model, hifigan_model, sv_cn_model, device - print("Loading models onto GPU...") + # Use device override if provided, otherwise use global device + target_device = device_override if device_override is not None else device + + print(f"Loading models onto {target_device}...") # Load BERT tokenizer and model print("Loading BERT model...") tokenizer = AutoTokenizer.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) if is_half == True: - bert_model = bert_model.half().to(device) + bert_model = bert_model.half().to(target_device) else: - bert_model = bert_model.to(device) + bert_model = bert_model.to(target_device) # Load SSL model print("Loading SSL model...") ssl_model = cnhubert.get_model() if is_half == True: - ssl_model = ssl_model.half().to(device) + ssl_model = ssl_model.half().to(target_device) else: - ssl_model = ssl_model.to(device) + ssl_model = ssl_model.to(target_device) + + # Temporarily update global device if override is provided for model loading functions + original_device = device if device_override is not None else None + if device_override is not None: + device = target_device - # Load SoVITS model - print("Loading SoVITS model...") try: - next(change_sovits_weights(sovits_path)) - except: - pass - - # Load GPT model - print("Loading GPT model...") - change_gpt_weights(gpt_path) - - # Load appropriate vocoder model based on version - print(f"Loading vocoder model for version {model_version}...") - if model_version == "v3": - init_bigvgan() - elif model_version == "v4": - init_hifigan() - elif model_version in {"v2Pro", "v2ProPlus"}: - init_sv_cn() + # Load SoVITS model + print("Loading SoVITS model...") + try: + next(change_sovits_weights(sovits_path)) + except: + pass + + # Load GPT model + print("Loading GPT model...") + change_gpt_weights(gpt_path) + + # Load appropriate vocoder model based on version + print(f"Loading vocoder model for version {model_version}...") + if model_version == "v3": + init_bigvgan() + elif model_version == "v4": + init_hifigan() + elif model_version in {"v2Pro", "v2ProPlus"}: + init_sv_cn() + finally: + # Restore original device if it was overridden + if original_device is not None: + device = original_device print("All models loaded successfully!")