mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
refactor: centralize model loading logic
This commit is contained in:
parent
65cf7d67d5
commit
fc987e2a65
@ -3,7 +3,7 @@ import os
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
@ -69,6 +69,7 @@ def main():
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
load_models()
|
||||
|
||||
synthesize(
|
||||
args.gpt_model,
|
||||
|
@ -9,7 +9,7 @@ from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
|
||||
|
||||
|
||||
class GPTSoVITSGUI(QMainWindow):
|
||||
@ -281,6 +281,7 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
target_text = self.target_text_input.text()
|
||||
output_path = self.output_input.text()
|
||||
|
||||
load_models()
|
||||
if GPT_model_path != self.GPT_Path:
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
self.GPT_Path = GPT_model_path
|
||||
|
@ -510,51 +510,68 @@ bigvgan_model = hifigan_model = sv_cn_model = None
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def load_models( ):
|
||||
def load_models(device_override=None):
|
||||
"""
|
||||
Load all models onto GPU. Call this function when you want to initialize models.
|
||||
|
||||
Args:
|
||||
device_override (str, optional): Override the global device setting.
|
||||
If None, uses the global device variable.
|
||||
"""
|
||||
global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec
|
||||
global bigvgan_model, hifigan_model, sv_cn_model
|
||||
global bigvgan_model, hifigan_model, sv_cn_model, device
|
||||
|
||||
print("Loading models onto GPU...")
|
||||
# Use device override if provided, otherwise use global device
|
||||
target_device = device_override if device_override is not None else device
|
||||
|
||||
print(f"Loading models onto {target_device}...")
|
||||
|
||||
# Load BERT tokenizer and model
|
||||
print("Loading BERT model...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if is_half == True:
|
||||
bert_model = bert_model.half().to(device)
|
||||
bert_model = bert_model.half().to(target_device)
|
||||
else:
|
||||
bert_model = bert_model.to(device)
|
||||
bert_model = bert_model.to(target_device)
|
||||
|
||||
# Load SSL model
|
||||
print("Loading SSL model...")
|
||||
ssl_model = cnhubert.get_model()
|
||||
if is_half == True:
|
||||
ssl_model = ssl_model.half().to(device)
|
||||
ssl_model = ssl_model.half().to(target_device)
|
||||
else:
|
||||
ssl_model = ssl_model.to(device)
|
||||
ssl_model = ssl_model.to(target_device)
|
||||
|
||||
# Temporarily update global device if override is provided for model loading functions
|
||||
original_device = device if device_override is not None else None
|
||||
if device_override is not None:
|
||||
device = target_device
|
||||
|
||||
# Load SoVITS model
|
||||
print("Loading SoVITS model...")
|
||||
try:
|
||||
next(change_sovits_weights(sovits_path))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Load GPT model
|
||||
print("Loading GPT model...")
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
# Load appropriate vocoder model based on version
|
||||
print(f"Loading vocoder model for version {model_version}...")
|
||||
if model_version == "v3":
|
||||
init_bigvgan()
|
||||
elif model_version == "v4":
|
||||
init_hifigan()
|
||||
elif model_version in {"v2Pro", "v2ProPlus"}:
|
||||
init_sv_cn()
|
||||
# Load SoVITS model
|
||||
print("Loading SoVITS model...")
|
||||
try:
|
||||
next(change_sovits_weights(sovits_path))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Load GPT model
|
||||
print("Loading GPT model...")
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
# Load appropriate vocoder model based on version
|
||||
print(f"Loading vocoder model for version {model_version}...")
|
||||
if model_version == "v3":
|
||||
init_bigvgan()
|
||||
elif model_version == "v4":
|
||||
init_hifigan()
|
||||
elif model_version in {"v2Pro", "v2ProPlus"}:
|
||||
init_sv_cn()
|
||||
finally:
|
||||
# Restore original device if it was overridden
|
||||
if original_device is not None:
|
||||
device = original_device
|
||||
|
||||
print("All models loaded successfully!")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user