Merge branch 'RVC-Boss:main' into main

This commit is contained in:
Jesse Cheng 2024-02-19 01:58:19 +11:00 committed by GitHub
commit f1aa95d8cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 177 additions and 90 deletions

5
.gitignore vendored
View File

@ -7,5 +7,8 @@ runtime
output output
logs logs
reference reference
SoVITS_weights
GPT_weights GPT_weights
SoVITS_weights
TEMP

View File

@ -41,7 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val']) # pad_val=self.config['data']['pad_val'])
def train_dataloader(self): def train_dataloader(self):
batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存 batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader( return DataLoader(
self._train_dataset, self._train_dataset,

View File

@ -11,7 +11,6 @@ from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule): class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True): def __init__(self, config, output_dir, is_train=True):
super().__init__() super().__init__()
@ -35,7 +34,8 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int): def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers() opt = self.optimizers()
scheduler = self.lr_schedulers() scheduler = self.lr_schedulers()
loss, acc = self.model.forward( forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"], batch["phoneme_ids"],
batch["phoneme_ids_len"], batch["phoneme_ids_len"],
batch["semantic_ids"], batch["semantic_ids"],

View File

@ -337,7 +337,7 @@ class Text2SemanticDecoder(nn.Module):
# AR Decoder # AR Decoder
y = prompts y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1] x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False stop = False
@ -353,23 +353,24 @@ class Text2SemanticDecoder(nn.Module):
"first_infer": 1, "first_infer": 1,
"stage": 0, "stage": 0,
} }
for idx in tqdm(range(1500)): ################### first step ##########################
if cache["first_infer"] == 1: if y is not None:
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y)
else: y_len = y_emb.shape[1]
y_emb = torch.cat( prefix_len = y.shape[1]
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else: else:
xy_pos = y_pos[:, -1:] y_emb = None
y_len = y_pos.shape[1] y_len = 0
###以下3个不做缓存 prefix_len = 0
if cache["first_infer"] == 1: y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
x_attn_mask_pad = F.pad( x_attn_mask_pad = F.pad(
x_attn_mask, x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y) (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
@ -381,19 +382,12 @@ class Text2SemanticDecoder(nn.Module):
value=False, value=False,
) )
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device x.device
) )
else:
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) for idx in tqdm(range(1500)):
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
# pdb.set_trace()
###缓存重头戏
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer( logits = self.ar_predict_layer(
xy_dec[:, -1] xy_dec[:, -1]
@ -404,6 +398,10 @@ class Text2SemanticDecoder(nn.Module):
samples = sample( samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0) )[0].unsqueeze(0)
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num) print("use early stop num:", early_stop_num)
stop = True stop = True
@ -412,13 +410,38 @@ class Text2SemanticDecoder(nn.Module):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True stop = True
if stop: if stop:
if prompts.shape[1] == y.shape[1]: # if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1) y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction") print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs ####################### update next step ###################################
y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0 cache["first_infer"] = 0
return y, idx if cache["y_emb"] is not None:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
else:
y_emb = self.ar_audio_embedding(y[:, -1:])
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos
y_len = y_pos.shape[1]
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx-1

View File

@ -114,6 +114,7 @@ def logits_to_probs(
top_p: Optional[int] = None, top_p: Optional[int] = None,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
): ):
if previous_tokens is not None:
previous_tokens = previous_tokens.squeeze() previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape) # print(logits.shape,previous_tokens.shape)
# pdb.set_trace() # pdb.set_trace()

View File

@ -5,8 +5,8 @@ from torch.nn.functional import (
_none_or_dtype, _none_or_dtype,
_in_projection_packed, _in_projection_packed,
) )
from torch.nn import functional as F
# import torch import torch
# Tensor = torch.Tensor # Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union # from typing import Callable, List, Optional, Tuple, Union
@ -448,9 +448,11 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim) k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal q, k, v, attn_mask, dropout_p, is_causal
) )
attn_output = ( attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
) )

View File

@ -248,6 +248,10 @@ def clean_text_inf(text, language):
formattext = "" formattext = ""
language = language.replace("all_","") language = language.replace("all_","")
for tmp in LangSegment.getTexts(text): for tmp in LangSegment.getTexts(text):
if language == "ja":
if tmp["lang"] == language or tmp["lang"] == "zh":
formattext += tmp["text"] + " "
continue
if tmp["lang"] == language: if tmp["lang"] == language:
formattext += tmp["text"] + " " formattext += tmp["text"] + " "
while " " in formattext: while " " in formattext:
@ -279,8 +283,6 @@ def nonen_clean_text_inf(text, language):
for tmp in LangSegment.getTexts(text): for tmp in LangSegment.getTexts(text):
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = [] phones_list = []
word2ph_list = [] word2ph_list = []
norm_text_list = [] norm_text_list = []
@ -365,15 +367,19 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text result[len(result) - 1] += text
return result return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime() t0 = ttime()
prompt_language = dict_language[prompt_language] prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language] text_language = dict_language[text_language]
if not ref_free:
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "." if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n") text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(i18n("实际输入的参考文本:"), prompt_text)
print(i18n("实际输入的目标文本:"), text) print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros( zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3), int(hps.data.sampling_rate * 0.3),
@ -398,11 +404,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
1, 2 1, 2
) # .float() ) # .float()
codes = vq_model.extract_latent(ssl_content) codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
t1 = ttime() t1 = ttime()
phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
if (how_to_cut == i18n("凑四句一切")): if (how_to_cut == i18n("凑四句一切")):
text = cut1(text) text = cut1(text)
elif (how_to_cut == i18n("凑50字一切")): elif (how_to_cut == i18n("凑50字一切")):
@ -419,6 +424,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
texts = text.split("\n") texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5) texts = merge_short_text_in_array(texts, 5)
audio_opt = [] audio_opt = []
if not ref_free:
phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype) bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
for text in texts: for text in texts:
@ -429,9 +436,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
print(i18n("实际输入的目标文本(每句):"), text) print(i18n("实际输入的目标文本(每句):"), text)
phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language) phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype) bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
if not ref_free:
bert = torch.cat([bert1, bert2], 1) bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
else:
bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
@ -441,7 +452,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
pred_semantic, idx = t2s_model.model.infer_panel( pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
all_phoneme_len, all_phoneme_len,
prompt, None if ref_free else prompt,
bert, bert,
# prompt_phone_len=ph_offset, # prompt_phone_len=ph_offset,
top_k=top_k, top_k=top_k,
@ -607,6 +618,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(value=i18n("*请上传并填写参考信息")) gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row(): with gr.Row():
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath") inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath")
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT"))
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="") prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown( prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文") label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
@ -624,6 +638,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row():
gr.Markdown("gpt采样参数(无参考文本时不要太低)")
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
@ -632,7 +647,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button.click( inference_button.click(
get_tts_wav, get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature], [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
[output], [output],
) )
@ -650,7 +665,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
button3.click(cut3, [text_inp], [text_opt]) button3.click(cut3, [text_inp], [text_opt])
button4.click(cut4, [text_inp], [text_opt]) button4.click(cut4, [text_inp], [text_opt])
button5.click(cut5, [text_inp], [text_opt]) button5.click(cut5, [text_inp], [text_opt])
gr.Markdown(value=i18n("后续将支持混合语种编码文本输入")) gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行"))
app.queue(concurrency_count=511, max_size=1022).launch( app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0", server_name="0.0.0.0",

View File

@ -228,6 +228,7 @@ class TextEncoder(nn.Module):
) )
y = self.ssl_proj(y * y_mask) * y_mask y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask) y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze( text_mask = torch.unsqueeze(
@ -958,6 +959,8 @@ class SynthesizerTrn(nn.Module):
@torch.no_grad() @torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5): def decode(self, codes, text, refer, noise_scale=0.5):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze( refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1 commons.sequence_mask(refer_lengths, refer.size(2)), 1

View File

@ -36,12 +36,12 @@ import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path) dir=os.path.dirname(path)
name=os.path.basename(path) name=os.path.basename(path)
tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part) # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part)
torch.save(fea,tmp_path) torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name)) shutil.move(tmp_path,"%s/%s"%(dir,name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
if os.path.exists(txt_path) == False: if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir) bert_dir = "%s/3-bert" % (opt_dir)

View File

@ -35,7 +35,8 @@ import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path) dir=os.path.dirname(path)
name=os.path.basename(path) name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part)
torch.save(fea,tmp_path) torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name)) shutil.move(tmp_path,"%s/%s"%(dir,name))

View File

@ -1,11 +1,18 @@
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from time import time as ttime
import shutil,os
import torch import torch
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
def savee(ckpt, name, epoch, steps, hps): def savee(ckpt, name, epoch, steps, hps):
try: try:
@ -17,7 +24,8 @@ def savee(ckpt, name, epoch, steps, hps):
opt["weight"][key] = ckpt[key].half() opt["weight"][key] = ckpt[key].half()
opt["config"] = hps opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success." return "Success."
except: except:
return traceback.format_exc() return traceback.format_exc()

View File

@ -24,6 +24,14 @@ torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt from AR.utils import get_newest_ckpt
from collections import OrderedDict from collections import OrderedDict
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
class my_model_ckpt(ModelCheckpoint): class my_model_ckpt(ModelCheckpoint):
@ -70,7 +78,8 @@ class my_model_ckpt(ModelCheckpoint):
to_save_od["weight"][key] = dictt[key].half() to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
torch.save( # torch.save(
my_save(
to_save_od, to_save_od,
"%s/%s-e%s.ckpt" "%s/%s-e%s.ckpt"
% ( % (

View File

@ -169,7 +169,7 @@ def read_dict_new():
line = line.strip() line = line.strip()
word_split = line.split(" ") word_split = line.split(" ")
word = word_split[0] word = word_split[0]
if word not in g2p_dict: #if word not in g2p_dict:
g2p_dict[word] = [] g2p_dict[word] = []
g2p_dict[word].append(word_split[1:]) g2p_dict[word].append(word_split[1:])

View File

@ -672,6 +672,7 @@ class ToneSandhi:
and i + 1 < len(seg) and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0] and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v" and seg[i - 1][1] == "v"
and seg[i + 1][1] == "v"
): ):
new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0] new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0]
else: else:

View File

@ -64,6 +64,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
) )
return model, optimizer, learning_rate, iteration return model, optimizer, learning_rate, iteration
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info( logger.info(
@ -75,7 +83,8 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
torch.save( # torch.save(
my_save(
{ {
"model": state_dict, "model": state_dict,
"iteration": iteration, "iteration": iteration,

View File

@ -82,7 +82,7 @@
"source": [ "source": [
"# @title launch WebUI 启动WebUI\n", "# @title launch WebUI 启动WebUI\n",
"!/usr/local/bin/pip install ipykernel\n", "!/usr/local/bin/pip install ipykernel\n",
"!sed -i '9s/False/True/' /content/GPT-SoVITS/config.py\n", "!sed -i '10s/False/True/' /content/GPT-SoVITS/config.py\n",
"%cd /content/GPT-SoVITS/\n", "%cd /content/GPT-SoVITS/\n",
"!/usr/local/bin/python webui.py" "!/usr/local/bin/python webui.py"
], ],

View File

@ -113,12 +113,21 @@
2-DPO Loss实验性训练选项开启通过构造负样本训练缓解GPT重复漏字问题。推理界面公开几个推理参数。 https://github.com/RVC-Boss/GPT-SoVITS/pull/457 2-DPO Loss实验性训练选项开启通过构造负样本训练缓解GPT重复漏字问题。推理界面公开几个推理参数。 https://github.com/RVC-Boss/GPT-SoVITS/pull/457
### 20240214更新
1-训练支持中文实验名(原来会报错)
2-DPO训练改为可勾选选项而非必须。如勾选batch size自动减半。修复推理界面新参数不传参的问题。
### 20240216更新
1-支持无参考文本输入
2-修复中文文本前端bug https://github.com/RVC-Boss/GPT-SoVITS/issues/475
todolist todolist
1-中文多音字推理优化 1-中文多音字推理优化
2-训练支持中文实验名(原来会报错)

View File

@ -266,7 +266,7 @@ def close1Ba():
return "已终止SoVITS训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False} return "已终止SoVITS训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
p_train_GPT=None p_train_GPT=None
def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers,pretrained_s1): def open1Bb(batch_size,total_epoch,exp_name,if_dpo,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers,pretrained_s1):
global p_train_GPT global p_train_GPT
if(p_train_GPT==None): if(p_train_GPT==None):
with open("GPT_SoVITS/configs/s1longer.yaml")as f: with open("GPT_SoVITS/configs/s1longer.yaml")as f:
@ -283,6 +283,7 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights
data["train"]["save_every_n_epoch"]=save_every_epoch data["train"]["save_every_n_epoch"]=save_every_epoch
data["train"]["if_save_every_weights"]=if_save_every_weights data["train"]["if_save_every_weights"]=if_save_every_weights
data["train"]["if_save_latest"]=if_save_latest data["train"]["if_save_latest"]=if_save_latest
data["train"]["if_dpo"]=if_dpo
data["train"]["half_weights_save_dir"]=GPT_weight_root data["train"]["half_weights_save_dir"]=GPT_weight_root
data["train"]["exp_name"]=exp_name data["train"]["exp_name"]=exp_name
data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
@ -807,6 +808,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row(): with gr.Row():
batch_size1Bb = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True) batch_size1Bb = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
total_epoch1Bb = gr.Slider(minimum=2,maximum=50,step=1,label=i18n("总训练轮数total_epoch"),value=15,interactive=True) total_epoch1Bb = gr.Slider(minimum=2,maximum=50,step=1,label=i18n("总训练轮数total_epoch"),value=15,interactive=True)
if_dpo = gr.Checkbox(label=i18n("是否开启dpo训练选项(实验性)"), value=False, interactive=True, show_label=True)
if_save_latest1Bb = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True) if_save_latest1Bb = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
if_save_every_weights1Bb = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True) if_save_every_weights1Bb = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
save_every_epoch1Bb = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True) save_every_epoch1Bb = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
@ -817,7 +819,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
info1Bb=gr.Textbox(label=i18n("GPT训练进程输出信息")) info1Bb=gr.Textbox(label=i18n("GPT训练进程输出信息"))
button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Ba,button1Ba_open,button1Ba_close]) button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Ba,button1Ba_open,button1Ba_close])
button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close]) button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close])
button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close]) button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_dpo,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close])
button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close]) button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close])
with gr.TabItem(i18n("1C-推理")): with gr.TabItem(i18n("1C-推理")):
gr.Markdown(value=i18n("选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模体验5秒Zero Shot TTS用。")) gr.Markdown(value=i18n("选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模体验5秒Zero Shot TTS用。"))