refactor: separate loading model logic to a function instead of while importing

This commit is contained in:
Jacky He 2025-09-02 17:50:34 +08:00
parent 611ff1e8c0
commit 65cf7d67d5

View File

@ -160,15 +160,14 @@ dict_language_v2 = {
} }
dict_language = dict_language_v1 if version == "v1" else dict_language_v2 dict_language = dict_language_v1 if version == "v1" else dict_language_v2
tokenizer = AutoTokenizer.from_pretrained(bert_path) # Initialize model variables as None - they will be loaded by load_models() function
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) tokenizer = None
if is_half == True: bert_model = None
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
if tokenizer is None or bert_model is None:
raise RuntimeError("Models not loaded. Please call load_models() first.")
with torch.no_grad(): with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt") inputs = tokenizer(text, return_tensors="pt")
for i in inputs: for i in inputs:
@ -180,6 +179,8 @@ def get_bert_feature(text, word2ph):
for i in range(len(word2ph)): for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1) repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature) phone_level_feature.append(repeat_feature)
if len(phone_level_feature) == 0:
return torch.empty((res.shape[1], 0), dtype=res.dtype, device=res.device)
phone_level_feature = torch.cat(phone_level_feature, dim=0) phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T return phone_level_feature.T
@ -212,11 +213,8 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found") raise AttributeError(f"Attribute {item} not found")
ssl_model = cnhubert.get_model() # Initialize SSL model as None - it will be loaded by load_models() function
if is_half == True: ssl_model = None
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt ###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
@ -367,11 +365,13 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
f.write(json.dumps(data)) f.write(json.dumps(data))
try: # Initialize global model variables as None - they will be loaded by load_models() function
next(change_sovits_weights(sovits_path)) vq_model = None
except: hps = None
pass t2s_model = None
config = None
hz = None
max_sec = None
def change_gpt_weights(gpt_path): def change_gpt_weights(gpt_path):
if "" in gpt_path or "!" in gpt_path: if "" in gpt_path or "!" in gpt_path:
@ -385,7 +385,8 @@ def change_gpt_weights(gpt_path):
t2s_model.load_state_dict(dict_s1["weight"]) t2s_model.load_state_dict(dict_s1["weight"])
if is_half == True: if is_half == True:
t2s_model = t2s_model.half() t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device) else:
t2s_model = t2s_model.to(device)
t2s_model.eval() t2s_model.eval()
# total = sum([param.nelement() for param in t2s_model.parameters()]) # total = sum([param.nelement() for param in t2s_model.parameters()])
# print("Number of parameter: %.2fM" % (total / 1e6)) # print("Number of parameter: %.2fM" % (total / 1e6))
@ -397,7 +398,8 @@ def change_gpt_weights(gpt_path):
f.write(json.dumps(data)) f.write(json.dumps(data))
change_gpt_weights(gpt_path) # Remove the automatic loading of GPT weights at import time
# change_gpt_weights(gpt_path)
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import torch import torch
@ -495,17 +497,68 @@ def init_sv_cn():
clean_hifigan_model() clean_hifigan_model()
# Initialize vocoder model variables as None - they will be loaded by load_models() function
bigvgan_model = hifigan_model = sv_cn_model = None bigvgan_model = hifigan_model = sv_cn_model = None
if model_version == "v3": # Remove automatic vocoder loading at import time
init_bigvgan() # if model_version == "v3":
if model_version == "v4": # init_bigvgan()
init_hifigan() # if model_version == "v4":
if model_version in {"v2Pro", "v2ProPlus"}: # init_hifigan()
init_sv_cn() # if model_version in {"v2Pro", "v2ProPlus"}:
# init_sv_cn()
resample_transform_dict = {} resample_transform_dict = {}
def load_models( ):
"""
Load all models onto GPU. Call this function when you want to initialize models.
"""
global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec
global bigvgan_model, hifigan_model, sv_cn_model
print("Loading models onto GPU...")
# Load BERT tokenizer and model
print("Loading BERT model...")
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
# Load SSL model
print("Loading SSL model...")
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
# Load SoVITS model
print("Loading SoVITS model...")
try:
next(change_sovits_weights(sovits_path))
except:
pass
# Load GPT model
print("Loading GPT model...")
change_gpt_weights(gpt_path)
# Load appropriate vocoder model based on version
print(f"Loading vocoder model for version {model_version}...")
if model_version == "v3":
init_bigvgan()
elif model_version == "v4":
init_hifigan()
elif model_version in {"v2Pro", "v2ProPlus"}:
init_sv_cn()
print("All models loaded successfully!")
def resample(audio_tensor, sr0, sr1, device): def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict global resample_transform_dict
key = "%s-%s-%s" % (sr0, sr1, str(device)) key = "%s-%s-%s" % (sr0, sr1, str(device))
@ -1353,6 +1406,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
# gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))) # gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")))
if __name__ == "__main__": if __name__ == "__main__":
load_models()
app.queue().launch( # concurrency_count=511, max_size=1022 app.queue().launch( # concurrency_count=511, max_size=1022
server_name="0.0.0.0", server_name="0.0.0.0",
inbrowser=True, inbrowser=True,