GPT-SoVITS/GPT_SoVITS/s1_train.py
huangxu1991 4f8e1660af
Add use_distributed_sampler=False in Trainer (#756)
if you have defined your own sampler, you should have to set use_distributed_sampler to False!
当使用自定义的 sampler 时,必须设置 use_distributed_sampler 为 False
2024-07-19 10:33:24 +08:00

184 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import os
import pdb
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
from pathlib import Path
import torch, platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
class my_model_ckpt(ModelCheckpoint):
def __init__(
self,
config,
if_save_latest,
if_save_every_weights,
half_weights_save_dir,
exp_name,
**kwargs
):
super().__init__(**kwargs)
self.if_save_latest = if_save_latest
self.if_save_every_weights = if_save_every_weights
self.half_weights_save_dir = half_weights_save_dir
self.exp_name = exp_name
self.config = config
def on_train_epoch_end(self, trainer, pl_module):
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
if self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if (
self._every_n_epochs >= 1
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if (
self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates)
if self.if_save_latest == True:
for name in to_clean:
try:
os.remove("%s/%s" % (self.dirpath, name))
except:
pass
if self.if_save_every_weights == True:
to_save_od = OrderedDict()
to_save_od["weight"] = OrderedDict()
dictt = trainer.strategy._lightning_module.state_dict()
for key in dictt:
to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
# torch.save(
# print(os.environ)
if(os.environ.get("LOCAL_RANK","0")=="0"):
my_save(
to_save_od,
"%s/%s-e%s.ckpt"
% (
self.half_weights_save_dir,
self.exp_name,
trainer.current_epoch + 1,
),
)
self._save_last_checkpoint(trainer, monitor_candidates)
def main(args):
config = load_yaml_config(args.config_file)
output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True)
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = my_model_ckpt(
config=config,
if_save_latest=config["train"]["if_save_latest"],
if_save_every_weights=config["train"]["if_save_every_weights"],
half_weights_save_dir=config["train"]["half_weights_save_dir"],
exp_name=config["train"]["exp_name"],
save_top_k=-1,
monitor="top_3_acc",
mode="max",
save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir,
)
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
os.environ["MASTER_ADDR"]="localhost"
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator="gpu" if torch.cuda.is_available() else "cpu",
# val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None,
limit_val_batches=0,
devices=-1 if torch.cuda.is_available() else 1,
benchmark=False,
fast_dev_run=False,
strategy = DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
) if torch.cuda.is_available() else "auto",
precision=config["train"]["precision"],
logger=logger,
num_sanity_val_steps=0,
callbacks=[ckpt_callback],
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
)
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir
)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
train_semantic_path=config["train_semantic_path"],
train_phoneme_path=config["train_phoneme_path"],
# dev_semantic_path=args.dev_semantic_path,
# dev_phoneme_path=args.dev_phoneme_path
)
try:
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
ckpt_path = ckpt_dir / newest_ckpt_name
except Exception:
ckpt_path = None
print("ckpt_path:", ckpt_path)
trainer.fit(model, data_module, ckpt_path=ckpt_path)
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_file",
type=str,
default="configs/s1longer.yaml",
help="path of config file",
)
# args for dataset
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
# parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
# parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
# parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
# parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
args = parser.parse_args()
logging.info(str(args))
main(args)