diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py
index 4917d32..7d79fbd 100644
--- a/GPT_SoVITS/inference_webui.py
+++ b/GPT_SoVITS/inference_webui.py
@@ -1,49 +1,55 @@
import os
-gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
-sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
-cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
-bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
-infer_ttswebui=os.environ.get("infer_ttswebui",9872)
-infer_ttswebui=int(infer_ttswebui)
-if("_CUDA_VISIBLE_DEVICES"in os.environ):
- os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
-is_half=eval(os.environ.get("is_half","True"))
+
+gpt_path = os.environ.get(
+ "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
+)
+sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
+cnhubert_base_path = os.environ.get(
+ "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
+)
+bert_path = os.environ.get(
+ "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
+)
+infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
+infer_ttswebui = int(infer_ttswebui)
+if "_CUDA_VISIBLE_DEVICES" in os.environ:
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
+is_half = eval(os.environ.get("is_half", "True"))
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
-import sys,torch,numpy as np
-from pathlib import Path
-import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
+import torch, numpy as np
+import os, librosa, torch
+
# torch.backends.cuda.sdp_kernel("flash")
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
# torch.backends.cuda.enable_math_sdp(True)
-from random import shuffle
-from AR.utils import get_newest_ckpt
-from glob import glob
-from tqdm import tqdm
from feature_extractor import cnhubert
-cnhubert.cnhubert_base_path=cnhubert_base_path
-from io import BytesIO
+
+cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
-from AR.utils.io import load_yaml_config
from text import cleaned_text_to_sequence
-from text.cleaner import text_to_sequence, clean_text
+from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
-device="cuda"
+device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
-bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
-if(is_half==True):bert_model=bert_model.half().to(device)
-else:bert_model=bert_model.to(device)
+bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
+if is_half == True:
+ bert_model = bert_model.half().to(device)
+else:
+ bert_model = bert_model.to(device)
+
+
# bert_model=bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
- inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
+ inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
@@ -55,9 +61,12 @@ def get_bert_feature(text, word2ph):
# if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
+
n_semantic = 1024
-dict_s2=torch.load(sovits_path,map_location="cpu")
-hps=dict_s2["config"]
+dict_s2 = torch.load(sovits_path, map_location="cpu")
+hps = dict_s2["config"]
+
+
class DictToAttrRecursive:
def __init__(self, input_dict):
for key, value in input_dict.items():
@@ -67,206 +76,271 @@ class DictToAttrRecursive:
else:
setattr(self, key, value)
+
hps = DictToAttrRecursive(hps)
-hps.model.semantic_frame_rate="25hz"
-dict_s1=torch.load(gpt_path,map_location="cpu")
-config=dict_s1["config"]
-ssl_model=cnhubert.get_model()
-if(is_half==True):ssl_model=ssl_model.half().to(device)
-else:ssl_model=ssl_model.to(device)
+hps.model.semantic_frame_rate = "25hz"
+dict_s1 = torch.load(gpt_path, map_location="cpu")
+config = dict_s1["config"]
+ssl_model = cnhubert.get_model()
+if is_half == True:
+ ssl_model = ssl_model.half().to(device)
+else:
+ ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
- **hps.model)
-if(is_half==True):vq_model=vq_model.half().to(device)
-else:vq_model=vq_model.to(device)
+ **hps.model
+)
+if is_half == True:
+ vq_model = vq_model.half().to(device)
+else:
+ vq_model = vq_model.to(device)
vq_model.eval()
-print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
+print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50
-max_sec = config['data']['max_sec']
+max_sec = config["data"]["max_sec"]
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
-t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
+t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
-if(is_half==True):t2s_model=t2s_model.half()
-t2s_model=t2s_model.to(device)
+if is_half == True:
+ t2s_model = t2s_model.half()
+t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
+
+
def get_spepc(hps, filename):
- audio=load_audio(filename,int(hps.data.sampling_rate))
- audio=torch.FloatTensor(audio)
+ audio = load_audio(filename, int(hps.data.sampling_rate))
+ audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
- spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
+ spec = spectrogram_torch(
+ audio_norm,
+ hps.data.filter_length,
+ hps.data.sampling_rate,
+ hps.data.hop_length,
+ hps.data.win_length,
+ center=False,
+ )
return spec
-dict_language={
- "中文":"zh",
- "英文":"en",
- "日文":"ja"
-}
-def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
+
+dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
+
+
+def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
- prompt_text=prompt_text.strip("\n")
- prompt_language,text=prompt_language,text.strip("\n")
+ prompt_text = prompt_text.strip("\n")
+ prompt_language, text = prompt_language, text.strip("\n")
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
wav16k = torch.from_numpy(wav16k)
- if(is_half==True):wav16k=wav16k.half().to(device)
- else:wav16k=wav16k.to(device)
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
+ if is_half == True:
+ wav16k = wav16k.half().to(device)
+ else:
+ wav16k = wav16k.to(device)
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
+ "last_hidden_state"
+ ].transpose(
+ 1, 2
+ ) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
- prompt_language=dict_language[prompt_language]
- text_language=dict_language[text_language]
+ prompt_language = dict_language[prompt_language]
+ text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
- phones1=cleaned_text_to_sequence(phones1)
- texts=text.split("\n")
+ phones1 = cleaned_text_to_sequence(phones1)
+ texts = text.split("\n")
audio_opt = []
- zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
+ zero_wav = np.zeros(
+ int(hps.data.sampling_rate * 0.3),
+ dtype=np.float16 if is_half == True else np.float32,
+ )
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
- if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
- else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
- if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
- else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
+ if prompt_language == "zh":
+ bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
+ else:
+ bert1 = torch.zeros(
+ (1024, len(phones1)),
+ dtype=torch.float16 if is_half == True else torch.float32,
+ ).to(device)
+ if text_language == "zh":
+ bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
+ else:
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
- pred_semantic,idx = t2s_model.model.infer_panel(
+ pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
- top_k=config['inference']['top_k'],
- early_stop_num=hz * max_sec)
+ top_k=config["inference"]["top_k"],
+ early_stop_num=hz * max_sec,
+ )
t3 = ttime()
# print(pred_semantic.shape,idx)
- pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
- refer = get_spepc(hps, ref_wav_path)#.to(device)
- if(is_half==True):refer=refer.half().to(device)
- else:refer=refer.to(device)
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
+ 0
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
+ refer = get_spepc(hps, ref_wav_path) # .to(device)
+ if is_half == True:
+ refer = refer.half().to(device)
+ else:
+ refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
+ audio = (
+ vq_model.decode(
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
+ )
+ .detach()
+ .cpu()
+ .numpy()[0, 0]
+ ) ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
- yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
+ np.int16
+ )
+
+
+splits = {
+ ",",
+ "。",
+ "?",
+ "!",
+ ",",
+ ".",
+ "?",
+ "!",
+ "~",
+ ":",
+ ":",
+ "—",
+ "…",
+} # 不考虑省略号
-splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
def split(todo_text):
todo_text = todo_text.replace("……", "。").replace("——", ",")
- if (todo_text[-1] not in splits): todo_text += "。"
+ if todo_text[-1] not in splits:
+ todo_text += "。"
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
- while (1):
- if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
- if (todo_text[i_split_head] in splits):
+ while 1:
+ if i_split_head >= len_text:
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
+ if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
+
+
def cut1(inp):
- inp=inp.strip("\n")
- inps=split(inp)
- split_idx=list(range(0,len(inps),5))
- split_idx[-1]=None
- if(len(split_idx)>1):
- opts=[]
- for idx in range(len(split_idx)-1):
- opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
+ inp = inp.strip("\n")
+ inps = split(inp)
+ split_idx = list(range(0, len(inps), 5))
+ split_idx[-1] = None
+ if len(split_idx) > 1:
+ opts = []
+ for idx in range(len(split_idx) - 1):
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
- opts=[inp]
+ opts = [inp]
return "\n".join(opts)
+
def cut2(inp):
- inp=inp.strip("\n")
- inps=split(inp)
- if(len(inps)<2):return [inp]
- opts=[]
- summ=0
- tmp_str=""
+ inp = inp.strip("\n")
+ inps = split(inp)
+ if len(inps) < 2:
+ return [inp]
+ opts = []
+ summ = 0
+ tmp_str = ""
for i in range(len(inps)):
- summ+=len(inps[i])
- tmp_str+=inps[i]
- if(summ>50):
- summ=0
+ summ += len(inps[i])
+ tmp_str += inps[i]
+ if summ > 50:
+ summ = 0
opts.append(tmp_str)
- tmp_str=""
- if(tmp_str!=""):opts.append(tmp_str)
- if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
- opts[-2]=opts[-2]+opts[-1]
- opts=opts[:-1]
+ tmp_str = ""
+ if tmp_str != "":
+ opts.append(tmp_str)
+ if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
+ opts[-2] = opts[-2] + opts[-1]
+ opts = opts[:-1]
return "\n".join(opts)
+
def cut3(inp):
- inp=inp.strip("\n")
- return "\n".join(["%s。"%item for item in inp.strip("。").split("。")])
+ inp = inp.strip("\n")
+ return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
+
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
- value=
- "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
+ value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
)
# with gr.Tabs():
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
with gr.Group():
- gr.Markdown(
- value=
- "*请上传并填写参考信息"
- )
+ gr.Markdown(value="*请上传并填写参考信息")
with gr.Row():
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
- prompt_text= gr.Textbox(label="参考音频的文本",value="")
- prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文")
- gr.Markdown(
- value=
- "*请填写需要合成的目标文本"
- )
+ prompt_text = gr.Textbox(label="参考音频的文本", value="")
+ prompt_language = gr.Dropdown(
+ label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
+ )
+ gr.Markdown(value="*请填写需要合成的目标文本")
with gr.Row():
- text=gr.Textbox(label="需要合成的文本",value="")
- text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文")
- inference_button=gr.Button("合成语音", variant="primary")
+ text = gr.Textbox(label="需要合成的文本", value="")
+ text_language = gr.Dropdown(
+ label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
+ )
+ inference_button = gr.Button("合成语音", variant="primary")
output = gr.Audio(label="输出的语音")
- inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
-
- gr.Markdown(
- value=
- "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
+ inference_button.click(
+ get_tts_wav,
+ [inp_ref, prompt_text, prompt_language, text, text_language],
+ [output],
)
+
+ gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
with gr.Row():
- text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
+ text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
button1 = gr.Button("凑五句一切", variant="primary")
button2 = gr.Button("凑50字一切", variant="primary")
button3 = gr.Button("按中文句号。切", variant="primary")
text_opt = gr.Textbox(label="切分后文本", value="")
- button1.click(cut1,[text_inp],[text_opt])
- button2.click(cut2,[text_inp],[text_opt])
- button3.click(cut3,[text_inp],[text_opt])
- gr.Markdown(
- value=
- "后续将支持混合语种编码文本输入。"
- )
+ button1.click(cut1, [text_inp], [text_opt])
+ button2.click(cut2, [text_inp], [text_opt])
+ button3.click(cut3, [text_inp], [text_opt])
+ gr.Markdown(value="后续将支持混合语种编码文本输入。")
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
server_port=infer_ttswebui,
quiet=True,
-)
\ No newline at end of file
+)
diff --git a/GPT_SoVITS/process_ckpt.py b/GPT_SoVITS/process_ckpt.py
index 170dbb3..7483337 100644
--- a/GPT_SoVITS/process_ckpt.py
+++ b/GPT_SoVITS/process_ckpt.py
@@ -1,11 +1,12 @@
-import os
-import sys
import traceback
from collections import OrderedDict
import torch
from tools.i18n.i18n import I18nAuto
+
i18n = I18nAuto()
+
+
def savee(ckpt, name, epoch, steps, hps):
try:
opt = OrderedDict()
@@ -15,8 +16,8 @@ def savee(ckpt, name, epoch, steps, hps):
continue
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps
- opt["info"] = "%sepoch_%siteration" % (epoch,steps)
- torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir,name))
+ opt["info"] = "%sepoch_%siteration" % (epoch, steps)
+ torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()
diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py
index 37166cb..4a77006 100644
--- a/GPT_SoVITS/s1_train.py
+++ b/GPT_SoVITS/s1_train.py
@@ -2,56 +2,84 @@
import os
import pdb
-if("_CUDA_VISIBLE_DEVICES"in os.environ):
- os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
+if "_CUDA_VISIBLE_DEVICES" in os.environ:
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
from pathlib import Path
-import torch,platform
+import torch, platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger
+from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
-logging.getLogger('numba').setLevel(logging.WARNING)
-logging.getLogger('matplotlib').setLevel(logging.WARNING)
-torch.set_float32_matmul_precision('high')
+
+logging.getLogger("numba").setLevel(logging.WARNING)
+logging.getLogger("matplotlib").setLevel(logging.WARNING)
+torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict
+
+
class my_model_ckpt(ModelCheckpoint):
- def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
+ def __init__(
+ self,
+ config,
+ if_save_latest,
+ if_save_every_weights,
+ half_weights_save_dir,
+ exp_name,
+ **kwargs
+ ):
super().__init__(**kwargs)
- self.if_save_latest=if_save_latest
- self.if_save_every_weights=if_save_every_weights
- self.half_weights_save_dir=half_weights_save_dir
- self.exp_name=exp_name
- self.config=config
+ self.if_save_latest = if_save_latest
+ self.if_save_every_weights = if_save_every_weights
+ self.half_weights_save_dir = half_weights_save_dir
+ self.exp_name = exp_name
+ self.config = config
def on_train_epoch_end(self, trainer, pl_module):
- if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
+ if not self._should_skip_saving_checkpoint(
+ trainer
+ ) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
- if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
- if(self.if_save_latest==True):####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
- to_clean=list(os.listdir(self.dirpath))
+ if (
+ self._every_n_epochs >= 1
+ and (trainer.current_epoch + 1) % self._every_n_epochs == 0
+ ):
+ if (
+ self.if_save_latest == True
+ ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
+ to_clean = list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates)
- if (self.if_save_latest == True):
+ if self.if_save_latest == True:
for name in to_clean:
try:
- os.remove("%s/%s"%(self.dirpath,name))
- except:pass
- if(self.if_save_every_weights==True):
- to_save_od=OrderedDict()
- to_save_od["weight"]=OrderedDict()
- dictt=trainer.strategy._lightning_module.state_dict()
- for key in dictt:to_save_od["weight"][key]=dictt[key].half()
- to_save_od["config"]=self.config
- to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
- torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
+ os.remove("%s/%s" % (self.dirpath, name))
+ except:
+ pass
+ if self.if_save_every_weights == True:
+ to_save_od = OrderedDict()
+ to_save_od["weight"] = OrderedDict()
+ dictt = trainer.strategy._lightning_module.state_dict()
+ for key in dictt:
+ to_save_od["weight"][key] = dictt[key].half()
+ to_save_od["config"] = self.config
+ to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
+ torch.save(
+ to_save_od,
+ "%s/%s-e%s.ckpt"
+ % (
+ self.half_weights_save_dir,
+ self.exp_name,
+ trainer.current_epoch + 1,
+ ),
+ )
self._save_last_checkpoint(trainer, monitor_candidates)
@@ -61,41 +89,45 @@ def main(args):
output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
- ckpt_dir = output_dir / 'ckpt'
+ ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True)
-
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = my_model_ckpt(
config=config,
- if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
+ if_save_latest=config["train"]["if_save_latest"],
+ if_save_every_weights=config["train"]["if_save_every_weights"],
+ half_weights_save_dir=config["train"]["half_weights_save_dir"],
+ exp_name=config["train"]["exp_name"],
save_top_k=-1,
- monitor='top_3_acc',
- mode='max',
+ monitor="top_3_acc",
+ mode="max",
save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir,
)
- logger = TensorBoardLogger(
- name=output_dir.stem,
- save_dir=output_dir
- )
+ logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
- accelerator='gpu',
+ accelerator="gpu",
# val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None,
limit_val_batches=0,
devices=-1,
benchmark=False,
fast_dev_run=False,
- strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
+ strategy=DDPStrategy(
+ process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
+ ),
precision=config["train"]["precision"],
- logger=logger,num_sanity_val_steps=0,
- callbacks=[ckpt_callback])
+ logger=logger,
+ num_sanity_val_steps=0,
+ callbacks=[ckpt_callback],
+ )
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
- config, output_dir)
+ config, output_dir
+ )
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
@@ -116,14 +148,15 @@ def main(args):
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- '-c',
- '--config_file',
+ "-c",
+ "--config_file",
type=str,
- default='configs/s1longer.yaml',
- help='path of config file')
+ default="configs/s1longer.yaml",
+ help="path of config file",
+ )
# args for dataset
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py
index 7a455eb..d2ec262 100644
--- a/GPT_SoVITS/s2_train.py
+++ b/GPT_SoVITS/s2_train.py
@@ -1,4 +1,5 @@
-import utils,os
+import utils, os
+
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch
@@ -6,11 +7,12 @@ from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
-import torch.distributed as dist,traceback
+import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
-import logging,traceback
+import logging, traceback
+
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
@@ -20,37 +22,42 @@ from module import commons
from module.data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
- DistributedBucketSampler
+ DistributedBucketSampler,
)
from module.models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
-from module.losses import (
- generator_loss,
- discriminator_loss,
- feature_loss,
- kl_loss
-)
+from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee
+
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快,那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
-torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响
+torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
+
+
def main():
"""Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed."
n_gpus = torch.cuda.device_count()
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = str(randint(20000, 55555))
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
- mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
+ mp.spawn(
+ run,
+ nprocs=n_gpus,
+ args=(
+ n_gpus,
+ hps,
+ ),
+ )
def run(rank, n_gpus, hps):
@@ -62,21 +69,54 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
- dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank)
+ dist.init_process_group(
+ backend="gloo" if os.name == "nt" else "nccl",
+ init_method="env://",
+ world_size=n_gpus,
+ rank=rank,
+ )
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
- train_dataset = TextAudioSpeakerLoader(hps.data)########
+ train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
- [32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
+ [
+ 32,
+ 300,
+ 400,
+ 500,
+ 600,
+ 700,
+ 800,
+ 900,
+ 1000,
+ 1100,
+ 1200,
+ 1300,
+ 1400,
+ 1500,
+ 1600,
+ 1700,
+ 1800,
+ 1900,
+ ],
num_replicas=n_gpus,
rank=rank,
- shuffle=True)
+ shuffle=True,
+ )
collate_fn = TextAudioSpeakerCollate()
- train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True,
- collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16)
+ train_loader = DataLoader(
+ train_dataset,
+ num_workers=6,
+ shuffle=False,
+ pin_memory=True,
+ collate_fn=collate_fn,
+ batch_sampler=train_sampler,
+ persistent_workers=True,
+ prefetch_factor=16,
+ )
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
@@ -87,17 +127,21 @@ def run(rank, n_gpus, hps):
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
- **hps.model).cuda(rank)
+ **hps.model,
+ ).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
for name, param in net_g.named_parameters():
if not param.requires_grad:
- print(name,"not requires_grad")
+ print(name, "not requires_grad")
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
- base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters())
+ base_params = filter(
+ lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
+ net_g.parameters(),
+ )
# te_p=net_g.enc_p.text_embedding.parameters()
# et_p=net_g.enc_p.encoder_text.parameters()
@@ -106,31 +150,46 @@ def run(rank, n_gpus, hps):
optim_g = torch.optim.AdamW(
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
[
- {"params":base_params,"lr":hps.train.learning_rate},
- {"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
- {"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
- {"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
+ {"params": base_params, "lr": hps.train.learning_rate},
+ {
+ "params": net_g.enc_p.text_embedding.parameters(),
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
+ },
+ {
+ "params": net_g.enc_p.encoder_text.parameters(),
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
+ },
+ {
+ "params": net_g.enc_p.mrte.parameters(),
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
+ },
],
hps.train.learning_rate,
betas=hps.train.betas,
- eps=hps.train.eps)
+ eps=hps.train.eps,
+ )
optim_d = torch.optim.AdamW(
net_d.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
- eps=hps.train.eps)
- net_g = DDP(net_g, device_ids=[rank],find_unused_parameters=True)
- net_d = DDP(net_d, device_ids=[rank],find_unused_parameters=True)
+ eps=hps.train.eps,
+ )
+ net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
+ net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d
+ utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
+ net_d,
+ optim_d,
) # D多半加载没事
if rank == 0:
logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g
+ utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
+ net_g,
+ optim_g,
)
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
@@ -144,7 +203,8 @@ def run(rank, n_gpus, hps):
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print(
net_g.module.load_state_dict(
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
+ strict=False,
)
) ##测试不加载优化器
if hps.train.pretrained_s2D != "":
@@ -159,8 +219,12 @@ def run(rank, n_gpus, hps):
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1)
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
+ optim_g, gamma=hps.train.lr_decay, last_epoch=-1
+ )
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
+ optim_d, gamma=hps.train.lr_decay, last_epoch=-1
+ )
for _ in range(epoch_str):
scheduler_g.step()
scheduler_d.step()
@@ -169,17 +233,39 @@ def run(rank, n_gpus, hps):
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
- train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
- # [train_loader, eval_loader], logger, [writer, writer_eval])
- [train_loader, None], logger, [writer, writer_eval])
+ train_and_evaluate(
+ rank,
+ epoch,
+ hps,
+ [net_g, net_d],
+ [optim_g, optim_d],
+ [scheduler_g, scheduler_d],
+ scaler,
+ # [train_loader, eval_loader], logger, [writer, writer_eval])
+ [train_loader, None],
+ logger,
+ [writer, writer_eval],
+ )
else:
- train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
- [train_loader, None], None, None)
+ train_and_evaluate(
+ rank,
+ epoch,
+ hps,
+ [net_g, net_d],
+ [optim_g, optim_d],
+ [scheduler_g, scheduler_d],
+ scaler,
+ [train_loader, None],
+ None,
+ None,
+ )
scheduler_g.step()
scheduler_d.step()
-def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
+def train_and_evaluate(
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
+):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train()
net_d.train()
- for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)):
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
- y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
+ for batch_idx, (
+ ssl,
+ ssl_lengths,
+ spec,
+ spec_lengths,
+ y,
+ y_lengths,
+ text,
+ text_lengths,
+ ) in tqdm(enumerate(train_loader)):
+ spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
+ rank, non_blocking=True
+ )
+ y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
+ rank, non_blocking=True
+ )
ssl = ssl.cuda(rank, non_blocking=True)
- ssl.requires_grad=False
+ ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
- text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True)
+ text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
+ rank, non_blocking=True
+ )
with autocast(enabled=hps.train.fp16_run):
- y_hat, kl_ssl, ids_slice, x_mask, z_mask, \
- (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths)
+ (
+ y_hat,
+ kl_ssl,
+ ids_slice,
+ x_mask,
+ z_mask,
+ (z, z_p, m_p, logs_p, m_q, logs_q),
+ stats_ssl,
+ ) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch(
spec,
@@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
- hps.data.mel_fmax)
- y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
+ hps.data.mel_fmax,
+ )
+ y_mel = commons.slice_segments(
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
+ )
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
@@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
- hps.data.mel_fmax
+ hps.data.mel_fmax,
)
- y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
+ y = commons.slice_segments(
+ y, ids_slice * hps.data.hop_length, hps.train.segment_size
+ ) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
+ y_d_hat_r, y_d_hat_g
+ )
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
@@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if rank == 0:
if global_step % hps.train.log_interval == 0:
- lr = optim_g.param_groups[0]['lr']
+ lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
- logger.info('Train Epoch: {} [{:.0f}%]'.format(
- epoch,
- 100. * batch_idx / len(train_loader)))
+ logger.info(
+ "Train Epoch: {} [{:.0f}%]".format(
+ epoch, 100.0 * batch_idx / len(train_loader)
+ )
+ )
logger.info([x.item() for x in losses] + [global_step, lr])
- scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
- "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
+ scalar_dict = {
+ "loss/g/total": loss_gen_all,
+ "loss/d/total": loss_disc_all,
+ "learning_rate": lr,
+ "grad_norm_d": grad_norm_d,
+ "grad_norm_g": grad_norm_g,
+ }
scalar_dict.update(
- {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl})
+ {
+ "loss/g/fm": loss_fm,
+ "loss/g/mel": loss_mel,
+ "loss/g/kl_ssl": kl_ssl,
+ "loss/g/kl": loss_kl,
+ }
+ )
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = {
- "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
- "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
- "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
+ y_mel[0].data.cpu().numpy()
+ ),
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
+ y_hat_mel[0].data.cpu().numpy()
+ ),
+ "all/mel": utils.plot_spectrogram_to_numpy(
+ mel[0].data.cpu().numpy()
+ ),
+ "all/stats_ssl": utils.plot_spectrogram_to_numpy(
+ stats_ssl[0].data.cpu().numpy()
+ ),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
- scalars=scalar_dict)
+ scalars=scalar_dict,
+ )
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
@@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g,
hps.train.learning_rate,
epoch,
- os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)),
+ os.path.join(
+ "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
+ ),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
- os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)),
+ os.path.join(
+ "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
+ ),
)
else:
utils.save_checkpoint(
@@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g,
hps.train.learning_rate,
epoch,
- os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)),
+ os.path.join(
+ "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
+ ),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
- os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)),
+ os.path.join(
+ "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
+ ),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
@@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
)
)
-
-
if rank == 0:
- logger.info('====> Epoch: {}'.format(epoch))
-
+ logger.info("====> Epoch: {}".format(epoch))
def evaluate(hps, generator, eval_loader, writer_eval):
@@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
audio_dict = {}
print("Evaluating ...")
with torch.no_grad():
- for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader):
+ for batch_idx, (
+ ssl,
+ ssl_lengths,
+ spec,
+ spec_lengths,
+ y,
+ y_lengths,
+ text,
+ text_lengths,
+ ) in enumerate(eval_loader):
print(111)
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda()
for test in [0, 1]:
-
- y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test)
+ y_hat, mask, *_ = generator.module.infer(
+ ssl, spec, spec_lengths, text, text_lengths, test=test
+ )
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel_torch(
@@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
- hps.data.mel_fmax)
+ hps.data.mel_fmax,
+ )
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(),
hps.data.filter_length,
@@ -373,16 +526,26 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
- hps.data.mel_fmax
+ hps.data.mel_fmax,
)
- image_dict.update({
- f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
- })
- audio_dict.update({
- f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]]
- })
- image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
- audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :y_lengths[0]]})
+ image_dict.update(
+ {
+ f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
+ y_hat_mel[0].cpu().numpy()
+ )
+ }
+ )
+ audio_dict.update(
+ {f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
+ )
+ image_dict.update(
+ {
+ f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
+ mel[0].cpu().numpy()
+ )
+ }
+ )
+ audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
# audio_dict.update({
@@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
global_step=global_step,
images=image_dict,
audios=audio_dict,
- audio_sampling_rate=hps.data.sampling_rate
+ audio_sampling_rate=hps.data.sampling_rate,
)
generator.train()
+
if __name__ == "__main__":
main()
diff --git a/GPT_SoVITS/utils.py b/GPT_SoVITS/utils.py
index e3ed89b..0ce03b3 100644
--- a/GPT_SoVITS/utils.py
+++ b/GPT_SoVITS/utils.py
@@ -12,8 +12,9 @@ import numpy as np
from scipy.io.wavfile import read
import torch
import logging
-logging.getLogger('numba').setLevel(logging.ERROR)
-logging.getLogger('matplotlib').setLevel(logging.ERROR)
+
+logging.getLogger("numba").setLevel(logging.ERROR)
+logging.getLogger("matplotlib").setLevel(logging.ERROR)
MATPLOTLIB_FLAG = False
@@ -23,13 +24,17 @@ logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
- iteration = checkpoint_dict['iteration']
- learning_rate = checkpoint_dict['learning_rate']
- if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
- saved_state_dict = checkpoint_dict['model']
- if hasattr(model, 'module'):
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
+ iteration = checkpoint_dict["iteration"]
+ learning_rate = checkpoint_dict["learning_rate"]
+ if (
+ optimizer is not None
+ and not skip_optimizer
+ and checkpoint_dict["optimizer"] is not None
+ ):
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
+ saved_state_dict = checkpoint_dict["model"]
+ if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
@@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
# assert "quantizer" not in k
# print("load", k)
new_state_dict[k] = saved_state_dict[k]
- assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
+ assert saved_state_dict[k].shape == v.shape, (
+ saved_state_dict[k].shape,
+ v.shape,
+ )
except:
traceback.print_exc()
- print("error, %s is not in the checkpoint" % k)#shape不对也会,比如text_embedding当cleaner修改时
+ print(
+ "error, %s is not in the checkpoint" % k
+ ) # shape不对也会,比如text_embedding当cleaner修改时
new_state_dict[k] = v
- if hasattr(model, 'module'):
+ if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
print("load ")
- logger.info("Loaded checkpoint '{}' (iteration {})".format(
- checkpoint_path, iteration))
+ logger.info(
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
+ )
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
- logger.info("Saving model and optimizer state at iteration {} to {}".format(
- iteration, checkpoint_path))
- if hasattr(model, 'module'):
+ logger.info(
+ "Saving model and optimizer state at iteration {} to {}".format(
+ iteration, checkpoint_path
+ )
+ )
+ if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
- torch.save({'model': state_dict,
- 'iteration': iteration,
- 'optimizer': optimizer.state_dict(),
- 'learning_rate': learning_rate}, checkpoint_path)
+ torch.save(
+ {
+ "model": state_dict,
+ "iteration": iteration,
+ "optimizer": optimizer.state_dict(),
+ "learning_rate": learning_rate,
+ },
+ checkpoint_path,
+ )
-def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
+def summarize(
+ writer,
+ global_step,
+ scalars={},
+ histograms={},
+ images={},
+ audios={},
+ audio_sampling_rate=22050,
+):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
- writer.add_image(k, v, global_step, dataformats='HWC')
+ writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
@@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
+
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
- interpolation='none')
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
@@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
+
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
- im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
- interpolation='none')
+ im = ax.imshow(
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
+ )
fig.colorbar(im, ax=ax)
- xlabel = 'Decoder timestep'
+ xlabel = "Decoder timestep"
if info is not None:
- xlabel += '\n\n' + info
+ xlabel += "\n\n" + info
plt.xlabel(xlabel)
- plt.ylabel('Encoder timestep')
+ plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
@@ -147,16 +176,31 @@ def load_wav_to_torch(full_path):
def load_filepaths_and_text(filename, split="|"):
- with open(filename, encoding='utf-8') as f:
+ with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True, stage=1):
parser = argparse.ArgumentParser()
- parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration')
- parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir')
- parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step')
+ parser.add_argument(
+ "-c",
+ "--config",
+ type=str,
+ default="./configs/s2.json",
+ help="JSON file for configuration",
+ )
+ parser.add_argument(
+ "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
+ )
+ parser.add_argument(
+ "-rs",
+ "--resume_step",
+ type=int,
+ required=False,
+ default=None,
+ help="resume step",
+ )
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
@@ -172,7 +216,7 @@ def get_hparams(init=True, stage=1):
hparams.pretrain = args.pretrain
hparams.resume_step = args.resume_step
# hparams.data.exp_dir = args.exp_dir
- if stage ==1:
+ if stage == 1:
model_dir = hparams.s1_ckpt_dir
else:
model_dir = hparams.s2_ckpt_dir
@@ -186,29 +230,38 @@ def get_hparams(init=True, stage=1):
return hparams
-
-def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
+def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
- Arguments:
- path_to_models -- Path to the model directory
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
- sort_by_time -- True -> chronologically delete ckpts
- False -> lexicographically delete ckpts
- """
+ Arguments:
+ path_to_models -- Path to the model directory
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
+ sort_by_time -- True -> chronologically delete ckpts
+ False -> lexicographically delete ckpts
+ """
import re
- ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
- name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
- time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
+
+ ckpts_files = [
+ f
+ for f in os.listdir(path_to_models)
+ if os.path.isfile(os.path.join(path_to_models, f))
+ ]
+ name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
+ time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
- x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
- key=sort_key)
- to_del = [os.path.join(path_to_models, fn) for fn in
- (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
+ x_sorted = lambda _x: sorted(
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
+ key=sort_key,
+ )
+ to_del = [
+ os.path.join(path_to_models, fn)
+ for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
+ ]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
+
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
@@ -228,12 +281,15 @@ def get_hparams_from_file(config_path):
hparams = HParams(**config)
return hparams
+
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
- logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
- source_dir
- ))
+ logger.warn(
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
+ source_dir
+ )
+ )
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
@@ -242,8 +298,11 @@ def check_git_hash(model_dir):
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
- logger.warn("git hash values are different. {}(saved) != {}(current)".format(
- saved_hash[:8], cur_hash[:8]))
+ logger.warn(
+ "git hash values are different. {}(saved) != {}(current)".format(
+ saved_hash[:8], cur_hash[:8]
+ )
+ )
else:
open(path, "w").write(cur_hash)
@@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"):
return logger
-class HParams():
+class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
@@ -294,5 +353,10 @@ class HParams():
def __repr__(self):
return self.__dict__.__repr__()
-if __name__ == '__main__':
- print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac'))
\ No newline at end of file
+
+if __name__ == "__main__":
+ print(
+ load_wav_to_torch(
+ "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
+ )
+ )