mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
添加模型记忆功能,不用二次选择模型
添加模型记忆功能,不用二次选择模型
This commit is contained in:
parent
813cf96e50
commit
0bcdf0155c
@ -6,10 +6,25 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
|
|||||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
gpt_path = os.environ.get(
|
if os.path.exists("./gweight.txt"):
|
||||||
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
with open("./gweight.txt", 'r',encoding="utf-8") as file:
|
||||||
)
|
gweight_data = file.read()
|
||||||
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
gpt_path = os.environ.get(
|
||||||
|
"gpt_path", gweight_data)
|
||||||
|
else:
|
||||||
|
gpt_path = os.environ.get(
|
||||||
|
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
||||||
|
|
||||||
|
if os.path.exists("./sweight.txt"):
|
||||||
|
with open("./sweight.txt", 'r',encoding="utf-8") as file:
|
||||||
|
sweight_data = file.read()
|
||||||
|
sovits_path = os.environ.get("sovits_path", sweight_data)
|
||||||
|
else:
|
||||||
|
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
|
||||||
|
# gpt_path = os.environ.get(
|
||||||
|
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
|
# )
|
||||||
|
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
||||||
cnhubert_base_path = os.environ.get(
|
cnhubert_base_path = os.environ.get(
|
||||||
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
||||||
)
|
)
|
||||||
@ -124,6 +139,7 @@ def change_sovits_weights(sovits_path):
|
|||||||
vq_model = vq_model.to(device)
|
vq_model = vq_model.to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||||
|
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
|
||||||
change_sovits_weights(sovits_path)
|
change_sovits_weights(sovits_path)
|
||||||
|
|
||||||
def change_gpt_weights(gpt_path):
|
def change_gpt_weights(gpt_path):
|
||||||
@ -140,6 +156,7 @@ def change_gpt_weights(gpt_path):
|
|||||||
t2s_model.eval()
|
t2s_model.eval()
|
||||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||||
|
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
|
||||||
change_gpt_weights(gpt_path)
|
change_gpt_weights(gpt_path)
|
||||||
|
|
||||||
def get_spepc(hps, filename):
|
def get_spepc(hps, filename):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user