mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-07 15:33:29 +08:00
change process ckpt
This commit is contained in:
parent
fc98f14f61
commit
9fbeeea8a0
@ -1,37 +1,44 @@
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from time import time as ttime
|
||||
import shutil,os
|
||||
import shutil
|
||||
import os
|
||||
import torch
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
tmp_path="%s.pth"%(ttime())
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
|
||||
'''
|
||||
00:v1
|
||||
01:v2
|
||||
02:v3
|
||||
03:v3lora
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
tmp_path = "%s.pth" % (ttime())
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
'''
|
||||
from io import BytesIO
|
||||
def my_save2(fea,path):
|
||||
|
||||
model_version2byte = {
|
||||
"v3": b"03",
|
||||
"v4": b"04",
|
||||
"v2Pro": b"05",
|
||||
"v2ProPlus": b"06",
|
||||
}
|
||||
|
||||
|
||||
def my_save2(fea, path, model_version):
|
||||
bio = BytesIO()
|
||||
torch.save(fea, bio)
|
||||
bio.seek(0)
|
||||
data = bio.getvalue()
|
||||
data = b'03' + data[2:]###temp for v3lora only, todo
|
||||
with open(path, "wb") as f: f.write(data)
|
||||
byte = model_version2byte[model_version]
|
||||
data = byte + data[2:]
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
|
||||
try:
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
@ -42,49 +49,73 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
opt["config"] = hps
|
||||
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
||||
if lora_rank:
|
||||
opt["lora_rank"]=lora_rank
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
opt["lora_rank"] = lora_rank
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
||||
elif model_version != None and "Pro" in model_version:
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
||||
else:
|
||||
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
return "Success."
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
head2version={
|
||||
b'00':["v1","v1",False],
|
||||
b'01':["v2","v2",False],
|
||||
b'02':["v2","v3",False],
|
||||
b'03':["v2","v3",True],
|
||||
|
||||
"""
|
||||
00:v1
|
||||
01:v2
|
||||
02:v3
|
||||
03:v3lora
|
||||
04:v4lora
|
||||
05:v2Pro
|
||||
06:v2ProPlus
|
||||
"""
|
||||
head2version = {
|
||||
b"00": ["v1", "v1", False],
|
||||
b"01": ["v2", "v2", False],
|
||||
b"02": ["v2", "v3", False],
|
||||
b"03": ["v2", "v3", True],
|
||||
b"04": ["v2", "v4", True],
|
||||
b"05": ["v2", "v2Pro", False],
|
||||
b"06": ["v2", "v2ProPlus", False],
|
||||
}
|
||||
hash_pretrained_dict={
|
||||
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
|
||||
hash_pretrained_dict = {
|
||||
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
||||
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
|
||||
"c7e9fce2223f3db685cdfa1e6368728a": ["v2", "v2Pro", False], # s2Gv2Pro.pth#sovits_v2Pro_pretrained
|
||||
"66b313e39455b57ab1b0bc0b239c9d0a": ["v2", "v2ProPlus", False], # s2Gv2ProPlus.pth#sovits_v2ProPlus_pretrained
|
||||
}
|
||||
import hashlib
|
||||
|
||||
|
||||
def get_hash_from_file(sovits_path):
|
||||
with open(sovits_path,"rb")as f:data=f.read(8192)
|
||||
with open(sovits_path, "rb") as f:
|
||||
data = f.read(8192)
|
||||
hash_md5 = hashlib.md5()
|
||||
hash_md5.update(data)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def get_sovits_version_from_path_fast(sovits_path):
|
||||
###1-if it is pretrained sovits models, by hash
|
||||
hash=get_hash_from_file(sovits_path)
|
||||
hash = get_hash_from_file(sovits_path)
|
||||
if hash in hash_pretrained_dict:
|
||||
return hash_pretrained_dict[hash]
|
||||
###2-new weights or old weights, by head
|
||||
with open(sovits_path,"rb")as f:version=f.read(2)
|
||||
if version!=b"PK":
|
||||
###2-new weights, by head
|
||||
with open(sovits_path, "rb") as f:
|
||||
version = f.read(2)
|
||||
if version != b"PK":
|
||||
return head2version[version]
|
||||
###3-old weights, by file size
|
||||
if_lora_v3=False
|
||||
size=os.path.getsize(sovits_path)
|
||||
'''
|
||||
if_lora_v3 = False
|
||||
size = os.path.getsize(sovits_path)
|
||||
"""
|
||||
v1weights:about 82942KB
|
||||
half thr:82978KB
|
||||
v2weights:about 83014KB
|
||||
v3weights:about 750MB
|
||||
'''
|
||||
"""
|
||||
if size < 82978 * 1024:
|
||||
model_version = version = "v1"
|
||||
elif size < 700 * 1024 * 1024:
|
||||
@ -92,15 +123,16 @@ def get_sovits_version_from_path_fast(sovits_path):
|
||||
else:
|
||||
version = "v2"
|
||||
model_version = "v3"
|
||||
return version,model_version,if_lora_v3
|
||||
return version, model_version, if_lora_v3
|
||||
|
||||
|
||||
def load_sovits_new(sovits_path):
|
||||
f=open(sovits_path,"rb")
|
||||
meta=f.read(2)
|
||||
if meta!="PK":
|
||||
data = b'PK' + f.read()
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != b"PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
bio.seek(0)
|
||||
return torch.load(bio, map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path,map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user