diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py
index 606c1a8..ee2ec1e 100644
--- a/GPT_SoVITS/TTS_infer_pack/TTS.py
+++ b/GPT_SoVITS/TTS_infer_pack/TTS.py
@@ -304,7 +304,7 @@ class TTS:
def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
- dict_s2 = torch.load(weights_path, map_location=self.configs.device)
+ dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.configs.update_version("v1")
@@ -1031,4 +1031,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
- return processed_audio
\ No newline at end of file
+ return processed_audio
diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py
index f903b73..e5a8a60 100644
--- a/GPT_SoVITS/inference_webui.py
+++ b/GPT_SoVITS/inference_webui.py
@@ -28,10 +28,13 @@ try:
analytics.version_check = lambda:None
except:...
version=model_version=os.environ.get("version","v2")
-pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth","GPT_SoVITS/pretrained_models/s2Gv3.pth"]
+path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
+is_exist_s2gv3=os.path.exists(path_sovits_v3)
+pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
+
_ =[[],[]]
for i in range(3):
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
@@ -73,6 +76,7 @@ is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
+# is_half=False
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
@@ -83,13 +87,26 @@ from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
from GPT_SoVITS.module.models import SynthesizerTrn,SynthesizerTrnV3
+import numpy as np
+import random
+def set_seed(seed):
+ if seed == -1:
+ seed = random.randint(0, 1000000)
+ seed = int(seed)
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+# set_seed(42)
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from time import time as ttime
-from module.mel_processing import spectrogram_torch
from tools.my_utils import load_audio
from tools.i18n.i18n import I18nAuto, scan_language_list
+from peft import LoraConfig, PeftModel, get_peft_model
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
@@ -192,42 +209,21 @@ def resample(audio_tensor, sr0):
).to(device)
return resample_transform_dict[sr0](audio_tensor)
+###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
+#symbol_version-model_version-if_lora_v3
+from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
- global vq_model, hps, version, model_version, dict_language
- '''
- v1:about 82942KB
- half thr:82978KB
- v2:about 83014KB
- half thr:100MB
- v1base:103490KB
- half thr:103520KB
- v2base:103551KB
- v3:about 750MB
-
- ~82978K~100M~103420~700M
- v1-v2-v1base-v2base-v3
- version:
- symbols version and timebre_embedding version
- model_version:
- sovits is v1/2 (VITS) or v3 (shortcut CFM DiT)
- '''
- size=os.path.getsize(sovits_path)
- if size<82978*1024:
- model_version=version="v1"
- elif size<100*1024*1024:
- model_version=version="v2"
- elif size<103520*1024:
- model_version=version="v1"
- elif size<700*1024*1024:
- model_version = version = "v2"
- else:
- version = "v2"
- model_version="v3"
-
+ global vq_model, hps, version, model_version, dict_language,if_lora_v3
+ version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
+ # print(sovits_path,version, model_version, if_lora_v3)
+ if if_lora_v3==True and is_exist_s2gv3==False:
+ info= "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
+ gr.Warning(info)
+ raise FileExistsError(info)
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
- prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
+ prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
else:
prompt_text_update = {'__type__':'update', 'value':''}
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
@@ -242,13 +238,15 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else:
visible_sample_steps=False
visible_inp_refs=True
- yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False}
+ yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
- dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False)
+ dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
- if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
+ if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
+ hps.model.version = "v2"#v3model,v2sybomls
+ elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
@@ -278,7 +276,24 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else:
vq_model = vq_model.to(device)
vq_model.eval()
- print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False))
+ if if_lora_v3==False:
+ print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False))
+ else:
+ print("loading sovits_v3pretrained_G", vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False))
+ lora_rank=dict_s2["lora_rank"]
+ lora_config = LoraConfig(
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ r=lora_rank,
+ lora_alpha=lora_rank,
+ init_lora_weights=True,
+ )
+ vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
+ print("loading sovits_v3_lora%s"%(lora_rank))
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
+ vq_model.cfm = vq_model.cfm.merge_and_unload()
+ # torch.save(vq_model.state_dict(),"merge_win.pth")
+ vq_model.eval()
+
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
@@ -317,23 +332,24 @@ now_dir = os.getcwd()
import soundfile
def init_bigvgan():
- global model
+ global bigvgan_model
from BigVGAN import bigvgan
- model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
- model.remove_weight_norm()
- model = model.eval()
+ bigvgan_model.remove_weight_norm()
+ bigvgan_model = bigvgan_model.eval()
if is_half == True:
- model = model.half().to(device)
+ bigvgan_model = bigvgan_model.half().to(device)
else:
- model = model.to(device)
+ bigvgan_model = bigvgan_model.to(device)
-if model_version!="v3":model=None
+if model_version!="v3":bigvgan_model=None
else:init_bigvgan()
def get_spepc(hps, filename):
- audio = load_audio(filename, int(hps.data.sampling_rate))
+ # audio = load_audio(filename, int(hps.data.sampling_rate))
+ audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
maxx=audio.abs().max()
if(maxx>1):audio/=min(2,maxx)
@@ -350,6 +366,7 @@ def get_spepc(hps, filename):
return spec
def clean_text_inf(text, language, version):
+ language = language.replace("all_","")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
@@ -379,11 +396,10 @@ def get_first(text):
from text import chinese
def get_phones_and_bert(text,language,version,final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
- language = language.replace("all_","")
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
- if language == "zh":
+ if language == "all_zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
@@ -391,7 +407,7 @@ def get_phones_and_bert(text,language,version,final=False):
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
- elif language == "yue" and re.search(r'[A-Za-z]', formattext):
+ elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version)
@@ -443,12 +459,14 @@ def get_phones_and_bert(text,language,version,final=False):
return phones,bert.to(dtype),norm_text
-from module.mel_processing import spectrogram_torch,spec_to_mel_torch
-def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
- spec=spectrogram_torch(y,n_fft,sampling_rate,hop_size,win_size,center)
- mel=spec_to_mel_torch(spec,n_fft,num_mels,sampling_rate,fmin,fmax)
- return mel
-mel_fn_args = {
+from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
+spec_min = -12
+spec_max = 2
+def norm_spec(x):
+ return (x - spec_min) / (spec_max - spec_min) * 2 - 1
+def denorm_spec(x):
+ return (x + 1) / 2 * (spec_max - spec_min) + spec_min
+mel_fn=lambda x: mel_spectrogram_torch(x, **{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
@@ -457,16 +475,7 @@ mel_fn_args = {
"fmin": 0,
"fmax": None,
"center": False
-}
-
-spec_min = -12
-spec_max = 2
-def norm_spec(x):
- return (x - spec_min) / (spec_max - spec_min) * 2 - 1
-def denorm_spec(x):
- return (x + 1) / 2 * (spec_max - spec_min) + spec_min
-mel_fn=lambda x: mel_spectrogram(x, **mel_fn_args)
-
+})
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
@@ -485,10 +494,23 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text
return result
+sr_model=None
+def audio_sr(audio,sr):
+ global sr_model
+ if sr_model==None:
+ from tools.audio_sr import AP_BWE
+ try:
+ sr_model=AP_BWE(device,DictToAttrRecursive)
+ except FileNotFoundError:
+ gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
+ return audio.cpu().detach().numpy(),sr
+ return sr_model(audio,sr)
+
+
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
# cache_tokens={}#暂未实现清理机制
cache= {}
-def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=8):
+def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=8,if_sr=False,pause_second=0.3):
global cache
if ref_wav_path:pass
else:gr.Warning(i18n('请上传参考音频'))
@@ -497,7 +519,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t = []
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
- if model_version=="v3":ref_free=False#s2v3暂不支持ref_free
+ if model_version=="v3":
+ ref_free=False#s2v3暂不支持ref_free
+ else:
+ if_sr=False
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
@@ -509,12 +534,17 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
# if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
-
+
print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros(
- int(hps.data.sampling_rate * 0.3),
+ int(hps.data.sampling_rate * pause_second),
dtype=np.float16 if is_half == True else np.float32,
)
+ zero_wav_torch = torch.from_numpy(zero_wav)
+ if is_half == True:
+ zero_wav_torch = zero_wav_torch.half().to(device)
+ else:
+ zero_wav_torch = zero_wav_torch.to(device)
if not ref_free:
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
@@ -522,13 +552,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
wav16k = torch.from_numpy(wav16k)
- zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
wav16k = wav16k.half().to(device)
- zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
- zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
@@ -612,11 +639,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
except:
traceback.print_exc()
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
- audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
+ audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed)[0][0]#.cpu().detach().numpy()
else:
- refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)#######这里要重采样切到32k,因为src是24k的,没有单独的32k的src,所以不能改成2个路径
+ refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
+ # print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio=ref_audio.to(device).float()
@@ -624,7 +652,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr!=24000:
ref_audio=resample(ref_audio,sr)
- mel2 = mel_fn(ref_audio.to(dtype))
+ # print("ref_audio",ref_audio.abs().mean())
+ mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
@@ -634,7 +663,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
- fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge)
+ # print("fea_ref",fea_ref,fea_ref.shape)
+ # print("mel2",mel2)
+ mel2=mel2.to(dtype)
+ fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
+ # print("fea_todo",fea_todo)
+ # print("ge",ge.abs().mean())
cfm_resss = []
idx = 0
while (1):
@@ -642,29 +676,38 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
+ # set_seed(123)
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
+ # print("fea", fea)
+ # print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
- if model==None:init_bigvgan()
+ if bigvgan_model==None:init_bigvgan()
with torch.inference_mode():
- wav_gen = model(cmf_res)
- audio=wav_gen[0][0].cpu().detach().numpy()
- max_audio=np.abs(audio).max()#简单防止16bit爆音
- if max_audio>1:audio/=max_audio
+ wav_gen = bigvgan_model(cmf_res)
+ audio=wav_gen[0][0]#.cpu().detach().numpy()
+ max_audio=torch.abs(audio).max()#简单防止16bit爆音
+ if max_audio>1:audio/=max_audio
audio_opt.append(audio)
- audio_opt.append(zero_wav)
+ audio_opt.append(zero_wav_torch)#zero_wav
t4 = ttime()
t.extend([t2 - t1,t3 - t2, t4 - t3])
t1 = ttime()
- print("%.3f\t%.3f\t%.3f\t%.3f" %
- (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
- )
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
+ audio_opt=torch.cat(audio_opt, 0)#np.concatenate
sr=hps.data.sampling_rate if model_version!="v3"else 24000
- yield sr, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
+ if if_sr==True and sr==24000:
+ print(i18n("音频超分中"))
+ audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
+ max_audio=np.abs(audio_opt).max()
+ if max_audio > 1: audio_opt /= max_audio
+ else:
+ audio_opt=audio_opt.cpu().detach().numpy()
+ yield sr, (audio_opt * 32767).astype(np.int16)
def split(todo_text):
@@ -821,7 +864,7 @@ def html_left(text, label='p'):
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
- value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "
" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
)
with gr.Group():
gr.Markdown(html_center(i18n("模型切换"),'h3'))
@@ -834,15 +877,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath", scale=13)
with gr.Column(scale=13):
- ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。v3暂不支持该模式,使用了会报错。"), value=False, interactive=True, show_label=True,scale=1)
- gr.Markdown(html_left(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。")))
+ ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")+i18n("v3暂不支持该模式,使用了会报错。"), value=False, interactive=True, show_label=True,scale=1)
+ gr.Markdown(html_left(i18n("使用无参考文本模式时建议使用微调的GPT")+"
"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")))
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=5, max_lines=5,scale=1)
with gr.Column(scale=14):
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文"),
)
inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple")if model_version!="v3"else gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple",visible=False)
- sample_steps = gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=32,choices=[4,8,16,32],visible=True)if model_version=="v3"else gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=8,choices=[4,8,16,32],visible=False)
+ sample_steps = gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=32,choices=[4,8,16,32],visible=True)if model_version=="v3"else gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),choices=[4,8,16,32],visible=False,value=32)
+ if_sr_Checkbox=gr.Checkbox(label=i18n("v3输出如果觉得闷可以试试开超分"), value=False, interactive=True, show_label=True,visible=False if model_version!="v3"else True)
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"),'h3'))
with gr.Row():
with gr.Column(scale=13):
@@ -859,11 +903,13 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
)
gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
if_freeze=gr.Checkbox(label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), value=False, interactive=True,show_label=True, scale=1)
- speed = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label=i18n("语速"),value=1,interactive=True, scale=1)
+ with gr.Row():
+ speed = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label=i18n("语速"),value=1,interactive=True, scale=1)
+ pause_second_slider = gr.Slider(minimum=0.1,maximum=0.5,step=0.01,label=i18n("句间停顿秒数"),value=0.3,interactive=True, scale=1)
gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=15,interactive=True, scale=1)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True, scale=1)
- temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True, scale=1)
+ temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True, scale=1)
# with gr.Column():
# gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。"))
# phoneme=gr.Textbox(label=i18n("音素框"), value="")
@@ -874,10 +920,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button.click(
get_tts_wav,
- [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs,sample_steps],
+ [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs,sample_steps,if_sr_Checkbox,pause_second_slider],
[output],
)
- SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs,ref_text_free])
+ SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs,ref_text_free,if_sr_Checkbox])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
# gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py
index dcc2bcf..5a6910d 100644
--- a/GPT_SoVITS/inference_webui_fast.py
+++ b/GPT_SoVITS/inference_webui_fast.py
@@ -42,7 +42,7 @@ sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2")
-
+
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import get_method
@@ -61,7 +61,7 @@ if torch.cuda.is_available():
# device = "mps"
else:
device = "cpu"
-
+
dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
@@ -106,20 +106,20 @@ if cnhubert_base_path is not None:
tts_config.cnhuhbert_base_path = cnhubert_base_path
if bert_path is not None:
tts_config.bert_base_path = bert_path
-
+
print(tts_config)
tts_pipeline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
version = tts_config.version
-def inference(text, text_lang,
- ref_audio_path,
+def inference(text, text_lang,
+ ref_audio_path,
aux_ref_audio_paths,
- prompt_text,
- prompt_lang, top_k,
- top_p, temperature,
- text_split_method, batch_size,
+ prompt_text,
+ prompt_lang, top_k,
+ top_p, temperature,
+ text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
@@ -150,7 +150,7 @@ def inference(text, text_lang,
}
for item in tts_pipeline.run(inputs):
yield item, actual_seed
-
+
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
@@ -201,7 +201,7 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
- prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
+ prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
else:
prompt_text_update = {'__type__':'update', 'value':''}
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
@@ -216,9 +216,9 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
- value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "
" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
)
-
+
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
@@ -228,7 +228,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
-
+
with gr.Row():
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
@@ -242,8 +242,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
)
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
- gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
-
+ gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"
"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。"))
+
with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
@@ -251,7 +251,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
-
+
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
@@ -274,8 +274,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
)
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
-
- with gr.Row():
+
+ with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
@@ -283,15 +283,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
-
-
+
+
inference_button.click(
inference,
[
text,text_language, inp_ref, inp_refs,
- prompt_text, prompt_language,
- top_k, top_p, temperature,
- how_to_cut, batch_size,
+ prompt_text, prompt_language,
+ top_k, top_p, temperature,
+ how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
@@ -315,13 +315,13 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
interactive=True,
)
cut_text= gr.Button(i18n("切分"), variant="primary")
-
+
def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]:
return ""
method = get_method(cut_method[how_to_cut])
return method(text_inp)
-
+
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py
index 323bf1b..6ceca20 100644
--- a/GPT_SoVITS/module/data_utils.py
+++ b/GPT_SoVITS/module/data_utils.py
@@ -456,6 +456,231 @@ class TextAudioSpeakerCollateV3():
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
+class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
+ """
+ 1) loads audio, speaker_id, text pairs
+ 2) normalizes text and converts them to sequences of integers
+ 3) computes spectrograms from audio files.
+ """
+
+ def __init__(self, hparams, val=False):
+ exp_dir = hparams.exp_dir
+ self.path2 = "%s/2-name2text.txt" % exp_dir
+ self.path4 = "%s/4-cnhubert" % exp_dir
+ self.path5 = "%s/5-wav32k" % exp_dir
+ assert os.path.exists(self.path2)
+ assert os.path.exists(self.path4)
+ assert os.path.exists(self.path5)
+ names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
+ names5 = set(os.listdir(self.path5))
+ self.phoneme_data = {}
+ with open(self.path2, "r", encoding="utf8") as f:
+ lines = f.read().strip("\n").split("\n")
+
+ for line in lines:
+ tmp = line.split("\t")
+ if (len(tmp) != 4):
+ continue
+ self.phoneme_data[tmp[0]] = [tmp[1]]
+
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
+ tmp = self.audiopaths_sid_text
+ leng = len(tmp)
+ min_num = 100
+ if (leng < min_num):
+ self.audiopaths_sid_text = []
+ for _ in range(max(2, int(min_num / leng))):
+ self.audiopaths_sid_text += tmp
+ self.max_wav_value = hparams.max_wav_value
+ self.sampling_rate = hparams.sampling_rate
+ self.filter_length = hparams.filter_length
+ self.hop_length = hparams.hop_length
+ self.win_length = hparams.win_length
+ self.sampling_rate = hparams.sampling_rate
+ self.val = val
+
+ random.seed(1234)
+ random.shuffle(self.audiopaths_sid_text)
+
+ print("phoneme_data_len:", len(self.phoneme_data.keys()))
+ print("wav_data_len:", len(self.audiopaths_sid_text))
+
+ audiopaths_sid_text_new = []
+ lengths = []
+ skipped_phone = 0
+ skipped_dur = 0
+ for audiopath in tqdm(self.audiopaths_sid_text):
+ try:
+ phoneme = self.phoneme_data[audiopath][0]
+ phoneme = phoneme.split(' ')
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
+ except Exception:
+ print(f"{audiopath} not in self.phoneme_data !")
+ skipped_phone += 1
+ continue
+
+ size = os.path.getsize("%s/%s" % (self.path5, audiopath))
+ duration = size / self.sampling_rate / 2
+
+ if duration == 0:
+ print(f"Zero duration for {audiopath}, skipping...")
+ skipped_dur += 1
+ continue
+
+ if 54 > duration > 0.6 or self.val:
+ audiopaths_sid_text_new.append([audiopath, phoneme_ids])
+ lengths.append(size // (2 * self.hop_length))
+ else:
+ skipped_dur += 1
+ continue
+
+ print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
+ print("total left: ", len(audiopaths_sid_text_new))
+ assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
+ self.audiopaths_sid_text = audiopaths_sid_text_new
+ self.lengths = lengths
+ self.spec_min=-12
+ self.spec_max=2
+
+ self.filter_length_mel=self.win_length_mel=1024
+ self.hop_length_mel=256
+ self.n_mel_channels=100
+ self.sampling_rate_mel=24000
+ self.mel_fmin=0
+ self.mel_fmax=None
+ def norm_spec(self, x):
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
+
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
+ audiopath, phoneme_ids = audiopath_sid_text
+ text = torch.FloatTensor(phoneme_ids)
+ try:
+ spec, mel,wav = self.get_audio("%s/%s" % (self.path5, audiopath))
+ with torch.no_grad():
+ ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
+ if (ssl.shape[-1] != spec.shape[-1]):
+ typee = ssl.dtype
+ ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
+ ssl.requires_grad = False
+ except:
+ traceback.print_exc()
+ mel = torch.zeros(100, 180)
+ wav = torch.zeros(1, 96 * self.hop_length)
+ spec = torch.zeros(1025, 96)
+ ssl = torch.zeros(1, 768, 96)
+ text = text[-1:]
+ print("load audio or ssl error!!!!!!", audiopath)
+ return (ssl, spec, wav, mel, text)
+
+ def get_audio(self, filename):
+ audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
+ audio=torch.FloatTensor(audio_array)#/32768
+ audio_norm = audio
+ audio_norm = audio_norm.unsqueeze(0)
+ audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
+ audio24=torch.FloatTensor(audio_array24)#/32768
+ audio_norm24 = audio24
+ audio_norm24 = audio_norm24.unsqueeze(0)
+
+ spec = spectrogram_torch(audio_norm, self.filter_length,
+ self.sampling_rate, self.hop_length, self.win_length,
+ center=False)
+ spec = torch.squeeze(spec, 0)
+
+
+ spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
+ mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
+ mel = torch.squeeze(mel, 0)
+ mel=self.norm_spec(mel)
+ # print(1111111,spec.shape,mel.shape)
+ return spec, mel,audio_norm
+
+ def get_sid(self, sid):
+ sid = torch.LongTensor([int(sid)])
+ return sid
+
+ def __getitem__(self, index):
+ # with torch.no_grad():
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
+
+ def __len__(self):
+ return len(self.audiopaths_sid_text)
+class TextAudioSpeakerCollateV3b():
+ """ Zero-pads model inputs and targets
+ """
+
+ def __init__(self, return_ids=False):
+ self.return_ids = return_ids
+
+ def __call__(self, batch):
+ """Collate's training batch from normalized text, audio and speaker identities
+ PARAMS
+ ------
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
+ """
+ #ssl, spec, wav,mel, text
+ # Right zero-pad all one-hot text sequences to max input length
+ _, ids_sorted_decreasing = torch.sort(
+ torch.LongTensor([x[1].size(1) for x in batch]),
+ dim=0, descending=True)
+#(ssl, spec,mel, text)
+ max_ssl_len = max([x[0].size(2) for x in batch])
+
+ max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
+ max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
+
+ # max_ssl_len = int(8 * ((max_ssl_len // 8) + 1))
+ # max_ssl_len1=max_ssl_len
+
+ max_spec_len = max([x[1].size(1) for x in batch])
+ max_spec_len = int(2 * ((max_spec_len // 2) + 1))
+ max_wav_len = max([x[2].size(1) for x in batch])
+ max_text_len = max([x[4].size(0) for x in batch])
+ max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
+
+ ssl_lengths = torch.LongTensor(len(batch))
+ spec_lengths = torch.LongTensor(len(batch))
+ text_lengths = torch.LongTensor(len(batch))
+ wav_lengths = torch.LongTensor(len(batch))
+ mel_lengths = torch.LongTensor(len(batch))
+
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
+ mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len)
+ ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
+ text_padded = torch.LongTensor(len(batch), max_text_len)
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
+
+ spec_padded.zero_()
+ mel_padded.zero_()
+ ssl_padded.zero_()
+ text_padded.zero_()
+ wav_padded.zero_()
+
+ for i in range(len(ids_sorted_decreasing)):
+ row = batch[ids_sorted_decreasing[i]]
+ # ssl, spec, wav,mel, text
+ ssl = row[0]
+ ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
+ ssl_lengths[i] = ssl.size(2)
+
+ spec = row[1]
+ spec_padded[i, :, :spec.size(1)] = spec
+ spec_lengths[i] = spec.size(1)
+
+ wav = row[2]
+ wav_padded[i, :, :wav.size(1)] = wav
+ wav_lengths[i] = wav.size(1)
+
+ mel = row[3]
+ mel_padded[i, :, :mel.size(1)] = mel
+ mel_lengths[i] = mel.size(1)
+
+ text = row[4]
+ text_padded[i, :text.size(0)] = text
+ text_lengths[i] = text.size(0)
+
+ return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
+ # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
diff --git a/GPT_SoVITS/module/mel_processing.py b/GPT_SoVITS/module/mel_processing.py
index 503825e..d94b045 100644
--- a/GPT_SoVITS/module/mel_processing.py
+++ b/GPT_SoVITS/module/mel_processing.py
@@ -145,7 +145,7 @@ def mel_spectrogram_torch(
return_complex=False,
)
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py
index d546fcd..623da80 100644
--- a/GPT_SoVITS/module/models.py
+++ b/GPT_SoVITS/module/models.py
@@ -1099,17 +1099,15 @@ class CFM(torch.nn.Module):
return x
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
b, _, t = x1.shape
-
- # random timestep
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
x0 = torch.randn_like(x1,device=mu.device)
vt = x1 - x0
xt = x0 + t[:, None, None] * vt
dt = torch.zeros_like(t,device=mu.device)
prompt = torch.zeros_like(x1)
- for bib in range(b):
- prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
- xt[bib, :, :prompt_lens[bib]] = 0
+ for i in range(b):
+ prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]]
+ xt[i, :, :prompt_lens[i]] = 0
gailv=0.3# if ttime()>1736250488 else 0.1
if random.random() < gailv:
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
@@ -1128,14 +1126,15 @@ class CFM(torch.nn.Module):
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
loss = 0
-
- # print(45555555,estimator_out.shape,u.shape,x_lens,prompt_lens)#45555555 torch.Size([7, 465, 100]) torch.Size([7, 100, 465]) tensor([461, 461, 451, 451, 442, 442, 442], device='cuda:0') tensor([ 96, 93, 185, 59, 244, 262, 294], device='cuda:0')
- for bib in range(b):
- loss += self.criterion(vt_pred[bib, :, prompt_lens[bib]:x_lens[bib]], vt[bib, :, prompt_lens[bib]:x_lens[bib]])
+ for i in range(b):
+ loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]])
loss /= b
- return loss#, estimator_out + (1 - self.sigma_min) * z
+ return loss
+def set_no_grad(net_g):
+ for name, param in net_g.named_parameters():
+ param.requires_grad=False
class SynthesizerTrnV3(nn.Module):
"""
@@ -1210,7 +1209,6 @@ class SynthesizerTrnV3(nn.Module):
bins=1024
)
self.freeze_quantizer=freeze_quantizer
-
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
@@ -1219,6 +1217,10 @@ class SynthesizerTrnV3(nn.Module):
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
+ if self.freeze_quantizer==True:
+ set_no_grad(self.ssl_proj)
+ set_no_grad(self.quantizer)
+ set_no_grad(self.enc_p)
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
with autocast(enabled=False):
@@ -1229,13 +1231,13 @@ class SynthesizerTrnV3(nn.Module):
if self.freeze_quantizer:
self.ssl_proj.eval()#
self.quantizer.eval()
+ self.enc_p.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
- with maybe_no_grad:
- quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
- x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
+ quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
@@ -1248,6 +1250,155 @@ class SynthesizerTrnV3(nn.Module):
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
return cfm_loss
+ @torch.no_grad()
+ def decode_encp(self, codes,text, refer,ge=None,speed=1):
+ # print(2333333,refer.shape)
+ # ge=None
+ if(ge==None):
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
+ refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
+ ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
+ y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
+ if speed==1:
+ sizee=int(codes.size(2)*2.5*1.5)
+ else:
+ sizee=int(codes.size(2)*2.5*1.5/speed)+1
+ y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
+
+ quantized = self.quantizer.decode(codes)
+ if self.semantic_frame_rate == '25hz':
+ quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed)
+ fea=self.bridge(x)
+ fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
+ ####more wn paramter to learn mel
+ fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
+ return fea,ge
+
+ def extract_latent(self, x):
+ ssl = self.ssl_proj(x)
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
+ return codes.transpose(0,1)
+
+class SynthesizerTrnV3b(nn.Module):
+ """
+ Synthesizer for Training
+ """
+
+ def __init__(self,
+ spec_channels,
+ segment_size,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ n_speakers=0,
+ gin_channels=0,
+ use_sdp=True,
+ semantic_frame_rate=None,
+ freeze_quantizer=None,
+ **kwargs):
+
+ super().__init__()
+ self.spec_channels = spec_channels
+ self.inter_channels = inter_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.resblock = resblock
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.upsample_rates = upsample_rates
+ self.upsample_initial_channel = upsample_initial_channel
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ self.segment_size = segment_size
+ self.n_speakers = n_speakers
+ self.gin_channels = gin_channels
+
+ self.model_dim=512
+ self.use_sdp = use_sdp
+ self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
+ # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
+ self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
+ upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
+ self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
+ gin_channels=gin_channels)
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
+
+
+ ssl_dim = 768
+ assert semantic_frame_rate in ['25hz', "50hz"]
+ self.semantic_frame_rate = semantic_frame_rate
+ if semantic_frame_rate == '25hz':
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
+ else:
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
+
+ self.quantizer = ResidualVectorQuantizer(
+ dimension=ssl_dim,
+ n_q=1,
+ bins=1024
+ )
+ self.freeze_quantizer=freeze_quantizer
+
+ inter_channels2=512
+ self.bridge=nn.Sequential(
+ nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
+ nn.LeakyReLU()
+ )
+ self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
+ self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
+ self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
+
+
+ def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
+ with autocast(enabled=False):
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
+ ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
+ # ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
+ # ge=None
+ maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
+ with maybe_no_grad:
+ if self.freeze_quantizer:
+ self.ssl_proj.eval()
+ self.quantizer.eval()
+ ssl = self.ssl_proj(ssl)
+ quantized, codes, commit_loss, quantized_list = self.quantizer(
+ ssl, layers=[0]
+ )
+ quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
+ z_p = self.flow(z, y_mask, g=ge)
+ z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
+ o = self.dec(z_slice, g=ge)
+ fea=self.bridge(x)
+ fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
+ fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
+ learned_mel = self.linear_mel(fea)
+ B=ssl.shape[0]
+ prompt_len_max = mel_lengths*2/3
+ prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)#
+ minn=min(mel.shape[-1],fea.shape[-1])
+ mel=mel[:,:,:minn]
+ fea=fea[:,:,:minn]
+ cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need
+ return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
+
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None):
# print(2333333,refer.shape)
diff --git a/GPT_SoVITS/process_ckpt.py b/GPT_SoVITS/process_ckpt.py
index 3a436f1..36ef434 100644
--- a/GPT_SoVITS/process_ckpt.py
+++ b/GPT_SoVITS/process_ckpt.py
@@ -14,7 +14,24 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
-def savee(ckpt, name, epoch, steps, hps):
+'''
+00:v1
+01:v2
+02:v3
+03:v3lora
+
+
+'''
+from io import BytesIO
+def my_save2(fea,path):
+ bio = BytesIO()
+ torch.save(fea, bio)
+ bio.seek(0)
+ data = bio.getvalue()
+ data = b'03' + data[2:]###temp for v3lora only, todo
+ with open(path, "wb") as f: f.write(data)
+
+def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
try:
opt = OrderedDict()
opt["weight"] = {}
@@ -24,8 +41,66 @@ def savee(ckpt, name, epoch, steps, hps):
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
- # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
- my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
+ if lora_rank:
+ opt["lora_rank"]=lora_rank
+ my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
+ else:
+ my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()
+
+head2version={
+ b'00':["v1","v1",False],
+ b'01':["v2","v2",False],
+ b'02':["v2","v3",False],
+ b'03':["v2","v3",True],
+}
+hash_pretrained_dict={
+ "dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
+ "43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
+ "6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
+}
+import hashlib
+def get_hash_from_file(sovits_path):
+ with open(sovits_path,"rb")as f:data=f.read(8192)
+ hash_md5 = hashlib.md5()
+ hash_md5.update(data)
+ return hash_md5.hexdigest()
+def get_sovits_version_from_path_fast(sovits_path):
+ ###1-if it is pretrained sovits models, by hash
+ hash=get_hash_from_file(sovits_path)
+ if hash in hash_pretrained_dict:
+ return hash_pretrained_dict[hash]
+ ###2-new weights or old weights, by head
+ with open(sovits_path,"rb")as f:version=f.read(2)
+ if version!=b"PK":
+ return head2version[version]
+ ###3-old weights, by file size
+ if_lora_v3=False
+ size=os.path.getsize(sovits_path)
+ '''
+ v1weights:about 82942KB
+ half thr:82978KB
+ v2weights:about 83014KB
+ v3weights:about 750MB
+ '''
+ if size < 82978 * 1024:
+ model_version = version = "v1"
+ elif size < 700 * 1024 * 1024:
+ model_version = version = "v2"
+ else:
+ version = "v2"
+ model_version = "v3"
+ return version,model_version,if_lora_v3
+
+def load_sovits_new(sovits_path):
+ f=open(sovits_path,"rb")
+ meta=f.read(2)
+ if meta!="PK":
+ data = b'PK' + f.read()
+ bio = BytesIO()
+ bio.write(data)
+ bio.seek(0)
+ return torch.load(bio, map_location="cpu", weights_only=False)
+ return torch.load(sovits_path,map_location="cpu", weights_only=False)
\ No newline at end of file
diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py
index 4b510d9..4311db9 100644
--- a/GPT_SoVITS/s1_train.py
+++ b/GPT_SoVITS/s1_train.py
@@ -26,12 +26,7 @@ from AR.utils import get_newest_ckpt
from collections import OrderedDict
from time import time as ttime
import shutil
-def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
- dir=os.path.dirname(path)
- name=os.path.basename(path)
- tmp_path="%s.pth"%(ttime())
- torch.save(fea,tmp_path)
- shutil.move(tmp_path,"%s/%s"%(dir,name))
+from process_ckpt import my_save
class my_model_ckpt(ModelCheckpoint):
diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py
index 5be43c9..4d88ee8 100644
--- a/GPT_SoVITS/s2_train.py
+++ b/GPT_SoVITS/s2_train.py
@@ -205,6 +205,7 @@ def run(rank, n_gpus, hps):
net_g,
optim_g,
)
+ epoch_str+=1
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
@@ -215,7 +216,7 @@ def run(rank, n_gpus, hps):
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
- print(
+ print("loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
@@ -227,7 +228,7 @@ def run(rank, n_gpus, hps):
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
- print(
+ print("loaded pretrained %s" % hps.train.pretrained_s2D,
net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
) if torch.cuda.is_available() else net_d.load_state_dict(
@@ -250,6 +251,7 @@ def run(rank, n_gpus, hps):
scaler = GradScaler(enabled=hps.train.fp16_run)
+ print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
@@ -280,6 +282,7 @@ def run(rank, n_gpus, hps):
)
scheduler_g.step()
scheduler_d.step()
+ print("training done")
def train_and_evaluate(
diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py
index a5f7da7..9933dee 100644
--- a/GPT_SoVITS/s2_train_v3.py
+++ b/GPT_SoVITS/s2_train_v3.py
@@ -178,6 +178,7 @@ def run(rank, n_gpus, hps):
net_g,
optim_g,
)
+ epoch_str+=1
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
@@ -188,7 +189,7 @@ def run(rank, n_gpus, hps):
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
- print(
+ print("loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
@@ -224,6 +225,7 @@ def run(rank, n_gpus, hps):
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None
+ print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
@@ -254,6 +256,7 @@ def run(rank, n_gpus, hps):
)
scheduler_g.step()
# scheduler_d.step()
+ print("training done")
def train_and_evaluate(
diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py
new file mode 100644
index 0000000..75b3415
--- /dev/null
+++ b/GPT_SoVITS/s2_train_v3_lora.py
@@ -0,0 +1,345 @@
+import warnings
+warnings.filterwarnings("ignore")
+import utils, os
+hps = utils.get_hparams(stage=2)
+os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
+import torch
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+import torch.multiprocessing as mp
+import torch.distributed as dist, traceback
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.cuda.amp import autocast, GradScaler
+from tqdm import tqdm
+import logging, traceback
+
+logging.getLogger("matplotlib").setLevel(logging.INFO)
+logging.getLogger("h5py").setLevel(logging.INFO)
+logging.getLogger("numba").setLevel(logging.INFO)
+from random import randint
+from module import commons
+from peft import LoraConfig, PeftModel, get_peft_model
+from module.data_utils import (
+ TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
+ TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
+ DistributedBucketSampler,
+)
+from module.models import (
+ SynthesizerTrnV3 as SynthesizerTrn,
+ MultiPeriodDiscriminator,
+)
+from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
+from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
+from process_ckpt import savee
+from collections import OrderedDict as od
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = False
+###反正A100fp32更快,那试试tf32吧
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
+# from config import pretrained_s2G,pretrained_s2D
+global_step = 0
+
+device = "cpu" # cuda以外的设备,等mps优化后加入
+
+
+def main():
+
+ if torch.cuda.is_available():
+ n_gpus = torch.cuda.device_count()
+ else:
+ n_gpus = 1
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
+
+ mp.spawn(
+ run,
+ nprocs=n_gpus,
+ args=(
+ n_gpus,
+ hps,
+ ),
+ )
+
+
+def run(rank, n_gpus, hps):
+ global global_step,no_grad_names,save_root,lora_rank
+ if rank == 0:
+ logger = utils.get_logger(hps.data.exp_dir)
+ logger.info(hps)
+ # utils.check_git_hash(hps.s2_ckpt_dir)
+ writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
+
+ dist.init_process_group(
+ backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
+ init_method="env://?use_libuv=False",
+ world_size=n_gpus,
+ rank=rank,
+ )
+ torch.manual_seed(hps.train.seed)
+ if torch.cuda.is_available():
+ torch.cuda.set_device(rank)
+
+ train_dataset = TextAudioSpeakerLoader(hps.data) ########
+ train_sampler = DistributedBucketSampler(
+ train_dataset,
+ hps.train.batch_size,
+ [
+ 32,
+ 300,
+ 400,
+ 500,
+ 600,
+ 700,
+ 800,
+ 900,
+ 1000,
+ # 1100,
+ # 1200,
+ # 1300,
+ # 1400,
+ # 1500,
+ # 1600,
+ # 1700,
+ # 1800,
+ # 1900,
+ ],
+ num_replicas=n_gpus,
+ rank=rank,
+ shuffle=True,
+ )
+ collate_fn = TextAudioSpeakerCollate()
+ train_loader = DataLoader(
+ train_dataset,
+ num_workers=6,
+ shuffle=False,
+ pin_memory=True,
+ collate_fn=collate_fn,
+ batch_sampler=train_sampler,
+ persistent_workers=True,
+ prefetch_factor=4,
+ )
+ save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
+ os.makedirs(save_root,exist_ok=True)
+ lora_rank=int(hps.train.lora_rank)
+ lora_config = LoraConfig(
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ r=lora_rank,
+ lora_alpha=lora_rank,
+ init_lora_weights=True,
+ )
+ def get_model(hps):return SynthesizerTrn(
+ hps.data.filter_length // 2 + 1,
+ hps.train.segment_size // hps.data.hop_length,
+ n_speakers=hps.data.n_speakers,
+ **hps.model,
+ )
+ def get_optim(net_g):
+ return torch.optim.AdamW(
+ filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
+ hps.train.learning_rate,
+ betas=hps.train.betas,
+ eps=hps.train.eps,
+ )
+ def model2cuda(net_g,rank):
+ if torch.cuda.is_available():
+ net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
+ else:
+ net_g = net_g.to(device)
+ return net_g
+ try:# 如果能加载自动resume
+ net_g = get_model(hps)
+ net_g.cfm = get_peft_model(net_g.cfm, lora_config)
+ net_g=model2cuda(net_g,rank)
+ optim_g=get_optim(net_g)
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
+ _, _, _, epoch_str = utils.load_checkpoint(
+ utils.latest_checkpoint_path(save_root, "G_*.pth"),
+ net_g,
+ optim_g,
+ )
+ epoch_str+=1
+ global_step = (epoch_str - 1) * len(train_loader)
+ except: # 如果首次不能加载,加载pretrain
+ # traceback.print_exc()
+ epoch_str = 1
+ global_step = 0
+ net_g = get_model(hps)
+ if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
+ if rank == 0:
+ logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
+ print("loaded pretrained %s" % hps.train.pretrained_s2G,
+ net_g.load_state_dict(
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
+ strict=False,
+ )
+ )
+ net_g.cfm = get_peft_model(net_g.cfm, lora_config)
+ net_g=model2cuda(net_g,rank)
+ optim_g = get_optim(net_g)
+
+ no_grad_names=set()
+ for name, param in net_g.named_parameters():
+ if not param.requires_grad:
+ no_grad_names.add(name.replace("module.",""))
+ # print(name, "not requires_grad")
+ # print(no_grad_names)
+ # os._exit(233333)
+
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
+ optim_g, gamma=hps.train.lr_decay, last_epoch=-1
+ )
+ for _ in range(epoch_str):
+ scheduler_g.step()
+
+ scaler = GradScaler(enabled=hps.train.fp16_run)
+
+ net_d=optim_d=scheduler_d=None
+ print("start training from epoch %s"%epoch_str)
+ for epoch in range(epoch_str, hps.train.epochs + 1):
+ if rank == 0:
+ train_and_evaluate(
+ rank,
+ epoch,
+ hps,
+ [net_g, net_d],
+ [optim_g, optim_d],
+ [scheduler_g, scheduler_d],
+ scaler,
+ # [train_loader, eval_loader], logger, [writer, writer_eval])
+ [train_loader, None],
+ logger,
+ [writer, writer_eval],
+ )
+ else:
+ train_and_evaluate(
+ rank,
+ epoch,
+ hps,
+ [net_g, net_d],
+ [optim_g, optim_d],
+ [scheduler_g, scheduler_d],
+ scaler,
+ [train_loader, None],
+ None,
+ None,
+ )
+ scheduler_g.step()
+ print("training done")
+
+def train_and_evaluate(
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
+):
+ net_g, net_d = nets
+ optim_g, optim_d = optims
+ # scheduler_g, scheduler_d = schedulers
+ train_loader, eval_loader = loaders
+ if writers is not None:
+ writer, writer_eval = writers
+
+ train_loader.batch_sampler.set_epoch(epoch)
+ global global_step
+
+ net_g.train()
+ for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
+ if torch.cuda.is_available():
+ spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
+ rank, non_blocking=True
+ )
+ mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
+ rank, non_blocking=True
+ )
+ ssl = ssl.cuda(rank, non_blocking=True)
+ ssl.requires_grad = False
+ text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
+ rank, non_blocking=True
+ )
+ else:
+ spec, spec_lengths = spec.to(device), spec_lengths.to(device)
+ mel, mel_lengths = mel.to(device), mel_lengths.to(device)
+ ssl = ssl.to(device)
+ ssl.requires_grad = False
+ text, text_lengths = text.to(device), text_lengths.to(device)
+
+ with autocast(enabled=hps.train.fp16_run):
+ cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
+ loss_gen_all=cfm_loss
+ optim_g.zero_grad()
+ scaler.scale(loss_gen_all).backward()
+ scaler.unscale_(optim_g)
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
+ scaler.step(optim_g)
+ scaler.update()
+
+ if rank == 0:
+ if global_step % hps.train.log_interval == 0:
+ lr = optim_g.param_groups[0]['lr']
+ losses = [cfm_loss]
+ logger.info('Train Epoch: {} [{:.0f}%]'.format(
+ epoch,
+ 100. * batch_idx / len(train_loader)))
+ logger.info([x.item() for x in losses] + [global_step, lr])
+
+ scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
+ utils.summarize(
+ writer=writer,
+ global_step=global_step,
+ scalars=scalar_dict)
+
+ global_step += 1
+ if epoch % hps.train.save_every_epoch == 0 and rank == 0:
+ if hps.train.if_save_latest == 0:
+ utils.save_checkpoint(
+ net_g,
+ optim_g,
+ hps.train.learning_rate,
+ epoch,
+ os.path.join(
+ save_root, "G_{}.pth".format(global_step)
+ ),
+ )
+ else:
+ utils.save_checkpoint(
+ net_g,
+ optim_g,
+ hps.train.learning_rate,
+ epoch,
+ os.path.join(
+ save_root, "G_{}.pth".format(233333333333)
+ ),
+ )
+ if rank == 0 and hps.train.if_save_every_weights == True:
+ if hasattr(net_g, "module"):
+ ckpt = net_g.module.state_dict()
+ else:
+ ckpt = net_g.state_dict()
+ sim_ckpt=od()
+ for key in ckpt:
+ # if "cfm"not in key:
+ # print(key)
+ if key not in no_grad_names:
+ sim_ckpt[key]=ckpt[key].half().cpu()
+ logger.info(
+ "saving ckpt %s_e%s:%s"
+ % (
+ hps.name,
+ epoch,
+ savee(
+ sim_ckpt,
+ hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
+ epoch,
+ global_step,
+ hps,lora_rank=lora_rank
+ ),
+ )
+ )
+
+ if rank == 0:
+ logger.info("====> Epoch: {}".format(epoch))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py
index 6859ddb..cca5bf2 100644
--- a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py
+++ b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py
@@ -1,14 +1,74 @@
import logging
-import jieba
import re
+
+# jieba静音
+import jieba
jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置
from pathlib import Path
import fast_langdetect
fast_langdetect.ft_detect.infer.CACHE_DIRECTORY = Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
-import sys
-sys.modules["fast_langdetect"] = fast_langdetect
+
+# 防止win下无法读取模型
+import os
+from typing import Optional
+def load_fasttext_model(
+ model_path: Path,
+ download_url: Optional[str] = None,
+ proxy: Optional[str] = None,
+):
+ """
+ Load a FastText model, downloading it if necessary.
+ :param model_path: Path to the FastText model file
+ :param download_url: URL to download the model from
+ :param proxy: Proxy URL for downloading the model
+ :return: FastText model
+ :raises DetectError: If model loading fails
+ """
+ if all([
+ fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL,
+ model_path.exists(),
+ model_path.name == fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_NAME,
+ ]):
+ if not fast_langdetect.ft_detect.infer.verify_md5(model_path, fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL):
+ fast_langdetect.ft_detect.infer.logger.warning(
+ f"fast-langdetect: MD5 hash verification failed for {model_path}, "
+ f"please check the integrity of the downloaded file from {fast_langdetect.ft_detect.infer.FASTTEXT_LARGE_MODEL_URL}. "
+ "\n This may seriously reduce the prediction accuracy. "
+ "If you want to ignore this, please set `fast_langdetect.ft_detect.infer.VERIFY_FASTTEXT_LARGE_MODEL = None` "
+ )
+ if not model_path.exists():
+ if download_url:
+ fast_langdetect.ft_detect.infer.download_model(download_url, model_path, proxy)
+ if not model_path.exists():
+ raise fast_langdetect.ft_detect.infer.DetectError(f"FastText model file not found at {model_path}")
+
+ try:
+ # Load FastText model
+ if (re.match(r'^[A-Za-z0-9_/\\:.]*$', str(model_path))):
+ model = fast_langdetect.ft_detect.infer.fasttext.load_model(str(model_path))
+ else:
+ python_dir = os.getcwd()
+ if (str(model_path)[:len(python_dir)].upper() == python_dir.upper()):
+ model = fast_langdetect.ft_detect.infer.fasttext.load_model(os.path.relpath(model_path, python_dir))
+ else:
+ import tempfile
+ import shutil
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
+ shutil.copyfile(model_path, tmpfile.name)
+
+ model = fast_langdetect.ft_detect.infer.fasttext.load_model(tmpfile.name)
+ os.unlink(tmpfile.name)
+ return model
+
+ except Exception as e:
+ fast_langdetect.ft_detect.infer.logger.warning(f"fast-langdetect:Failed to load FastText model from {model_path}: {e}")
+ raise fast_langdetect.ft_detect.infer.DetectError(f"Failed to load FastText model: {e}")
+
+if os.name == 'nt':
+ fast_langdetect.ft_detect.infer.load_fasttext_model = load_fasttext_model
+
from split_lang import LangSplitter
@@ -18,6 +78,32 @@ def full_en(text):
return bool(re.match(pattern, text))
+def full_cjk(text):
+ # 来自wiki
+ cjk_ranges = [
+ (0x4E00, 0x9FFF), # CJK Unified Ideographs
+ (0x3400, 0x4DB5), # CJK Extension A
+ (0x20000, 0x2A6DD), # CJK Extension B
+ (0x2A700, 0x2B73F), # CJK Extension C
+ (0x2B740, 0x2B81F), # CJK Extension D
+ (0x2B820, 0x2CEAF), # CJK Extension E
+ (0x2CEB0, 0x2EBEF), # CJK Extension F
+ (0x30000, 0x3134A), # CJK Extension G
+ (0x31350, 0x323AF), # CJK Extension H
+ (0x2EBF0, 0x2EE5D), # CJK Extension H
+ ]
+
+ pattern = r'[0-9、-〜。!?.!?… ]+$'
+
+ cjk_text = ""
+ for char in text:
+ code_point = ord(char)
+ in_cjk = any(start <= code_point <= end for start, end in cjk_ranges)
+ if in_cjk or re.match(pattern, char):
+ cjk_text += char
+ return cjk_text
+
+
def split_jako(tag_lang,item):
if tag_lang == "ja":
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
@@ -98,8 +184,12 @@ class LangSegmenter():
# 未存在非日韩文夹日韩文
if len(temp_list) == 1:
- # 跳过未知语言
+ # 未知语言检查是否为CJK
if dict_item['lang'] == 'x':
+ cjk_text = full_cjk(dict_item['text'])
+ if cjk_text:
+ dict_item = {'lang':'zh','text':cjk_text}
+ lang_list = merge_lang(lang_list,dict_item)
continue
else:
lang_list = merge_lang(lang_list,dict_item)
@@ -107,12 +197,14 @@ class LangSegmenter():
# 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list):
- # 待观察是否会出现带英文或语言为x的中日英韩文
+ # 未知语言检查是否为CJK
if temp_item['lang'] == 'x':
- continue
-
- lang_list = merge_lang(lang_list,temp_item)
-
+ cjk_text = full_cjk(dict_item['text'])
+ if cjk_text:
+ dict_item = {'lang':'zh','text':cjk_text}
+ lang_list = merge_lang(lang_list,dict_item)
+ else:
+ lang_list = merge_lang(lang_list,temp_item)
return lang_list
@@ -120,5 +212,6 @@ if __name__ == "__main__":
text = "MyGO?,你也喜欢まいご吗?"
print(LangSegmenter.getTexts(text))
-
+ text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
+ print(LangSegmenter.getTexts(text))
diff --git a/GPT_SoVITS/text/japanese.py b/GPT_SoVITS/text/japanese.py
index 440062a..d815ef4 100644
--- a/GPT_SoVITS/text/japanese.py
+++ b/GPT_SoVITS/text/japanese.py
@@ -5,6 +5,40 @@ import hashlib
try:
import pyopenjtalk
current_file_path = os.path.dirname(__file__)
+
+ # 防止win下无法读取模型
+ if os.name == 'nt':
+ python_dir = os.getcwd()
+ OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
+ if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', OPEN_JTALK_DICT_DIR)):
+ if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
+ OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
+ else:
+ import shutil
+ if not os.path.exists('TEMP'):
+ os.mkdir('TEMP')
+ if not os.path.exists(os.path.join("TEMP", "ja")):
+ os.mkdir(os.path.join("TEMP", "ja"))
+ if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
+ shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
+ shutil.copytree(pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), os.path.join("TEMP", "ja", "open_jtalk_dic"), )
+ OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
+ pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
+
+ if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', current_file_path)):
+ if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
+ current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
+ else:
+ if not os.path.exists('TEMP'):
+ os.mkdir('TEMP')
+ if not os.path.exists(os.path.join("TEMP", "ja")):
+ os.mkdir(os.path.join("TEMP", "ja"))
+ if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
+ os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
+ shutil.copyfile(os.path.join(current_file_path, "ja_userdic", "userdict.csv"),os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"))
+ current_file_path = os.path.join("TEMP", "ja")
+
+
def get_hash(fp: str) -> str:
hash_md5 = hashlib.md5()
with open(fp, "rb") as f:
diff --git a/GPT_SoVITS/text/korean.py b/GPT_SoVITS/text/korean.py
index 8f28136..79d89af 100644
--- a/GPT_SoVITS/text/korean.py
+++ b/GPT_SoVITS/text/korean.py
@@ -5,6 +5,53 @@ from jamo import h2j, j2hcj
import ko_pron
from g2pk2 import G2p
+import importlib
+import os
+
+# 防止win下无法读取模型
+if os.name == 'nt':
+ class win_G2p(G2p):
+ def check_mecab(self):
+ super().check_mecab()
+ spam_spec = importlib.util.find_spec("eunjeon")
+ non_found = spam_spec is None
+ if non_found:
+ print(f'you have to install eunjeon. install it...')
+ else:
+ installpath = spam_spec.submodule_search_locations[0]
+ if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)):
+
+ import sys
+ from eunjeon import Mecab as _Mecab
+ class Mecab(_Mecab):
+ def get_dicpath(installpath):
+ if not (re.match(r'^[A-Za-z0-9_/\\:.]*$', installpath)):
+ import shutil
+ python_dir = os.getcwd()
+ if (installpath[:len(python_dir)].upper() == python_dir.upper()):
+ dicpath = os.path.join(os.path.relpath(installpath,python_dir),'data','mecabrc')
+ else:
+ if not os.path.exists('TEMP'):
+ os.mkdir('TEMP')
+ if not os.path.exists(os.path.join('TEMP', 'ko')):
+ os.mkdir(os.path.join('TEMP', 'ko'))
+ if os.path.exists(os.path.join('TEMP', 'ko', 'ko_dict')):
+ shutil.rmtree(os.path.join('TEMP', 'ko', 'ko_dict'))
+
+ shutil.copytree(os.path.join(installpath, 'data'), os.path.join('TEMP', 'ko', 'ko_dict'))
+ dicpath = os.path.join('TEMP', 'ko', 'ko_dict', 'mecabrc')
+ else:
+ dicpath=os.path.abspath(os.path.join(installpath, 'data/mecabrc'))
+ return dicpath
+
+ def __init__(self, dicpath=get_dicpath(installpath)):
+ super().__init__(dicpath=dicpath)
+
+ sys.modules["eunjeon"].Mecab = Mecab
+
+ G2p = win_G2p
+
+
from text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals.
@@ -263,3 +310,8 @@ def g2p(text):
# text = "".join([post_replace_ph(i) for i in text])
text = [post_replace_ph(i) for i in text]
return text
+
+
+if __name__ == "__main__":
+ text = "안녕하세요"
+ print(g2p(text))
\ No newline at end of file
diff --git a/README.md b/README.md
index a3f5966..adc1344 100644
--- a/README.md
+++ b/README.md
@@ -121,9 +121,7 @@ pip install -r requirements.txt
0. Regarding image tags: Due to rapid updates in the codebase and the slow process of packaging and testing images, please check [Docker Hub](https://hub.docker.com/r/breakstring/gpt-sovits) for the currently packaged latest images and select as per your situation, or alternatively, build locally using a Dockerfile according to your own needs.
1. Environment Variables:
-
-- is_half: Controls half-precision/double-precision. This is typically the cause if the content under the directories 4-cnhubert/5-wav32k is not generated correctly during the "SSL extracting" step. Adjust to True or False based on your actual situation.
-
+ - is_half: Controls half-precision/double-precision. This is typically the cause if the content under the directories 4-cnhubert/5-wav32k is not generated correctly during the "SSL extracting" step. Adjust to True or False based on your actual situation.
2. Volumes Configuration,The application's root directory inside the container is set to /workspace. The default docker-compose.yaml lists some practical examples for uploading/downloading content.
3. shm_size: The default available memory for Docker Desktop on Windows is too small, which can cause abnormal operations. Adjust according to your own situation.
4. Under the deploy section, GPU-related settings should be adjusted cautiously according to your system and actual circumstances.
@@ -152,9 +150,13 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
3. For UVR5 (Vocals/Accompaniment Separation & Reverberation Removal, additionally), download models from [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) and place them in `tools/uvr5/uvr5_weights`.
+ - If you want to use `bs_roformer` or `mel_band_roformer` models for UVR5, you can manually download the model and corresponding configuration file, and put them in `tools/uvr5/uvr5_weights`. **Rename the model file and configuration file, ensure that the model and configuration files have the same and corresponding names except for the suffix**. In addition, the model and configuration file names **must include `roformer`** in order to be recognized as models of the roformer class.
+
+ - The suggestion is to **directly specify the model type** in the model name and configuration file name, such as `mel_mand_roformer`, `bs_roformer`. If not specified, the features will be compared from the configuration file to determine which type of model it is. For example, the model `bs_roformer_ep_368_sdr_12.9628.ckpt` and its corresponding configuration file `bs_roformer_ep_368_sdr_12.9628.yaml` are a pair, `kim_mel_band_roformer.ckpt` and `kim_mel_band_roformer.yaml` are also a pair.
+
4. For Chinese ASR (additionally), download models from [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files), and [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) and place them in `tools/asr/models`.
-5. For English or Japanese ASR (additionally), download models from [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) and place them in `tools/asr/models`. Also, [other models](https://huggingface.co/Systran) may have the similar effect with smaller disk footprint.
+5. For English or Japanese ASR (additionally), download models from [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) and place them in `tools/asr/models`. Also, [other models](https://huggingface.co/Systran) may have the similar effect with smaller disk footprint.
## Dataset Format
@@ -171,7 +173,7 @@ Language dictionary:
- 'en': English
- 'ko': Korean
- 'yue': Cantonese
-
+
Example:
```
@@ -180,61 +182,56 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin.
## Finetune and inference
- ### Open WebUI
+### Open WebUI
- #### Integrated Package Users
+#### Integrated Package Users
- Double-click `go-webui.bat`or use `go-webui.ps1`
- if you want to switch to V1,then double-click`go-webui-v1.bat` or use `go-webui-v1.ps1`
+Double-click `go-webui.bat`or use `go-webui.ps1`
+if you want to switch to V1,then double-click`go-webui-v1.bat` or use `go-webui-v1.ps1`
- #### Others
+#### Others
- ```bash
- python webui.py
- ```
+```bash
+python webui.py
+```
- if you want to switch to V1,then
+if you want to switch to V1,then
- ```bash
- python webui.py v1
- ```
+```bash
+python webui.py v1
+```
Or maunally switch version in WebUI
- ### Finetune
+### Finetune
- #### Path Auto-filling is now supported
+#### Path Auto-filling is now supported
- 1.Fill in the audio path
+ 1. Fill in the audio path
+ 2. Slice the audio into small chunks
+ 3. Denoise(optinal)
+ 4. ASR
+ 5. Proofreading ASR transcriptions
+ 6. Go to the next Tab, then finetune the model
- 2.Slice the audio into small chunks
+### Open Inference WebUI
- 3.Denoise(optinal)
+#### Integrated Package Users
- 4.ASR
+Double-click `go-webui-v2.bat` or use `go-webui-v2.ps1` ,then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference`
- 5.Proofreading ASR transcriptions
+#### Others
- 6.Go to the next Tab, then finetune the model
+```bash
+python GPT_SoVITS/inference_webui.py
+```
+OR
- ### Open Inference WebUI
-
- #### Integrated Package Users
-
- Double-click `go-webui-v2.bat` or use `go-webui-v2.ps1` ,then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference`
-
- #### Others
-
- ```bash
- python GPT_SoVITS/inference_webui.py
- ```
- OR
-
- ```bash
- python webui.py
- ```
+```bash
+python webui.py
+```
then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference`
- ## V2 Release Notes
+## V2 Release Notes
New Features:
@@ -244,11 +241,11 @@ New Features:
3. Pre-trained model extended from 2k hours to 5k hours
-4. Improved synthesis quality for low-quality reference audio
+4. Improved synthesis quality for low-quality reference audio
- [more details](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v2%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7) )
+ [more details](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v2%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7))
-Use v2 from v1 environment:
+Use v2 from v1 environment:
1. `pip install -r requirements.txt` to update some packages
@@ -257,7 +254,28 @@ Use v2 from v1 environment:
3. Download v2 pretrained models from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main/gsv-v2final-pretrained) and put them into `GPT_SoVITS\pretrained_models\gsv-v2final-pretrained`.
Chinese v2 additional: [G2PWModel_1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip)(Download G2PW models, unzip and rename to `G2PWModel`, and then place them in `GPT_SoVITS/text`.
-
+
+## V3 Release Notes
+
+New Features:
+
+1. The timbre similarity is higher, requiring less training data to approximate the target speaker (the timbre similarity is significantly improved using the base model directly without fine-tuning).
+
+2. GPT model is more stable, with fewer repetitions and omissions, and it is easier to generate speech with richer emotional expression.
+
+ [more details](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v3%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7))
+
+Use v3 from v2 environment:
+
+1. `pip install -r requirements.txt` to update some packages
+
+2. Clone the latest codes from github.
+
+3. Download v3 pretrained models (s1v3.ckpt, s2Gv3.pth and models--nvidia--bigvgan_v2_24khz_100band_256x folder) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS\pretrained_models`.
+
+ additional: for Audio Super Resolution model, you can read [how to download](./tools/AP_BWE_main/24kto48k/readme.txt)
+
+
## Todo List
- [x] **High Priority:**
@@ -285,7 +303,7 @@ python tools/uvr5/webui.py ""
```
This is how the audio segmentation of the dataset is done using the command line
```
@@ -294,7 +312,7 @@ python audio_slicer.py \
--output_root "" \
--threshold \
--min_length \
- --min_interval
+ --min_interval
--hop_size
```
This is how dataset ASR processing is done using the command line(Only Chinese)
@@ -341,6 +359,7 @@ Special thanks to the following projects and contributors:
- [gradio](https://github.com/gradio-app/gradio)
- [faster-whisper](https://github.com/SYSTRAN/faster-whisper)
- [FunASR](https://github.com/alibaba-damo-academy/FunASR)
+- [AP-BWE](https://github.com/yxlu-0102/AP-BWE)
Thankful to @Naozumi520 for providing the Cantonese training set and for the guidance on Cantonese-related knowledge.
diff --git a/docs/cn/Changelog_CN.md b/docs/cn/Changelog_CN.md
index d60d213..4666f6e 100644
--- a/docs/cn/Changelog_CN.md
+++ b/docs/cn/Changelog_CN.md
@@ -246,4 +246,43 @@
### 20250211
-1-增加gpt-sovits-v3模型
+增加gpt-sovits-v3模型,需要14G显存可以微调
+
+### 20250212
+
+sovits-v3微调支持开启梯度检查点,需要12G显存可以微调https://github.com/RVC-Boss/GPT-SoVITS/pull/2040
+
+### 20250214
+
+优化多语种混合文本切分策略a https://github.com/RVC-Boss/GPT-SoVITS/pull/2047
+
+### 20250217
+
+优化文本里的数字和英文处理逻辑https://github.com/RVC-Boss/GPT-SoVITS/pull/2062
+
+### 20250218
+
+优化多语种混合文本切分策略b https://github.com/RVC-Boss/GPT-SoVITS/pull/2073
+
+### 20250223
+
+1-sovits-v3微调支持lora训练,需要8G显存可以微调,效果比全参微调更好
+
+2-人声背景音分离增加mel band roformer模型支持https://github.com/RVC-Boss/GPT-SoVITS/pull/2078
+
+### 20250226
+
+https://github.com/RVC-Boss/GPT-SoVITS/pull/2112 https://github.com/RVC-Boss/GPT-SoVITS/pull/2114
+
+修复中文路径下mecab的报错(具体表现为日文韩文、文本混合语种切分可能会遇到的报错)
+
+### 20250227
+
+针对v3生成24k音频感觉闷的问题https://github.com/RVC-Boss/GPT-SoVITS/issues/2085 https://github.com/RVC-Boss/GPT-SoVITS/issues/2117 ,支持使用24k to 48k的音频超分模型缓解。
+
+
+### 20250228
+
+修复短文本语种选择出错 https://github.com/RVC-Boss/GPT-SoVITS/pull/2122
+
+修复v3sovits未传参以支持调节语速
diff --git a/docs/cn/README.md b/docs/cn/README.md
index a063128..6196099 100644
--- a/docs/cn/README.md
+++ b/docs/cn/README.md
@@ -5,11 +5,13 @@
[](https://github.com/RVC-Boss/GPT-SoVITS)
-
+
+
+
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
-[](https://huggingface.co/lj1995/GPT-SoVITS/tree/main)
+[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
[](https://discord.gg/dnrgs5GHfG)
[**English**](../../README.md) | **中文简体** | [**日本語**](../ja/README.md) | [**한국어**](../ko/README.md) | [**Türkçe**](../tr/README.md)
@@ -149,6 +151,11 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
3. 对于 UVR5(人声/伴奏分离和混响移除,额外功能),从 [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) 下载模型,并将其放置在 `tools/uvr5/uvr5_weights` 目录中。
+ - 如果你在 UVR5 中使用 `bs_roformer` 或 `mel_band_roformer`模型,你可以手动下载模型和相应的配置文件,并将它们放在 `tools/UVR5/UVR5_weights` 中。**重命名模型文件和配置文件,确保除后缀外**,模型和配置文件具有相同且对应的名称。此外,模型和配置文件名**必须包含“roformer”**,才能被识别为 roformer 类的模型。
+
+ - 建议在模型名称和配置文件名中**直接指定模型类型**,例如`mel_mand_roformer`、`bs_roformer`。如果未指定,将从配置文中比对特征,以确定它是哪种类型的模型。例如,模型`bs_roformer_ep_368_sdr_12.9628.ckpt` 和对应的配置文件`bs_roformer_ep_368_sdr_12.9628.yaml` 是一对。`kim_mel_band_roformer.ckpt` 和 `kim_mel_band_roformer.yaml` 也是一对。
+
+
4. 对于中文 ASR(额外功能),从 [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)、[Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files) 和 [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) 下载模型,并将它们放置在 `tools/asr/models` 目录中。
5. 对于英语或日语 ASR(额外功能),从 [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) 下载模型,并将其放置在 `tools/asr/models` 目录中。此外,[其他模型](https://huggingface.co/Systran) 可能具有类似效果且占用更少的磁盘空间。
@@ -201,17 +208,12 @@ python webui.py v1
#### 现已支持自动填充路径
- 1.填入训练音频路径
-
- 2.切割音频
-
- 3.进行降噪(可选)
-
- 4.进行ASR
-
- 5.校对标注
-
- 6.前往下一个窗口,点击训练
+ 1. 填入训练音频路径
+ 2. 切割音频
+ 3. 进行降噪(可选)
+ 4. 进行ASR
+ 5. 校对标注
+ 6. 前往下一个窗口,点击训练
### 打开推理WebUI
@@ -255,6 +257,27 @@ python webui.py
中文额外需要下载[G2PWModel_1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip)(下载G2PW模型,解压并重命名为`G2PWModel`,将其放到`GPT_SoVITS/text`目录下)
+## V3更新说明
+
+新模型特点:
+
+1. 音色相似度更像,需要更少训练集来逼近本人(不训练直接使用底模模式下音色相似性提升更大)
+
+2. GPT合成更稳定,重复漏字更少,也更容易跑出丰富情感
+
+ 详见[wiki](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v2%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7))
+
+从v2环境迁移至v3
+
+1. 需要pip安装requirements.txt更新环境
+
+2. 需要克隆github上的最新代码
+
+3. 从[huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main)下载这些v3新增预训练模型 (s1v3.ckpt, s2Gv3.pth and models--nvidia--bigvgan_v2_24khz_100band_256x folder)将他们放到`GPT_SoVITS\pretrained_models`目录下
+
+ 如果想用音频超分功能缓解v3模型生成24k音频觉得闷的问题,需要下载额外的模型参数,参考[how to download](../../tools/AP_BWE_main/24kto48k/readme.txt)
+
+
## 待办事项清单
- [x] **高优先级:**
@@ -271,7 +294,7 @@ python webui.py
- [x] 改进英语和日语文本前端。
- [ ] 开发体积小和更大的 TTS 模型。
- [x] Colab 脚本。
- - [ ] 扩展训练数据集(从 2k 小时到 10k 小时)。
+ - [x] 扩展训练数据集(从 2k 小时到 10k 小时)。
- [x] 更好的 sovits 基础模型(增强的音频质量)。
- [ ] 模型混合。
@@ -282,7 +305,7 @@ python tools/uvr5/webui.py ""
````
这是使用命令行完成数据集的音频切分的方式
````
@@ -291,7 +314,7 @@ python audio_slicer.py \
--output_root "" \
--threshold \
--min_length \
- --min_interval
+ --min_interval
--hop_size
````
这是使用命令行完成数据集ASR处理的方式(仅限中文)
@@ -318,12 +341,15 @@ python ./tools/asr/fasterwhisper_asr.py -i -o