mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
refactor: centralize model loading logic
This commit is contained in:
parent
65cf7d67d5
commit
fc987e2a65
@ -3,7 +3,7 @@ import os
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
@ -69,6 +69,7 @@ def main():
|
|||||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
load_models()
|
||||||
|
|
||||||
synthesize(
|
synthesize(
|
||||||
args.gpt_model,
|
args.gpt_model,
|
||||||
|
@ -9,7 +9,7 @@ from tools.i18n.i18n import I18nAuto
|
|||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
|
||||||
|
|
||||||
|
|
||||||
class GPTSoVITSGUI(QMainWindow):
|
class GPTSoVITSGUI(QMainWindow):
|
||||||
@ -281,6 +281,7 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
target_text = self.target_text_input.text()
|
target_text = self.target_text_input.text()
|
||||||
output_path = self.output_input.text()
|
output_path = self.output_input.text()
|
||||||
|
|
||||||
|
load_models()
|
||||||
if GPT_model_path != self.GPT_Path:
|
if GPT_model_path != self.GPT_Path:
|
||||||
change_gpt_weights(gpt_path=GPT_model_path)
|
change_gpt_weights(gpt_path=GPT_model_path)
|
||||||
self.GPT_Path = GPT_model_path
|
self.GPT_Path = GPT_model_path
|
||||||
|
@ -510,51 +510,68 @@ bigvgan_model = hifigan_model = sv_cn_model = None
|
|||||||
resample_transform_dict = {}
|
resample_transform_dict = {}
|
||||||
|
|
||||||
|
|
||||||
def load_models( ):
|
def load_models(device_override=None):
|
||||||
"""
|
"""
|
||||||
Load all models onto GPU. Call this function when you want to initialize models.
|
Load all models onto GPU. Call this function when you want to initialize models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_override (str, optional): Override the global device setting.
|
||||||
|
If None, uses the global device variable.
|
||||||
"""
|
"""
|
||||||
global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec
|
global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec
|
||||||
global bigvgan_model, hifigan_model, sv_cn_model
|
global bigvgan_model, hifigan_model, sv_cn_model, device
|
||||||
|
|
||||||
print("Loading models onto GPU...")
|
# Use device override if provided, otherwise use global device
|
||||||
|
target_device = device_override if device_override is not None else device
|
||||||
|
|
||||||
|
print(f"Loading models onto {target_device}...")
|
||||||
|
|
||||||
# Load BERT tokenizer and model
|
# Load BERT tokenizer and model
|
||||||
print("Loading BERT model...")
|
print("Loading BERT model...")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
bert_model = bert_model.half().to(device)
|
bert_model = bert_model.half().to(target_device)
|
||||||
else:
|
else:
|
||||||
bert_model = bert_model.to(device)
|
bert_model = bert_model.to(target_device)
|
||||||
|
|
||||||
# Load SSL model
|
# Load SSL model
|
||||||
print("Loading SSL model...")
|
print("Loading SSL model...")
|
||||||
ssl_model = cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
ssl_model = ssl_model.half().to(device)
|
ssl_model = ssl_model.half().to(target_device)
|
||||||
else:
|
else:
|
||||||
ssl_model = ssl_model.to(device)
|
ssl_model = ssl_model.to(target_device)
|
||||||
|
|
||||||
|
# Temporarily update global device if override is provided for model loading functions
|
||||||
|
original_device = device if device_override is not None else None
|
||||||
|
if device_override is not None:
|
||||||
|
device = target_device
|
||||||
|
|
||||||
# Load SoVITS model
|
|
||||||
print("Loading SoVITS model...")
|
|
||||||
try:
|
try:
|
||||||
next(change_sovits_weights(sovits_path))
|
# Load SoVITS model
|
||||||
except:
|
print("Loading SoVITS model...")
|
||||||
pass
|
try:
|
||||||
|
next(change_sovits_weights(sovits_path))
|
||||||
# Load GPT model
|
except:
|
||||||
print("Loading GPT model...")
|
pass
|
||||||
change_gpt_weights(gpt_path)
|
|
||||||
|
# Load GPT model
|
||||||
# Load appropriate vocoder model based on version
|
print("Loading GPT model...")
|
||||||
print(f"Loading vocoder model for version {model_version}...")
|
change_gpt_weights(gpt_path)
|
||||||
if model_version == "v3":
|
|
||||||
init_bigvgan()
|
# Load appropriate vocoder model based on version
|
||||||
elif model_version == "v4":
|
print(f"Loading vocoder model for version {model_version}...")
|
||||||
init_hifigan()
|
if model_version == "v3":
|
||||||
elif model_version in {"v2Pro", "v2ProPlus"}:
|
init_bigvgan()
|
||||||
init_sv_cn()
|
elif model_version == "v4":
|
||||||
|
init_hifigan()
|
||||||
|
elif model_version in {"v2Pro", "v2ProPlus"}:
|
||||||
|
init_sv_cn()
|
||||||
|
finally:
|
||||||
|
# Restore original device if it was overridden
|
||||||
|
if original_device is not None:
|
||||||
|
device = original_device
|
||||||
|
|
||||||
print("All models loaded successfully!")
|
print("All models loaded successfully!")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user