模型实验名可设置为中文。

fix https://github.com/RVC-Boss/GPT-SoVITS/issues/500
This commit is contained in:
RVC-Boss 2024-02-17 16:45:31 +08:00 committed by GitHub
parent f6c9803909
commit e97cc3346a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 4 deletions

View File

@ -1,11 +1,18 @@
import traceback
from collections import OrderedDict
from time import time as ttime
import shutil,os
import torch
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
def savee(ckpt, name, epoch, steps, hps):
try:
@ -17,7 +24,8 @@ def savee(ckpt, name, epoch, steps, hps):
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
# torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()

View File

@ -24,6 +24,14 @@ torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
class my_model_ckpt(ModelCheckpoint):
@ -70,7 +78,8 @@ class my_model_ckpt(ModelCheckpoint):
to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
torch.save(
# torch.save(
my_save(
to_save_od,
"%s/%s-e%s.ckpt"
% (

View File

@ -64,6 +64,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
)
return model, optimizer, learning_rate, iteration
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
@ -75,7 +83,8 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
# torch.save(
my_save(
{
"model": state_dict,
"iteration": iteration,