support sovits v3 lora training, 8G GPU memory is enough

support sovits v3 lora training, 8G GPU memory is enough
This commit is contained in:
RVC-Boss 2025-02-23 00:37:14 +08:00 committed by GitHub
parent 56509a17c9
commit e937b625e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 488 additions and 55 deletions

View File

@ -28,10 +28,13 @@ try:
analytics.version_check = lambda:None analytics.version_check = lambda:None
except:... except:...
version=model_version=os.environ.get("version","v2") version=model_version=os.environ.get("version","v2")
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth","GPT_SoVITS/pretrained_models/s2Gv3.pth"] path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3=os.path.exists(path_sovits_v3)
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"] pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
_ =[[],[]] _ =[[],[]]
for i in range(3): for i in range(3):
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i]) if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
@ -73,6 +76,7 @@ is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
# is_half=False
punctuation = set(['!', '?', '', ',', '.', '-'," "]) punctuation = set(['!', '?', '', ',', '.', '-'," "])
import gradio as gr import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -83,13 +87,26 @@ from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path
from GPT_SoVITS.module.models import SynthesizerTrn,SynthesizerTrnV3 from GPT_SoVITS.module.models import SynthesizerTrn,SynthesizerTrnV3
import numpy as np
import random
def set_seed(seed):
if seed == -1:
seed = random.randint(0, 1000000)
seed = int(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# set_seed(42)
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
from text.cleaner import clean_text from text.cleaner import clean_text
from time import time as ttime from time import time as ttime
from module.mel_processing import spectrogram_torch
from tools.my_utils import load_audio from tools.my_utils import load_audio
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
from peft import LoraConfig, PeftModel, get_peft_model
language=os.environ.get("language","Auto") language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
@ -192,38 +209,17 @@ def resample(audio_tensor, sr0):
).to(device) ).to(device)
return resample_transform_dict[sr0](audio_tensor) return resample_transform_dict[sr0](audio_tensor)
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
#symbol_version-model_version-if_lora_v3
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
global vq_model, hps, version, model_version, dict_language global vq_model, hps, version, model_version, dict_language,if_lora_v3
''' version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
v1:about 82942KB # print(sovits_path,version, model_version, if_lora_v3)
half thr:82978KB if if_lora_v3==True and is_exist_s2gv3==False:
v2:about 83014KB info=i18n("GPT_SoVITS/pretrained_models/s2Gv3.pth v3sovits的底模没下载对识别为v3sovits的lora没法加载")
half thr:100MB gr.Warning(info)
v1base:103490KB raise FileExistsError(info)
half thr:103520KB
v2base:103551KB
v3:about 750MB
~82978K~100M~103420~700M
v1-v2-v1base-v2base-v3
version:
symbols version and timebre_embedding version
model_version:
sovits is v1/2 (VITS) or v3 (shortcut CFM DiT)
'''
size=os.path.getsize(sovits_path)
if size<82978*1024:
model_version=version="v1"
elif size<100*1024*1024:
model_version=version="v2"
elif size<103520*1024:
model_version=version="v1"
elif size<700*1024*1024:
model_version = version = "v2"
else:
version = "v2"
model_version="v3"
dict_language = dict_language_v1 if version =='v1' else dict_language_v2 dict_language = dict_language_v1 if version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None: if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()): if prompt_language in list(dict_language.keys()):
@ -244,11 +240,13 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
visible_inp_refs=True visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False} yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False}
dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False) dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"] hps = dict_s2["config"]
hps = DictToAttrRecursive(hps) hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz" hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps.model.version = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1" hps.model.version = "v1"
else: else:
hps.model.version = "v2" hps.model.version = "v2"
@ -278,7 +276,23 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else: else:
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
if if_lora_v3==False:
print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False)) print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False))
else:
print("loading sovits_v3pretrained_G", vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False))
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
print("loading sovits_v3_lora%s"%(lora_rank),vq_model.load_state_dict(dict_s2["weight"], strict=False))
vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth")
vq_model.eval()
with open("./weight.json")as f: with open("./weight.json")as f:
data=f.read() data=f.read()
data=json.loads(data) data=json.loads(data)
@ -333,7 +347,8 @@ else:init_bigvgan()
def get_spepc(hps, filename): def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate)) # audio = load_audio(filename, int(hps.data.sampling_rate))
audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
maxx=audio.abs().max() maxx=audio.abs().max()
if(maxx>1):audio/=min(2,maxx) if(maxx>1):audio/=min(2,maxx)
@ -443,11 +458,7 @@ def get_phones_and_bert(text,language,version,final=False):
return phones,bert.to(dtype),norm_text return phones,bert.to(dtype),norm_text
from module.mel_processing import spectrogram_torch,spec_to_mel_torch from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
spec=spectrogram_torch(y,n_fft,sampling_rate,hop_size,win_size,center)
mel=spec_to_mel_torch(spec,n_fft,num_mels,sampling_rate,fmin,fmax)
return mel
mel_fn_args = { mel_fn_args = {
"n_fft": 1024, "n_fft": 1024,
"win_size": 1024, "win_size": 1024,
@ -465,7 +476,7 @@ def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1 return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x): def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram(x, **mel_fn_args) mel_fn=lambda x: mel_spectrogram_torch(x, **mel_fn_args)
def merge_short_text_in_array(texts, threshold): def merge_short_text_in_array(texts, threshold):
@ -617,6 +628,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)#######这里要重采样切到32k,因为src是24k的没有单独的32k的src所以不能改成2个路径 refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)#######这里要重采样切到32k,因为src是24k的没有单独的32k的src所以不能改成2个路径
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0) phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path) ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio=ref_audio.to(device).float() ref_audio=ref_audio.to(device).float()
@ -624,7 +636,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr!=24000: if sr!=24000:
ref_audio=resample(ref_audio,sr) ref_audio=resample(ref_audio,sr)
mel2 = mel_fn(ref_audio.to(dtype)) # print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min] mel2 = mel2[:, :, :T_min]
@ -634,7 +647,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
fea_ref = fea_ref[:, :, -468:] fea_ref = fea_ref[:, :, -468:]
T_min = 468 T_min = 468
chunk_len = 934 - T_min chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
mel2=mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge) fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = [] cfm_resss = []
idx = 0 idx = 0
while (1): while (1):
@ -642,9 +660,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if (fea_todo_chunk.shape[-1] == 0): break if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:] cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:] mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:] fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2) cmf_res = torch.cat(cfm_resss, 2)

View File

@ -14,7 +14,24 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
torch.save(fea,tmp_path) torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name)) shutil.move(tmp_path,"%s/%s"%(dir,name))
def savee(ckpt, name, epoch, steps, hps): '''
00:v1
01:v2
02:v3
03:v3lora
'''
from io import BytesIO
def my_save2(fea,path):
bio = BytesIO()
torch.save(fea, bio)
bio.seek(0)
data = bio.getvalue()
data = b'03' + data[2:]###temp for v3lora only, todo
with open(path, "wb") as f: f.write(data)
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
try: try:
opt = OrderedDict() opt = OrderedDict()
opt["weight"] = {} opt["weight"] = {}
@ -24,8 +41,66 @@ def savee(ckpt, name, epoch, steps, hps):
opt["weight"][key] = ckpt[key].half() opt["weight"][key] = ckpt[key].half()
opt["config"] = hps opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
# torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) if lora_rank:
opt["lora_rank"]=lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success." return "Success."
except: except:
return traceback.format_exc() return traceback.format_exc()
head2version={
b'00':["v1","v1",False],
b'01':["v2","v2",False],
b'02':["v2","v3",False],
b'03':["v2","v3",True],
}
hash_pretrained_dict={
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
}
import hashlib
def get_hash_from_file(sovits_path):
with open(sovits_path,"rb")as f:data=f.read(8192)
hash_md5 = hashlib.md5()
hash_md5.update(data)
return hash_md5.hexdigest()
def get_sovits_version_from_path_fast(sovits_path):
###1-if it is pretrained sovits models, by hash
hash=get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash]
###2-new weights or old weights, by head
with open(sovits_path,"rb")as f:version=f.read(2)
if version!=b"PK":
return head2version[version]
###3-old weights, by file size
if_lora_v3=False
size=os.path.getsize(sovits_path)
'''
v1weights:about 82942KB
half thr:82978KB
v2weights:about 83014KB
v3weights:about 750MB
'''
if size < 82978 * 1024:
model_version = version = "v1"
elif size < 700 * 1024 * 1024:
model_version = version = "v2"
else:
version = "v2"
model_version = "v3"
return version,model_version,if_lora_v3
def load_sovits_new(sovits_path):
f=open(sovits_path,"rb")
meta=f.read(2)
if meta!="PK":
data = b'PK' + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path,map_location="cpu", weights_only=False)

View File

@ -26,12 +26,7 @@ from AR.utils import get_newest_ckpt
from collections import OrderedDict from collections import OrderedDict
from time import time as ttime from time import time as ttime
import shutil import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path from process_ckpt import my_save
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
class my_model_ckpt(ModelCheckpoint): class my_model_ckpt(ModelCheckpoint):

View File

@ -0,0 +1,342 @@
import warnings
warnings.filterwarnings("ignore")
import utils, os
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from peft import LoraConfig, PeftModel, get_peft_model
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee
from collections import OrderedDict as od
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
def run(rank, n_gpus, hps):
global global_step,no_grad_names,save_root,lora_rank
if rank == 0:
logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps)
# utils.check_git_hash(hps.s2_ckpt_dir)
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[
32,
300,
400,
500,
600,
700,
800,
900,
1000,
# 1100,
# 1200,
# 1300,
# 1400,
# 1500,
# 1600,
# 1700,
# 1800,
# 1900,
],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=6,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=4,
)
save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
os.makedirs(save_root,exist_ok=True)
lora_rank=int(hps.train.lora_rank)
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
def get_model(hps):return SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
def get_optim(net_g):
return torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
def model2cuda(net_g,rank):
if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
return net_g
try:# 如果能加载自动resume
net_g = get_model(hps)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
optim_g=get_optim(net_g)
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(save_root, "G_*.pth"),
net_g,
optim_g,
)
global_step = (epoch_str - 1) * len(train_loader)
except: # 如果首次不能加载加载pretrain
# traceback.print_exc()
epoch_str = 1
global_step = 0
net_g = get_model(hps)
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print(
net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
optim_g = get_optim(net_g)
no_grad_names=set()
for name, param in net_g.named_parameters():
if not param.requires_grad:
no_grad_names.add(name.replace("module.",""))
# print(name, "not requires_grad")
# print(no_grad_names)
# os._exit(233333)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
for _ in range(epoch_str):
scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None],
logger,
[writer, writer_eval],
)
else:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
)
scheduler_g.step()
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
rank, non_blocking=True
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
mel, mel_lengths = mel.to(device), mel_lengths.to(device)
ssl = ssl.to(device)
ssl.requires_grad = False
text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
loss_gen_all=cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(global_step)
),
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(233333333333)
),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
sim_ckpt=od()
for key in ckpt:
# if "cfm"not in key:
# print(key)
if key not in no_grad_names:
sim_ckpt[key]=ckpt[key].half().cpu()
logger.info(
"saving ckpt %s_e%s:%s"
% (
hps.name,
epoch,
savee(
sim_ckpt,
hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
epoch,
global_step,
hps,lora_rank=lora_rank
),
)
)
if rank == 0:
logger.info("====> Epoch: {}".format(epoch))
if __name__ == "__main__":
main()