refactor: centralize model loading logic

This commit is contained in:
Jacky He 2025-09-04 09:38:04 +08:00
parent 65cf7d67d5
commit fc987e2a65
3 changed files with 46 additions and 27 deletions

View File

@ -3,7 +3,7 @@ import os
import soundfile as sf import soundfile as sf
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
i18n = I18nAuto() i18n = I18nAuto()
@ -69,6 +69,7 @@ def main():
parser.add_argument("--output_path", required=True, help="Path to the output directory") parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args() args = parser.parse_args()
load_models()
synthesize( synthesize(
args.gpt_model, args.gpt_model,

View File

@ -9,7 +9,7 @@ from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
class GPTSoVITSGUI(QMainWindow): class GPTSoVITSGUI(QMainWindow):
@ -281,6 +281,7 @@ class GPTSoVITSGUI(QMainWindow):
target_text = self.target_text_input.text() target_text = self.target_text_input.text()
output_path = self.output_input.text() output_path = self.output_input.text()
load_models()
if GPT_model_path != self.GPT_Path: if GPT_model_path != self.GPT_Path:
change_gpt_weights(gpt_path=GPT_model_path) change_gpt_weights(gpt_path=GPT_model_path)
self.GPT_Path = GPT_model_path self.GPT_Path = GPT_model_path

View File

@ -510,51 +510,68 @@ bigvgan_model = hifigan_model = sv_cn_model = None
resample_transform_dict = {} resample_transform_dict = {}
def load_models( ): def load_models(device_override=None):
""" """
Load all models onto GPU. Call this function when you want to initialize models. Load all models onto GPU. Call this function when you want to initialize models.
Args:
device_override (str, optional): Override the global device setting.
If None, uses the global device variable.
""" """
global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec
global bigvgan_model, hifigan_model, sv_cn_model global bigvgan_model, hifigan_model, sv_cn_model, device
print("Loading models onto GPU...") # Use device override if provided, otherwise use global device
target_device = device_override if device_override is not None else device
print(f"Loading models onto {target_device}...")
# Load BERT tokenizer and model # Load BERT tokenizer and model
print("Loading BERT model...") print("Loading BERT model...")
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True: if is_half == True:
bert_model = bert_model.half().to(device) bert_model = bert_model.half().to(target_device)
else: else:
bert_model = bert_model.to(device) bert_model = bert_model.to(target_device)
# Load SSL model # Load SSL model
print("Loading SSL model...") print("Loading SSL model...")
ssl_model = cnhubert.get_model() ssl_model = cnhubert.get_model()
if is_half == True: if is_half == True:
ssl_model = ssl_model.half().to(device) ssl_model = ssl_model.half().to(target_device)
else: else:
ssl_model = ssl_model.to(device) ssl_model = ssl_model.to(target_device)
# Temporarily update global device if override is provided for model loading functions
original_device = device if device_override is not None else None
if device_override is not None:
device = target_device
# Load SoVITS model
print("Loading SoVITS model...")
try: try:
next(change_sovits_weights(sovits_path)) # Load SoVITS model
except: print("Loading SoVITS model...")
pass try:
next(change_sovits_weights(sovits_path))
# Load GPT model except:
print("Loading GPT model...") pass
change_gpt_weights(gpt_path)
# Load GPT model
# Load appropriate vocoder model based on version print("Loading GPT model...")
print(f"Loading vocoder model for version {model_version}...") change_gpt_weights(gpt_path)
if model_version == "v3":
init_bigvgan() # Load appropriate vocoder model based on version
elif model_version == "v4": print(f"Loading vocoder model for version {model_version}...")
init_hifigan() if model_version == "v3":
elif model_version in {"v2Pro", "v2ProPlus"}: init_bigvgan()
init_sv_cn() elif model_version == "v4":
init_hifigan()
elif model_version in {"v2Pro", "v2ProPlus"}:
init_sv_cn()
finally:
# Restore original device if it was overridden
if original_device is not None:
device = original_device
print("All models loaded successfully!") print("All models loaded successfully!")