mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-23 21:19:47 +08:00
support gpt-sovits v4
support gpt-sovits v4
This commit is contained in:
parent
e0c452f007
commit
c6cb6b45f3
@ -39,21 +39,25 @@ except:
|
||||
...
|
||||
version = model_version = os.environ.get("version", "v2")
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||
pretrained_sovits_name = [
|
||||
"GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||
path_sovits_v3,
|
||||
"GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
||||
"GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
||||
]
|
||||
pretrained_gpt_name = [
|
||||
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
]
|
||||
|
||||
|
||||
_ = [[], []]
|
||||
for i in range(3):
|
||||
for i in range(4):
|
||||
if os.path.exists(pretrained_gpt_name[i]):
|
||||
_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):
|
||||
@ -102,7 +106,7 @@ cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
|
||||
import random
|
||||
|
||||
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
@ -222,23 +226,25 @@ else:
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0):
|
||||
def resample(audio_tensor, sr0,sr1):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
||||
key="%s-%s"%(sr0,sr1)
|
||||
if key not in resample_transform_dict:
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||||
return resample_transform_dict[key](audio_tensor)
|
||||
|
||||
|
||||
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
|
||||
# symbol_version-model_version-if_lora_v3
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
|
||||
v3v4set={"v3","v4"}
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
global vq_model, hps, version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
# print(sovits_path,version, model_version, if_lora_v3)
|
||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
||||
print(sovits_path,version, model_version, if_lora_v3)
|
||||
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
|
||||
if if_lora_v3 == True and is_exist == False:
|
||||
info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
@ -257,7 +263,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
else:
|
||||
text_update = {"__type__": "update", "value": ""}
|
||||
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
if model_version == "v3":
|
||||
if model_version in v3v4set:
|
||||
visible_sample_steps = True
|
||||
visible_inp_refs = False
|
||||
else:
|
||||
@ -270,10 +276,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "visible": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False},
|
||||
{"__type__": "update", "visible": True if model_version == "v3" else False},
|
||||
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
|
||||
{"__type__": "update", "visible": True if model_version =="v3" else False},
|
||||
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
||||
)
|
||||
|
||||
@ -289,7 +295,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
hps.model.version = "v2"
|
||||
version = hps.model.version
|
||||
# print("sovits版本:",hps.model.version)
|
||||
if model_version != "v3":
|
||||
if model_version not in v3v4set:
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
@ -317,9 +323,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
if if_lora_v3 == False:
|
||||
print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
else:
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
print(
|
||||
"loading sovits_v3pretrained_G",
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False),
|
||||
"loading sovits_%spretrained_G"%model_version,
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False),
|
||||
)
|
||||
lora_rank = dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
@ -329,7 +336,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
init_lora_weights=True,
|
||||
)
|
||||
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
|
||||
print("loading sovits_v3_lora%s" % (lora_rank))
|
||||
print("loading sovits_%s_lora%s" % (model_version,lora_rank))
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
vq_model.cfm = vq_model.cfm.merge_and_unload()
|
||||
# torch.save(vq_model.state_dict(),"merge_win.pth")
|
||||
@ -342,10 +349,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "visible": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False},
|
||||
{"__type__": "update", "visible": True if model_version == "v3" else False},
|
||||
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
|
||||
{"__type__": "update", "visible": True if model_version =="v3" else False},
|
||||
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
||||
)
|
||||
with open("./weight.json") as f:
|
||||
@ -392,7 +399,7 @@ now_dir = os.getcwd()
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
global bigvgan_model
|
||||
global bigvgan_model,hifigan_model
|
||||
from BigVGAN import bigvgan
|
||||
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||
@ -402,16 +409,47 @@ def init_bigvgan():
|
||||
# remove weight norm in the model and set to eval mode
|
||||
bigvgan_model.remove_weight_norm()
|
||||
bigvgan_model = bigvgan_model.eval()
|
||||
if hifigan_model:
|
||||
hifigan_model=hifigan_model.cpu()
|
||||
hifigan_model=None
|
||||
try:torch.cuda.empty_cache()
|
||||
except:pass
|
||||
if is_half == True:
|
||||
bigvgan_model = bigvgan_model.half().to(device)
|
||||
else:
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
|
||||
def init_hifigan():
|
||||
global hifigan_model,bigvgan_model
|
||||
hifigan_model = Generator(
|
||||
initial_channel=100,
|
||||
resblock="1",
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_rates=[10, 6, 2, 2, 2],
|
||||
upsample_initial_channel=512,
|
||||
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||
gin_channels=0, is_bias=True
|
||||
)
|
||||
hifigan_model.eval()
|
||||
hifigan_model.remove_weight_norm()
|
||||
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
||||
print("loading vocoder",hifigan_model.load_state_dict(state_dict_g))
|
||||
if bigvgan_model:
|
||||
bigvgan_model=bigvgan_model.cpu()
|
||||
bigvgan_model=None
|
||||
try:torch.cuda.empty_cache()
|
||||
except:pass
|
||||
if is_half == True:
|
||||
hifigan_model = hifigan_model.half().to(device)
|
||||
else:
|
||||
hifigan_model = hifigan_model.to(device)
|
||||
|
||||
if model_version != "v3":
|
||||
bigvgan_model = None
|
||||
else:
|
||||
bigvgan_model=hifigan_model=None
|
||||
if model_version=="v3":
|
||||
init_bigvgan()
|
||||
if model_version=="v4":
|
||||
init_hifigan()
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
@ -576,6 +614,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1280,
|
||||
"win_size": 1280,
|
||||
"hop_size": 320,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 32000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def merge_short_text_in_array(texts, threshold):
|
||||
@ -647,7 +698,7 @@ def get_tts_wav(
|
||||
t = []
|
||||
if prompt_text is None or len(prompt_text) == 0:
|
||||
ref_free = True
|
||||
if model_version == "v3":
|
||||
if model_version in v3v4set:
|
||||
ref_free = False # s2v3暂不支持ref_free
|
||||
else:
|
||||
if_sr = False
|
||||
@ -755,7 +806,7 @@ def get_tts_wav(
|
||||
cache[i_text] = pred_semantic
|
||||
t3 = ttime()
|
||||
###v3不存在以下逻辑和inp_refs
|
||||
if model_version != "v3":
|
||||
if model_version not in v3v4set:
|
||||
refers = []
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
@ -779,25 +830,24 @@ def get_tts_wav(
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr != 24000:
|
||||
ref_audio = resample(ref_audio, sr)
|
||||
tgt_sr=24000 if model_version=="v3"else 32000
|
||||
if sr != tgt_sr:
|
||||
ref_audio = resample(ref_audio, sr,tgt_sr)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if T_min > 468:
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
||||
# print("mel2",mel2)
|
||||
Tref=468 if model_version=="v3"else 500
|
||||
Tchunk=934 if model_version=="v3"else 1000
|
||||
if T_min > Tref:
|
||||
mel2 = mel2[:, :, -Tref:]
|
||||
fea_ref = fea_ref[:, :, -Tref:]
|
||||
T_min = Tref
|
||||
chunk_len = Tchunk - T_min
|
||||
mel2 = mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||
# print("fea_todo",fea_todo)
|
||||
# print("ge",ge.abs().mean())
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while 1:
|
||||
@ -806,22 +856,24 @@ def get_tts_wav(
|
||||
break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
# set_seed(123)
|
||||
cfm_res = vq_model.cfm.inference(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||
)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
# print("fea", fea)
|
||||
# print("mel2in", mel2)
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
cfm_res = torch.cat(cfm_resss, 2)
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
if model_version=="v3":
|
||||
if bigvgan_model == None:
|
||||
init_bigvgan()
|
||||
else:#v4
|
||||
if hifigan_model == None:
|
||||
init_hifigan()
|
||||
vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model
|
||||
with torch.inference_mode():
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
wav_gen = vocoder_model(cfm_res)
|
||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
@ -833,16 +885,18 @@ def get_tts_wav(
|
||||
t1 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
||||
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
|
||||
sr = hps.data.sampling_rate if model_version != "v3" else 24000
|
||||
if if_sr == True and sr == 24000:
|
||||
if model_version in {"v1","v2"}:opt_sr=32000
|
||||
elif model_version=="v3":opt_sr=24000
|
||||
else:opt_sr=48000#v4
|
||||
if if_sr == True and opt_sr == 24000:
|
||||
print(i18n("音频超分中"))
|
||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||
audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr)
|
||||
max_audio = np.abs(audio_opt).max()
|
||||
if max_audio > 1:
|
||||
audio_opt /= max_audio
|
||||
else:
|
||||
audio_opt = audio_opt.cpu().detach().numpy()
|
||||
yield sr, (audio_opt * 32767).astype(np.int16)
|
||||
yield opt_sr, (audio_opt * 32767).astype(np.int16)
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
@ -971,8 +1025,8 @@ def change_choices():
|
||||
}
|
||||
|
||||
|
||||
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
|
||||
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
|
||||
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
|
||||
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
|
||||
for path in SoVITS_weight_root + GPT_weight_root:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
@ -1039,7 +1093,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")
|
||||
+ i18n("v3暂不支持该模式,使用了会报错。"),
|
||||
value=False,
|
||||
interactive=True if model_version != "v3" else False,
|
||||
interactive=True if model_version not in v3v4set else False,
|
||||
show_label=True,
|
||||
scale=1,
|
||||
)
|
||||
@ -1064,7 +1118,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
),
|
||||
file_count="multiple",
|
||||
)
|
||||
if model_version != "v3"
|
||||
if model_version not in v3v4set
|
||||
else gr.File(
|
||||
label=i18n(
|
||||
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
|
||||
@ -1076,16 +1130,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
sample_steps = (
|
||||
gr.Radio(
|
||||
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
|
||||
value=32,
|
||||
choices=[4, 8, 16, 32],
|
||||
value=32 if model_version=="v3"else 8,
|
||||
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
|
||||
visible=True,
|
||||
)
|
||||
if model_version == "v3"
|
||||
if model_version in v3v4set
|
||||
else gr.Radio(
|
||||
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
|
||||
choices=[4, 8, 16, 32],
|
||||
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
|
||||
visible=False,
|
||||
value=32,
|
||||
value=32 if model_version=="v3"else 8,
|
||||
)
|
||||
)
|
||||
if_sr_Checkbox = gr.Checkbox(
|
||||
@ -1093,7 +1147,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
value=False,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
visible=False if model_version != "v3" else True,
|
||||
visible=False if model_version !="v3" else True,
|
||||
)
|
||||
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
|
||||
with gr.Row():
|
||||
|
@ -22,23 +22,24 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
01:v2
|
||||
02:v3
|
||||
03:v3lora
|
||||
|
||||
04:v4lora
|
||||
|
||||
"""
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def my_save2(fea, path):
|
||||
def my_save2(fea, path,cfm_version):
|
||||
bio = BytesIO()
|
||||
torch.save(fea, bio)
|
||||
bio.seek(0)
|
||||
data = bio.getvalue()
|
||||
data = b"03" + data[2:] ###temp for v3lora only, todo
|
||||
byte=b"03" if cfm_version=="v3"else b"04"
|
||||
data = byte + data[2:]
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||
def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
|
||||
try:
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
||||
if lora_rank:
|
||||
opt["lora_rank"] = lora_rank
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
|
||||
else:
|
||||
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
return "Success."
|
||||
@ -63,11 +64,13 @@ head2version = {
|
||||
b"01": ["v2", "v2", False],
|
||||
b"02": ["v2", "v3", False],
|
||||
b"03": ["v2", "v3", True],
|
||||
b"04": ["v2", "v4", True],
|
||||
}
|
||||
hash_pretrained_dict = {
|
||||
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
||||
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
|
||||
}
|
||||
import hashlib
|
||||
|
||||
@ -85,7 +88,7 @@ def get_sovits_version_from_path_fast(sovits_path):
|
||||
hash = get_hash_from_file(sovits_path)
|
||||
if hash in hash_pretrained_dict:
|
||||
return hash_pretrained_dict[hash]
|
||||
###2-new weights or old weights, by head
|
||||
###2-new weights, by head
|
||||
with open(sovits_path, "rb") as f:
|
||||
version = f.read(2)
|
||||
if version != b"PK":
|
||||
|
@ -27,12 +27,11 @@ from random import randint
|
||||
from module import commons
|
||||
from module.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollateV3,
|
||||
TextAudioSpeakerLoaderV3,
|
||||
TextAudioSpeakerCollateV4,
|
||||
TextAudioSpeakerLoaderV4,
|
||||
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrnV3 as SynthesizerTrn,
|
||||
@ -89,6 +88,8 @@ def run(rank, n_gpus, hps):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4
|
||||
TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
||||
train_sampler = DistributedBucketSampler(
|
||||
train_dataset,
|
||||
@ -364,7 +365,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
|
||||
epoch,
|
||||
global_step,
|
||||
hps,
|
||||
hps,cfm_version=hps.model.version,
|
||||
lora_rank=lora_rank,
|
||||
),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user