模型实验名可设置为中文。

fix https://github.com/RVC-Boss/GPT-SoVITS/issues/500
This commit is contained in:
RVC-Boss 2024-02-17 16:45:31 +08:00 committed by GitHub
parent f6c9803909
commit e97cc3346a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 4 deletions

View File

@ -1,11 +1,18 @@
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from time import time as ttime
import shutil,os
import torch import torch
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
def savee(ckpt, name, epoch, steps, hps): def savee(ckpt, name, epoch, steps, hps):
try: try:
@ -17,7 +24,8 @@ def savee(ckpt, name, epoch, steps, hps):
opt["weight"][key] = ckpt[key].half() opt["weight"][key] = ckpt[key].half()
opt["config"] = hps opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success." return "Success."
except: except:
return traceback.format_exc() return traceback.format_exc()

View File

@ -24,6 +24,14 @@ torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt from AR.utils import get_newest_ckpt
from collections import OrderedDict from collections import OrderedDict
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
class my_model_ckpt(ModelCheckpoint): class my_model_ckpt(ModelCheckpoint):
@ -70,7 +78,8 @@ class my_model_ckpt(ModelCheckpoint):
to_save_od["weight"][key] = dictt[key].half() to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
torch.save( # torch.save(
my_save(
to_save_od, to_save_od,
"%s/%s-e%s.ckpt" "%s/%s-e%s.ckpt"
% ( % (

View File

@ -64,6 +64,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
) )
return model, optimizer, learning_rate, iteration return model, optimizer, learning_rate, iteration
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info( logger.info(
@ -75,7 +83,8 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
torch.save( # torch.save(
my_save(
{ {
"model": state_dict, "model": state_dict,
"iteration": iteration, "iteration": iteration,