mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
添加模型记忆功能,不用二次选择模型
添加模型记忆功能,不用二次选择模型
This commit is contained in:
parent
813cf96e50
commit
0bcdf0155c
@ -6,10 +6,25 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
)
|
||||
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||
if os.path.exists("./gweight.txt"):
|
||||
with open("./gweight.txt", 'r',encoding="utf-8") as file:
|
||||
gweight_data = file.read()
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", gweight_data)
|
||||
else:
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
||||
|
||||
if os.path.exists("./sweight.txt"):
|
||||
with open("./sweight.txt", 'r',encoding="utf-8") as file:
|
||||
sweight_data = file.read()
|
||||
sovits_path = os.environ.get("sovits_path", sweight_data)
|
||||
else:
|
||||
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
|
||||
# gpt_path = os.environ.get(
|
||||
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
# )
|
||||
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||
cnhubert_base_path = os.environ.get(
|
||||
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
||||
)
|
||||
@ -124,6 +139,7 @@ def change_sovits_weights(sovits_path):
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
|
||||
change_sovits_weights(sovits_path)
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
@ -140,6 +156,7 @@ def change_gpt_weights(gpt_path):
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
|
Loading…
x
Reference in New Issue
Block a user