mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
Merge fc987e2a6512ae82ba27ec925c221d54cc54a3ae into 11aa78bd9bda8b53047cfcae03abf7ca94d27391
This commit is contained in:
commit
711209c612
@ -3,7 +3,7 @@ import os
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
@ -69,6 +69,7 @@ def main():
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
load_models()
|
||||
|
||||
synthesize(
|
||||
args.gpt_model,
|
||||
|
@ -9,7 +9,7 @@ from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav, load_models
|
||||
|
||||
|
||||
class GPTSoVITSGUI(QMainWindow):
|
||||
@ -281,6 +281,7 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
target_text = self.target_text_input.text()
|
||||
output_path = self.output_input.text()
|
||||
|
||||
load_models()
|
||||
if GPT_model_path != self.GPT_Path:
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
self.GPT_Path = GPT_model_path
|
||||
|
@ -160,15 +160,14 @@ dict_language_v2 = {
|
||||
}
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if is_half == True:
|
||||
bert_model = bert_model.half().to(device)
|
||||
else:
|
||||
bert_model = bert_model.to(device)
|
||||
# Initialize model variables as None - they will be loaded by load_models() function
|
||||
tokenizer = None
|
||||
bert_model = None
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
if tokenizer is None or bert_model is None:
|
||||
raise RuntimeError("Models not loaded. Please call load_models() first.")
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
@ -180,6 +179,8 @@ def get_bert_feature(text, word2ph):
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
if len(phone_level_feature) == 0:
|
||||
return torch.empty((res.shape[1], 0), dtype=res.dtype, device=res.device)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
@ -212,11 +213,8 @@ class DictToAttrRecursive(dict):
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
ssl_model = cnhubert.get_model()
|
||||
if is_half == True:
|
||||
ssl_model = ssl_model.half().to(device)
|
||||
else:
|
||||
ssl_model = ssl_model.to(device)
|
||||
# Initialize SSL model as None - it will be loaded by load_models() function
|
||||
ssl_model = None
|
||||
|
||||
|
||||
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
|
||||
@ -367,11 +365,13 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
f.write(json.dumps(data))
|
||||
|
||||
|
||||
try:
|
||||
next(change_sovits_weights(sovits_path))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Initialize global model variables as None - they will be loaded by load_models() function
|
||||
vq_model = None
|
||||
hps = None
|
||||
t2s_model = None
|
||||
config = None
|
||||
hz = None
|
||||
max_sec = None
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
@ -385,7 +385,8 @@ def change_gpt_weights(gpt_path):
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if is_half == True:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(device)
|
||||
else:
|
||||
t2s_model = t2s_model.to(device)
|
||||
t2s_model.eval()
|
||||
# total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
# print("Number of parameter: %.2fM" % (total / 1e6))
|
||||
@ -397,7 +398,8 @@ def change_gpt_weights(gpt_path):
|
||||
f.write(json.dumps(data))
|
||||
|
||||
|
||||
change_gpt_weights(gpt_path)
|
||||
# Remove the automatic loading of GPT weights at import time
|
||||
# change_gpt_weights(gpt_path)
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
import torch
|
||||
|
||||
@ -495,17 +497,85 @@ def init_sv_cn():
|
||||
clean_hifigan_model()
|
||||
|
||||
|
||||
# Initialize vocoder model variables as None - they will be loaded by load_models() function
|
||||
bigvgan_model = hifigan_model = sv_cn_model = None
|
||||
if model_version == "v3":
|
||||
init_bigvgan()
|
||||
if model_version == "v4":
|
||||
init_hifigan()
|
||||
if model_version in {"v2Pro", "v2ProPlus"}:
|
||||
init_sv_cn()
|
||||
# Remove automatic vocoder loading at import time
|
||||
# if model_version == "v3":
|
||||
# init_bigvgan()
|
||||
# if model_version == "v4":
|
||||
# init_hifigan()
|
||||
# if model_version in {"v2Pro", "v2ProPlus"}:
|
||||
# init_sv_cn()
|
||||
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def load_models(device_override=None):
|
||||
"""
|
||||
Load all models onto GPU. Call this function when you want to initialize models.
|
||||
|
||||
Args:
|
||||
device_override (str, optional): Override the global device setting.
|
||||
If None, uses the global device variable.
|
||||
"""
|
||||
global tokenizer, bert_model, ssl_model, vq_model, hps, t2s_model, config, hz, max_sec
|
||||
global bigvgan_model, hifigan_model, sv_cn_model, device
|
||||
|
||||
# Use device override if provided, otherwise use global device
|
||||
target_device = device_override if device_override is not None else device
|
||||
|
||||
print(f"Loading models onto {target_device}...")
|
||||
|
||||
# Load BERT tokenizer and model
|
||||
print("Loading BERT model...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if is_half == True:
|
||||
bert_model = bert_model.half().to(target_device)
|
||||
else:
|
||||
bert_model = bert_model.to(target_device)
|
||||
|
||||
# Load SSL model
|
||||
print("Loading SSL model...")
|
||||
ssl_model = cnhubert.get_model()
|
||||
if is_half == True:
|
||||
ssl_model = ssl_model.half().to(target_device)
|
||||
else:
|
||||
ssl_model = ssl_model.to(target_device)
|
||||
|
||||
# Temporarily update global device if override is provided for model loading functions
|
||||
original_device = device if device_override is not None else None
|
||||
if device_override is not None:
|
||||
device = target_device
|
||||
|
||||
try:
|
||||
# Load SoVITS model
|
||||
print("Loading SoVITS model...")
|
||||
try:
|
||||
next(change_sovits_weights(sovits_path))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Load GPT model
|
||||
print("Loading GPT model...")
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
# Load appropriate vocoder model based on version
|
||||
print(f"Loading vocoder model for version {model_version}...")
|
||||
if model_version == "v3":
|
||||
init_bigvgan()
|
||||
elif model_version == "v4":
|
||||
init_hifigan()
|
||||
elif model_version in {"v2Pro", "v2ProPlus"}:
|
||||
init_sv_cn()
|
||||
finally:
|
||||
# Restore original device if it was overridden
|
||||
if original_device is not None:
|
||||
device = original_device
|
||||
|
||||
print("All models loaded successfully!")
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0, sr1, device):
|
||||
global resample_transform_dict
|
||||
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
||||
@ -765,8 +835,17 @@ def get_tts_wav(
|
||||
sample_steps=8,
|
||||
if_sr=False,
|
||||
pause_second=0.3,
|
||||
device_override=None
|
||||
):
|
||||
global cache
|
||||
global device
|
||||
if device_override:
|
||||
device = device_override
|
||||
|
||||
# Check if models are loaded
|
||||
if (tokenizer is None or bert_model is None or ssl_model is None or
|
||||
vq_model is None or t2s_model is None):
|
||||
raise RuntimeError("Models not loaded. Please call load_models() first.")
|
||||
if ref_wav_path:
|
||||
pass
|
||||
else:
|
||||
@ -1344,6 +1423,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
# gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")))
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_models()
|
||||
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user