mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-30 01:25:58 +08:00
410 lines
13 KiB
Python
410 lines
13 KiB
Python
import logging
|
||
import os
|
||
import platform
|
||
import sys
|
||
import warnings
|
||
from contextlib import nullcontext
|
||
from random import randint
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
||
from torch.amp.autocast_mode import autocast
|
||
from torch.amp.grad_scaler import GradScaler
|
||
from torch.multiprocessing.spawn import spawn
|
||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
from torch.utils.data import DataLoader
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
import GPT_SoVITS.utils as utils
|
||
from GPT_SoVITS.Accelerate import console, logger
|
||
from GPT_SoVITS.Accelerate.logger import SpeedColumnIteration
|
||
from GPT_SoVITS.module import commons
|
||
from GPT_SoVITS.module.data_utils import (
|
||
DistributedBucketSampler,
|
||
TextAudioSpeakerCollateV3,
|
||
TextAudioSpeakerLoaderV3,
|
||
)
|
||
from GPT_SoVITS.module.models import SynthesizerTrnV3
|
||
from GPT_SoVITS.process_ckpt import save_ckpt
|
||
|
||
hps = utils.get_hparams(stage=2)
|
||
|
||
warnings.filterwarnings("ignore")
|
||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||
logging.getLogger("numba").setLevel(logging.INFO)
|
||
|
||
torch.backends.cudnn.benchmark = False
|
||
torch.backends.cudnn.deterministic = False
|
||
torch.backends.cuda.matmul.allow_tf32 = True ###反正A100fp32更快,那试试tf32吧
|
||
torch.backends.cudnn.allow_tf32 = True
|
||
torch.set_grad_enabled(True)
|
||
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
||
|
||
global_step = 0
|
||
|
||
if torch.cuda.is_available():
|
||
device_str = "cuda"
|
||
elif torch.mps.is_available():
|
||
device_str = "mps"
|
||
else:
|
||
device_str = "cpu"
|
||
|
||
|
||
multigpu = torch.cuda.device_count() > 1 if torch.cuda.is_available() else False
|
||
|
||
|
||
def main():
|
||
if torch.cuda.is_available():
|
||
n_gpus = torch.cuda.device_count()
|
||
else:
|
||
n_gpus = 1
|
||
os.environ["MASTER_ADDR"] = "localhost"
|
||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||
if platform.system() == "Windows":
|
||
os.environ["USE_LIBUV"] = "0"
|
||
|
||
spawn(
|
||
run,
|
||
nprocs=n_gpus,
|
||
args=(
|
||
n_gpus,
|
||
hps,
|
||
),
|
||
)
|
||
|
||
|
||
def run(rank, n_gpus, hps):
|
||
global global_step
|
||
device = torch.device(f"{device_str}:{rank}")
|
||
|
||
if rank == 0:
|
||
logger.add(
|
||
os.path.join(hps.data.exp_dir, "train.log"),
|
||
level="INFO",
|
||
enqueue=True,
|
||
backtrace=True,
|
||
diagnose=True,
|
||
format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}",
|
||
)
|
||
console.print(hps.to_dict())
|
||
writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||
writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||
else:
|
||
writer = writer_eval = None
|
||
|
||
if multigpu:
|
||
dist.init_process_group(
|
||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||
init_method="env://",
|
||
world_size=n_gpus,
|
||
rank=rank,
|
||
)
|
||
|
||
torch.manual_seed(hps.train.seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.set_device(rank)
|
||
|
||
train_dataset = TextAudioSpeakerLoaderV3(hps.data)
|
||
train_sampler = DistributedBucketSampler(
|
||
train_dataset,
|
||
hps.train.batch_size,
|
||
[
|
||
32,
|
||
300,
|
||
400,
|
||
500,
|
||
600,
|
||
700,
|
||
800,
|
||
900,
|
||
1000,
|
||
],
|
||
num_replicas=n_gpus,
|
||
rank=rank,
|
||
shuffle=True,
|
||
)
|
||
collate_fn = TextAudioSpeakerCollateV3()
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
num_workers=6,
|
||
shuffle=False,
|
||
pin_memory=True,
|
||
collate_fn=collate_fn,
|
||
batch_sampler=train_sampler,
|
||
persistent_workers=True,
|
||
prefetch_factor=4,
|
||
)
|
||
# if rank == 0:
|
||
# eval_dataset = TextAudioSpeakerLoaderV3(hps.data.validation_files, hps.data, val=True)
|
||
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
|
||
# batch_size=1, pin_memory=True,
|
||
# drop_last=False, collate_fn=collate_fn)
|
||
|
||
net_g = SynthesizerTrnV3(
|
||
hps.data.filter_length // 2 + 1,
|
||
hps.train.segment_size // hps.data.hop_length,
|
||
n_speakers=hps.data.n_speakers,
|
||
**hps.model,
|
||
).to(device)
|
||
|
||
optim_g = torch.optim.AdamW(
|
||
filter(lambda p: p.requires_grad, net_g.parameters()), # 默认所有层lr一致
|
||
hps.train.learning_rate,
|
||
betas=hps.train.betas,
|
||
eps=hps.train.eps,
|
||
)
|
||
|
||
if multigpu:
|
||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) # type: ignore
|
||
else:
|
||
pass
|
||
|
||
try: # 如果能加载自动resume
|
||
_, _, _, epoch_str = utils.load_checkpoint(
|
||
utils.latest_checkpoint_path(f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", "G_*.pth"),
|
||
net_g,
|
||
optim_g,
|
||
)
|
||
epoch_str += 1
|
||
global_step = (epoch_str - 1) * len(train_loader)
|
||
except Exception: # 如果首次不能加载,加载pretrain
|
||
epoch_str = 1
|
||
global_step = 0
|
||
if (
|
||
hps.train.pretrained_s2G != ""
|
||
and hps.train.pretrained_s2G is not None
|
||
and os.path.exists(hps.train.pretrained_s2G)
|
||
):
|
||
if rank == 0:
|
||
logger.info(f"loaded pretrained {hps.train.pretrained_s2G}")
|
||
console.print(
|
||
f"loaded pretrained %{hps.train.pretrained_s2G}",
|
||
net_g.module.load_state_dict( # type: ignore
|
||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||
strict=False,
|
||
)
|
||
if multigpu
|
||
else net_g.load_state_dict(
|
||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||
strict=False,
|
||
),
|
||
)
|
||
|
||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||
|
||
for _ in range(epoch_str):
|
||
scheduler_g.step()
|
||
|
||
scaler = GradScaler(device=device.type, enabled=hps.train.fp16_run)
|
||
|
||
net_d = optim_d = scheduler_d = None
|
||
if rank == 0:
|
||
logger.info(f"start training from epoch {epoch_str}")
|
||
|
||
with (
|
||
Progress(
|
||
TextColumn("[cyan]{task.description}"),
|
||
BarColumn(),
|
||
TextColumn("{task.completed}/{task.total}"),
|
||
TimeElapsedColumn(),
|
||
console=console,
|
||
redirect_stderr=True,
|
||
redirect_stdout=True,
|
||
)
|
||
if rank == 0
|
||
else nullcontext() as progress
|
||
):
|
||
if isinstance(progress, Progress):
|
||
epoch_task: TaskID | None = progress.add_task(
|
||
"Epoch",
|
||
total=int(hps.train.epochs),
|
||
completed=int(epoch_str) - 1,
|
||
)
|
||
else:
|
||
epoch_task = step_task = None
|
||
|
||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||
if rank == 0:
|
||
assert epoch_task is not None
|
||
assert progress is not None
|
||
progress.advance(epoch_task, 1)
|
||
train_and_evaluate(
|
||
device,
|
||
epoch,
|
||
hps,
|
||
(net_g, net_d),
|
||
(optim_g, optim_d),
|
||
(scheduler_g, scheduler_d),
|
||
scaler,
|
||
(train_loader, None),
|
||
logger,
|
||
(writer, writer_eval),
|
||
)
|
||
progress.advance(epoch_task, 1)
|
||
else:
|
||
train_and_evaluate(
|
||
device,
|
||
epoch,
|
||
hps,
|
||
(net_g, net_d),
|
||
(optim_g, optim_d),
|
||
(scheduler_g, scheduler_d),
|
||
scaler,
|
||
(train_loader, None),
|
||
None,
|
||
(None, None),
|
||
)
|
||
scheduler_g.step()
|
||
if rank == 0:
|
||
assert progress
|
||
progress.stop()
|
||
logger.info("Training Done")
|
||
sys.exit(0)
|
||
|
||
|
||
def train_and_evaluate(
|
||
device: torch.device,
|
||
epoch,
|
||
hps,
|
||
nets,
|
||
optims,
|
||
schedulers,
|
||
scaler,
|
||
loaders,
|
||
logger,
|
||
writers,
|
||
):
|
||
net_g, net_d = nets
|
||
optim_g, optim_d = optims
|
||
# scheduler_g, scheduler_d = schedulers
|
||
train_loader, eval_loader = loaders
|
||
writer, writer_eval = writers
|
||
|
||
train_loader.batch_sampler.set_epoch(epoch)
|
||
global global_step
|
||
|
||
net_g.train()
|
||
|
||
with (
|
||
Progress(
|
||
TextColumn("[cyan]{task.description}"),
|
||
BarColumn(),
|
||
TextColumn("{task.completed}/{task.total}"),
|
||
SpeedColumnIteration(show_speed=True),
|
||
TimeRemainingColumn(elapsed_when_finished=True),
|
||
console=console,
|
||
redirect_stderr=True,
|
||
redirect_stdout=True,
|
||
transient=not (int(epoch) == int(hps.train.epochs)),
|
||
)
|
||
if device.index == 0
|
||
else nullcontext() as progress
|
||
):
|
||
if isinstance(progress, Progress):
|
||
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
|
||
else:
|
||
step_task = None
|
||
|
||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||
train_loader
|
||
):
|
||
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
|
||
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
|
||
ssl = ssl.to(device, non_blocking=True)
|
||
ssl.requires_grad = False
|
||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
|
||
|
||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||
cfm_loss = net_g(
|
||
ssl,
|
||
spec,
|
||
mel,
|
||
ssl_lengths,
|
||
spec_lengths,
|
||
text,
|
||
text_lengths,
|
||
mel_lengths,
|
||
use_grad_ckpt=hps.train.grad_ckpt,
|
||
)
|
||
loss_gen_all = cfm_loss
|
||
optim_g.zero_grad()
|
||
scaler.scale(loss_gen_all).backward()
|
||
scaler.unscale_(optim_g)
|
||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||
scaler.step(optim_g)
|
||
scaler.update()
|
||
|
||
if device.index == 0 and progress is not None and step_task is not None:
|
||
progress.advance(step_task, 1)
|
||
|
||
if device.index == 0:
|
||
if global_step % hps.train.log_interval == 0:
|
||
lr = optim_g.param_groups[0]["lr"]
|
||
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
||
losses = [cfm_loss]
|
||
logger.info(
|
||
"Train Epoch: {} [{:.0f}%]".format(
|
||
epoch,
|
||
100.0 * batch_idx / len(train_loader),
|
||
)
|
||
)
|
||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||
|
||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||
utils.summarize(
|
||
writer=writer,
|
||
global_step=global_step,
|
||
# images=image_dict,
|
||
scalars=scalar_dict,
|
||
)
|
||
|
||
global_step += 1
|
||
|
||
if hps.train.if_save_latest == 0:
|
||
utils.save_checkpoint(
|
||
net_g,
|
||
optim_g,
|
||
hps.train.learning_rate,
|
||
epoch,
|
||
os.path.join(
|
||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||
f"G_{global_step}.pth",
|
||
),
|
||
logger,
|
||
)
|
||
|
||
else:
|
||
utils.save_checkpoint(
|
||
net_g,
|
||
optim_g,
|
||
hps.train.learning_rate,
|
||
epoch,
|
||
os.path.join(
|
||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||
"G_233333333333.pth",
|
||
),
|
||
logger,
|
||
)
|
||
|
||
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
|
||
if hps.train.if_save_every_weights is True:
|
||
if hasattr(net_g, "module"):
|
||
ckpt = net_g.module.state_dict()
|
||
else:
|
||
ckpt = net_g.state_dict()
|
||
save_info = save_ckpt(
|
||
ckpt,
|
||
hps.name + f"_e{epoch}_s{global_step}",
|
||
epoch,
|
||
global_step,
|
||
hps,
|
||
)
|
||
logger.info(f"saving ckpt {hps.name}_e{epoch}:{save_info}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|