refactor: centralize model loading logic

This commit is contained in:
Jacky He 2025-09-04 09:38:04 +08:00
parent 65cf7d67d5
commit fc987e2a65
3 changed files with 46 additions and 27 deletions

View File

@ -3,7 +3,7 @@ import os
import soundfile as sf
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
i18n = I18nAuto()
@ -69,6 +69,7 @@ def main():
parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args()
load_models()
synthesize(
args.gpt_model,

View File

@ -9,7 +9,7 @@ from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
class GPTSoVITSGUI(QMainWindow):
@ -281,6 +281,7 @@ class GPTSoVITSGUI(QMainWindow):
target_text = self.target_text_input.text()
output_path = self.output_input.text()
load_models()
if GPT_model_path != self.GPT_Path:
change_gpt_weights(gpt_path=GPT_model_path)
self.GPT_Path = GPT_model_path

View File

@ -510,51 +510,68 @@ bigvgan_model = hifigan_model = sv_cn_model = None
resample_transform_dict = {}
def load_models( ):
def load_models(device_override=None):
"""
Load all models onto GPU. Call this function when you want to initialize models.
Args:
device_override (str, optional): Override the global device setting.
If None, uses the global device variable.
"""
global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec
global bigvgan_model, hifigan_model, sv_cn_model
global bigvgan_model, hifigan_model, sv_cn_model, device
print("Loading models onto GPU...")
# Use device override if provided, otherwise use global device
target_device = device_override if device_override is not None else device
print(f"Loading models onto {target_device}...")
# Load BERT tokenizer and model
print("Loading BERT model...")
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
bert_model = bert_model.half().to(target_device)
else:
bert_model = bert_model.to(device)
bert_model = bert_model.to(target_device)
# Load SSL model
print("Loading SSL model...")
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
ssl_model = ssl_model.half().to(target_device)
else:
ssl_model = ssl_model.to(device)
ssl_model = ssl_model.to(target_device)
# Temporarily update global device if override is provided for model loading functions
original_device = device if device_override is not None else None
if device_override is not None:
device = target_device
# Load SoVITS model
print("Loading SoVITS model...")
try:
next(change_sovits_weights(sovits_path))
except:
pass
# Load GPT model
print("Loading GPT model...")
change_gpt_weights(gpt_path)
# Load appropriate vocoder model based on version
print(f"Loading vocoder model for version {model_version}...")
if model_version == "v3":
init_bigvgan()
elif model_version == "v4":
init_hifigan()
elif model_version in {"v2Pro", "v2ProPlus"}:
init_sv_cn()
# Load SoVITS model
print("Loading SoVITS model...")
try:
next(change_sovits_weights(sovits_path))
except:
pass
# Load GPT model
print("Loading GPT model...")
change_gpt_weights(gpt_path)
# Load appropriate vocoder model based on version
print(f"Loading vocoder model for version {model_version}...")
if model_version == "v3":
init_bigvgan()
elif model_version == "v4":
init_hifigan()
elif model_version in {"v2Pro", "v2ProPlus"}:
init_sv_cn()
finally:
# Restore original device if it was overridden
if original_device is not None:
device = original_device
print("All models loaded successfully!")