change process ckpt

This commit is contained in:
samiabat 2025-07-09 03:25:57 +03:00
parent fc98f14f61
commit 9fbeeea8a0

View File

@ -1,12 +1,14 @@
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from time import time as ttime from time import time as ttime
import shutil,os import shutil
import os
import torch import torch
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path) dir = os.path.dirname(path)
name = os.path.basename(path) name = os.path.basename(path)
@ -14,24 +16,29 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
torch.save(fea, tmp_path) torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name)) shutil.move(tmp_path, "%s/%s" % (dir, name))
'''
00:v1
01:v2
02:v3
03:v3lora
'''
from io import BytesIO from io import BytesIO
def my_save2(fea,path):
model_version2byte = {
"v3": b"03",
"v4": b"04",
"v2Pro": b"05",
"v2ProPlus": b"06",
}
def my_save2(fea, path, model_version):
bio = BytesIO() bio = BytesIO()
torch.save(fea, bio) torch.save(fea, bio)
bio.seek(0) bio.seek(0)
data = bio.getvalue() data = bio.getvalue()
data = b'03' + data[2:]###temp for v3lora only, todo byte = model_version2byte[model_version]
with open(path, "wb") as f: f.write(data) data = byte + data[2:]
with open(path, "wb") as f:
f.write(data)
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
try: try:
opt = OrderedDict() opt = OrderedDict()
opt["weight"] = {} opt["weight"] = {}
@ -43,48 +50,72 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
opt["info"] = "%sepoch_%siteration" % (epoch, steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank: if lora_rank:
opt["lora_rank"] = lora_rank opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
elif model_version != None and "Pro" in model_version:
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
else: else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success." return "Success."
except: except:
return traceback.format_exc() return traceback.format_exc()
"""
00:v1
01:v2
02:v3
03:v3lora
04:v4lora
05:v2Pro
06:v2ProPlus
"""
head2version = { head2version = {
b'00':["v1","v1",False], b"00": ["v1", "v1", False],
b'01':["v2","v2",False], b"01": ["v2", "v2", False],
b'02':["v2","v3",False], b"02": ["v2", "v3", False],
b'03':["v2","v3",True], b"03": ["v2", "v3", True],
b"04": ["v2", "v4", True],
b"05": ["v2", "v2Pro", False],
b"06": ["v2", "v2ProPlus", False],
} }
hash_pretrained_dict = { hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
"c7e9fce2223f3db685cdfa1e6368728a": ["v2", "v2Pro", False], # s2Gv2Pro.pth#sovits_v2Pro_pretrained
"66b313e39455b57ab1b0bc0b239c9d0a": ["v2", "v2ProPlus", False], # s2Gv2ProPlus.pth#sovits_v2ProPlus_pretrained
} }
import hashlib import hashlib
def get_hash_from_file(sovits_path): def get_hash_from_file(sovits_path):
with open(sovits_path,"rb")as f:data=f.read(8192) with open(sovits_path, "rb") as f:
data = f.read(8192)
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
hash_md5.update(data) hash_md5.update(data)
return hash_md5.hexdigest() return hash_md5.hexdigest()
def get_sovits_version_from_path_fast(sovits_path): def get_sovits_version_from_path_fast(sovits_path):
###1-if it is pretrained sovits models, by hash ###1-if it is pretrained sovits models, by hash
hash = get_hash_from_file(sovits_path) hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict: if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash] return hash_pretrained_dict[hash]
###2-new weights or old weights, by head ###2-new weights, by head
with open(sovits_path,"rb")as f:version=f.read(2) with open(sovits_path, "rb") as f:
version = f.read(2)
if version != b"PK": if version != b"PK":
return head2version[version] return head2version[version]
###3-old weights, by file size ###3-old weights, by file size
if_lora_v3 = False if_lora_v3 = False
size = os.path.getsize(sovits_path) size = os.path.getsize(sovits_path)
''' """
v1weights:about 82942KB v1weights:about 82942KB
half thr:82978KB half thr:82978KB
v2weights:about 83014KB v2weights:about 83014KB
v3weights:about 750MB v3weights:about 750MB
''' """
if size < 82978 * 1024: if size < 82978 * 1024:
model_version = version = "v1" model_version = version = "v1"
elif size < 700 * 1024 * 1024: elif size < 700 * 1024 * 1024:
@ -94,11 +125,12 @@ def get_sovits_version_from_path_fast(sovits_path):
model_version = "v3" model_version = "v3"
return version, model_version, if_lora_v3 return version, model_version, if_lora_v3
def load_sovits_new(sovits_path): def load_sovits_new(sovits_path):
f = open(sovits_path, "rb") f = open(sovits_path, "rb")
meta = f.read(2) meta = f.read(2)
if meta!="PK": if meta != b"PK":
data = b'PK' + f.read() data = b"PK" + f.read()
bio = BytesIO() bio = BytesIO()
bio.write(data) bio.write(data)
bio.seek(0) bio.seek(0)