fix: resolve path issue for inference_gui on Windows

This commit is contained in:
aoguai 2024-06-28 14:13:31 +08:00
parent d1d43661a9
commit ffb14ff13e
3 changed files with 79 additions and 45 deletions

View File

@ -1,11 +1,13 @@
import os
import sys
import soundfile as sf
from PyQt5.QtCore import QEvent
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
import soundfile as sf
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
@ -61,11 +63,11 @@ class GPTSoVITSGUI(QMainWindow):
border: 1px solid #45a049;
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
}
""")
""")
license_text = (
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
license_label = QLabel(license_text)
license_label.setWordWrap(True)
@ -284,17 +286,17 @@ class GPTSoVITSGUI(QMainWindow):
change_sovits_weights(sovits_path=SoVITS_model_path)
self.SoVITS_Path = SoVITS_model_path
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=language_combobox,
text=target_text,
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=language_combobox,
text=target_text,
text_language=target_language_combobox)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
output_wav_path = os.path.join(output_path, "output.wav")
output_wav_path = os.path.join(output_path, "output.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
result = "Audio saved to " + output_wav_path
@ -307,4 +309,4 @@ if __name__ == '__main__':
app = QApplication(sys.argv)
mainWin = GPTSoVITSGUI()
mainWin.show()
sys.exit(app.exec_())
sys.exit(app.exec_())

View File

@ -18,31 +18,50 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file:
gweight_data = file.read()
gpt_path = os.environ.get(
"gpt_path", gweight_data)
else:
gpt_path = os.environ.get(
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
# 获取项目根目录的绝对路径
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if os.path.exists("./sweight.txt"):
with open("./sweight.txt", 'r', encoding="utf-8") as file:
def get_absolute_path(relative_path):
# 如果路径已经是绝对路径,直接返回
if os.path.isabs(relative_path):
return relative_path
return os.path.join(PROJECT_ROOT, relative_path)
# 检查和加载 sovits_path
sweight_file_path = get_absolute_path("sweight.txt")
if os.path.exists(sweight_file_path):
with open(sweight_file_path, 'r', encoding="utf-8") as file:
sovits_path = file.read().strip()
sovits_path = get_absolute_path(sovits_path)
else:
sovits_path = get_absolute_path("GPT_SoVITS/pretrained_models/s2G488k.pth")
# 检查和加载 gpt_path
gweight_file_path = get_absolute_path("gweight.txt")
if os.path.exists(gweight_file_path):
with open(gweight_file_path, 'r', encoding="utf-8") as file:
gpt_path = file.read().strip()
gpt_path = get_absolute_path(gpt_path)
else:
gpt_path = get_absolute_path("GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
if os.path.exists(gweight_file_path):
with open(gweight_file_path, 'r', encoding="utf-8") as file:
gweight_data = file.read()
gpt_path = os.environ.get("gpt_path", gweight_data)
else:
gpt_path = os.environ.get("gpt_path", get_absolute_path("GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"))
if os.path.exists(sweight_file_path):
with open(sweight_file_path, 'r', encoding="utf-8") as file:
sweight_data = file.read()
sovits_path = os.environ.get("sovits_path", sweight_data)
else:
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
# gpt_path = os.environ.get(
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# )
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
sovits_path = os.environ.get("sovits_path", get_absolute_path("GPT_SoVITS/pretrained_models/s2G488k.pth"))
cnhubert_base_path = os.environ.get("cnhubert_base_path", get_absolute_path("GPT_SoVITS/pretrained_models/chinese-hubert-base"))
bert_path = os.environ.get("bert_path", get_absolute_path("GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"))
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
is_share = os.environ.get("is_share", "False")
@ -138,6 +157,7 @@ else:
def change_sovits_weights(sovits_path):
global vq_model, hps
sovits_path = get_absolute_path(sovits_path)
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
@ -148,38 +168,43 @@ def change_sovits_weights(sovits_path):
n_speakers=hps.data.n_speakers,
**hps.model
)
if ("pretrained" not in sovits_path):
if "pretrained" not in sovits_path:
del vq_model.enc_q
if is_half == True:
if is_half:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt", "w", encoding="utf-8") as f:
sweight_file_path = get_absolute_path("sweight.txt")
with open(sweight_file_path, "w", encoding="utf-8") as f:
f.write(sovits_path)
change_sovits_weights(sovits_path)
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config
gpt_path = get_absolute_path(gpt_path)
hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half == True:
if is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
gweight_file_path = get_absolute_path("gweight.txt")
with open(gweight_file_path, "w", encoding="utf-8") as f:
f.write(gpt_path)
# 更新权重路径
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
@ -554,10 +579,11 @@ def change_choices():
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
SoVITS_weight_root = "SoVITS_weights"
GPT_weight_root = "GPT_weights"
pretrained_sovits_name = get_absolute_path("GPT_SoVITS/pretrained_models/s2G488k.pth")
pretrained_gpt_name = get_absolute_path("GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
SoVITS_weight_root = get_absolute_path("SoVITS_weights")
GPT_weight_root = get_absolute_path("GPT_weights")
os.makedirs(SoVITS_weight_root, exist_ok=True)
os.makedirs(GPT_weight_root, exist_ok=True)

View File

@ -2,9 +2,14 @@ import json
import locale
import os
# 获取项目根目录的绝对路径
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
def load_language_list(language):
with open(f"./i18n/locale/{language}.json", "r", encoding="utf-8") as f:
# 使用从项目根目录开始的绝对路径
language_file_path = os.path.join(PROJECT_ROOT, f"i18n/locale/{language}.json")
with open(language_file_path, "r", encoding="utf-8") as f:
language_list = json.load(f)
return language_list
@ -15,7 +20,8 @@ class I18nAuto:
language = locale.getdefaultlocale()[
0
] # getlocale can't identify the system's language ((None, None))
if not os.path.exists(f"./i18n/locale/{language}.json"):
language_file_path = os.path.join(PROJECT_ROOT, f"i18n/locale/{language}.json")
if not os.path.exists(language_file_path):
language = "en_US"
self.language = language
self.language_map = load_language_list(language)