From 49667f44e827705d0b04c64e2bdf8c3217367887 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Fri, 5 Sep 2025 20:50:36 +0000 Subject: [PATCH] . --- GPT_SoVITS/AR/data/bucket_sampler.py | 6 +- GPT_SoVITS/inference_webui.py | 1 + GPT_SoVITS/prepare_datasets/1-get-text.py | 8 +- .../2-get-hubert-sv-wav32k.py | 11 +- GPT_SoVITS/prepare_datasets/3-get-semantic.py | 8 +- GPT_SoVITS/s1_train.py | 20 +- GPT_SoVITS/s2_train.py | 421 ++++++++++-------- GPT_SoVITS/s2_train_v3.py | 196 ++++---- GPT_SoVITS/s2_train_v3_lora.py | 160 +++---- GPT_SoVITS/utils.py | 2 +- webui.py | 12 +- 11 files changed, 454 insertions(+), 391 deletions(-) diff --git a/GPT_SoVITS/AR/data/bucket_sampler.py b/GPT_SoVITS/AR/data/bucket_sampler.py index d8457334..4d2ed0b8 100644 --- a/GPT_SoVITS/AR/data/bucket_sampler.py +++ b/GPT_SoVITS/AR/data/bucket_sampler.py @@ -39,12 +39,12 @@ class DistributedBucketSampler(Sampler[T_co]): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") - num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1 + num_replicas = dist.get_world_size() if torch.cuda.device_count() > 1 else 1 if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") - rank = dist.get_rank() if torch.cuda.is_available() else 0 - if torch.cuda.is_available(): + rank = dist.get_rank() if torch.cuda.device_count() > 1 else 0 + if torch.cuda.device_count() > 1: torch.cuda.set_device(rank) if rank >= num_replicas or rank < 0: raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1)) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index bb21fcba..363046ee 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -48,6 +48,7 @@ with contextlib.suppress(ImportError): warnings.filterwarnings( "ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively." ) +warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*") logging.getLogger("markdown_it").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.ERROR) diff --git a/GPT_SoVITS/prepare_datasets/1-get-text.py b/GPT_SoVITS/prepare_datasets/1-get-text.py index d9fd513d..9d4ac605 100644 --- a/GPT_SoVITS/prepare_datasets/1-get-text.py +++ b/GPT_SoVITS/prepare_datasets/1-get-text.py @@ -12,11 +12,11 @@ from typing import List import torch import torch.multiprocessing as tmp import typer -from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn from torch.multiprocessing.spawn import spawn from transformers import BertForMaskedLM, BertTokenizerFast -from GPT_SoVITS.Accelerate.logger import console, logger +from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration from GPT_SoVITS.text.cleaner import clean_text from tools.my_utils import clean_path @@ -250,7 +250,7 @@ def main( TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), - TaskProgressColumn(show_speed=True), + SpeedColumnIteration(show_speed=True), TimeRemainingColumn(elapsed_when_finished=True), console=console, redirect_stderr=False, @@ -286,7 +286,7 @@ def main( console.print(e) finally: logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.") - raise SystemExit(1) + sys.exit(1) ctx.join() with open(merged_path, "w", encoding="utf8") as fout: diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py index ae6f4853..01bca00e 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py @@ -7,21 +7,24 @@ import sys import time from pathlib import Path from typing import List, Optional +import warnings import numpy as np import torch import torch.multiprocessing as tmp import torchaudio import typer -from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn from scipy.io import wavfile from torch.multiprocessing.spawn import spawn -from GPT_SoVITS.Accelerate.logger import console, logger +from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration from GPT_SoVITS.eres2net.ERes2NetV2 import ERes2NetV2 from GPT_SoVITS.feature_extractor import cnhubert as cnhubert_mod from tools.my_utils import clean_path, load_audio +warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*") + torch.set_grad_enabled(False) tmp.set_start_method("spawn", force=True) @@ -351,7 +354,7 @@ def main( TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), - TaskProgressColumn(show_speed=True), + SpeedColumnIteration(show_speed=True), TimeRemainingColumn(elapsed_when_finished=True), console=console, redirect_stderr=False, @@ -390,7 +393,7 @@ def main( console.print(e) finally: logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.") - raise SystemExit(1) + sys.exit(1) ctx.join() diff --git a/GPT_SoVITS/prepare_datasets/3-get-semantic.py b/GPT_SoVITS/prepare_datasets/3-get-semantic.py index 2cbcf5a5..27ca5f6d 100644 --- a/GPT_SoVITS/prepare_datasets/3-get-semantic.py +++ b/GPT_SoVITS/prepare_datasets/3-get-semantic.py @@ -12,10 +12,10 @@ from typing import List, Tuple import torch import torch.multiprocessing as tmp import typer -from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TimeRemainingColumn, TextColumn from torch.multiprocessing.spawn import spawn -from GPT_SoVITS.Accelerate.logger import console, logger +from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3 from GPT_SoVITS.process_ckpt import inspect_version from tools.my_utils import DictToAttrRecursive, clean_path @@ -244,7 +244,7 @@ def main( TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), - TaskProgressColumn(show_speed=True), + SpeedColumnIteration(show_speed=True), TimeRemainingColumn(elapsed_when_finished=True), console=console, ) as progress: @@ -279,7 +279,7 @@ def main( console.print(e) finally: logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.") - raise SystemExit(1) + sys.exit(1) ctx.join() with open(merged_path, "w", encoding="utf8") as fout: diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 727ad8a8..c28c7cfc 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -29,7 +29,7 @@ os.environ["MASTER_ADDR"] = "localhost" if platform.system() == "Windows": os.environ["USE_LIBUV"] = "0" -torch.set_grad_enabled(False) +torch.set_grad_enabled(True) class ARModelCheckpoint(ModelCheckpoint): @@ -52,15 +52,15 @@ class ARModelCheckpoint(ModelCheckpoint): def on_train_epoch_end(self, trainer, pl_module): if self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) + if self.if_save_latest is True: # 如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt + to_clean = list(os.listdir(self.dirpath)) + for name in to_clean: + try: + os.remove(f"{self.dirpath}/{name}") + except Exception as _: + pass if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: - self._save_topk_checkpoint(trainer, monitor_candidates) - if self.if_save_latest is True: # 如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt - to_clean = list(os.listdir(self.dirpath)) - for name in to_clean: - try: - os.remove(f"{self.dirpath}/{name}") - except Exception as _: - pass if self.if_save_every_weights is True: to_save_od: OrderedDict[str, Any] = OrderedDict() to_save_od["weight"] = OrderedDict() @@ -94,7 +94,7 @@ def main(args): process_group_backend="nccl" if platform.system() != "Windows" else "gloo", find_unused_parameters=False ) else: - strategy = SingleDeviceStrategy("cuda") + strategy = SingleDeviceStrategy("cuda:0") else: strategy = SingleDeviceStrategy("cpu") diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index 02381577..27cb93bf 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -8,7 +8,7 @@ from random import randint import torch import torch.distributed as dist -from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn from torch.amp.autocast_mode import autocast from torch.amp.grad_scaler import GradScaler from torch.multiprocessing.spawn import spawn @@ -290,9 +290,10 @@ def run(rank, n_gpus, hps): TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), - SpeedColumnIteration(show_speed=True), - TimeRemainingColumn(elapsed_when_finished=True), + TimeElapsedColumn(), console=console, + redirect_stderr=True, + redirect_stdout=True, ) if rank == 0 else nullcontext() as progress @@ -303,7 +304,6 @@ def run(rank, n_gpus, hps): total=int(hps.train.epochs), completed=int(epoch_str) - 1, ) - step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader)) else: epoch_task = step_task = None @@ -311,8 +311,7 @@ def run(rank, n_gpus, hps): if rank == 0: assert epoch_task is not None assert progress is not None - assert step_task is not None - progress.reset(step_task, total=len(train_loader)) + progress.advance(epoch_task, 1) train_and_evaluate( device, epoch, @@ -325,9 +324,7 @@ def run(rank, n_gpus, hps): (train_loader, None), logger, (writer, writer_eval), - (progress, step_task), ) - progress.advance(epoch_task, 1) else: train_and_evaluate( device, @@ -340,11 +337,12 @@ def run(rank, n_gpus, hps): (train_loader, None), None, (None, None), - (None, None), ) scheduler_g.step() scheduler_d.step() if rank == 0: + assert progress + progress.stop() logger.info("Training Done") sys.exit(0) @@ -360,215 +358,242 @@ def train_and_evaluate( loaders, logger, writers, - progress_group: tuple[Progress, TaskID] | tuple[None, None], ): net_g, net_d = nets optim_g, optim_d = optims # scheduler_g, scheduler_d = schedulers train_loader, eval_loader = loaders writer, writer_eval = writers - progress, step_task = progress_group train_loader.batch_sampler.set_epoch(epoch) global global_step - for batch_idx, data in enumerate(train_loader): - if hps.model.version in {"v2Pro", "v2ProPlus"}: - ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data - ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = map( - lambda x: x.to(device, non_blocking=True), - (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb), - ) - else: - ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data - ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = map( - lambda x: x.to(device, non_blocking=True), - (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths), - ) - sv_emb = None - ssl.requires_grad = False + net_g.train() + net_d.train() - with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): + with ( + Progress( + TextColumn("[cyan]{task.description}"), + BarColumn(), + TextColumn("{task.completed}/{task.total}"), + SpeedColumnIteration(show_speed=True), + TimeRemainingColumn(elapsed_when_finished=True), + console=console, + redirect_stderr=True, + redirect_stdout=True, + transient=not (int(epoch) == int(hps.train.epochs)), + ) + if device.index == 0 + else nullcontext() as progress + ): + if isinstance(progress, Progress): + step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader)) + else: + step_task = None + + for batch_idx, data in enumerate(train_loader): if hps.model.version in {"v2Pro", "v2ProPlus"}: - (y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g( - ssl, spec, spec_lengths, text, text_lengths, sv_emb + ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data + ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = map( + lambda x: x.to(device, non_blocking=True), + (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb), ) else: - ( - y_hat, - kl_ssl, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - stats_ssl, - ) = net_g(ssl, spec, spec_lengths, text, text_lengths) - - mel = spec_to_mel_torch( - spec, - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.mel_fmin, - hps.data.mel_fmax, - ) - y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax, - ) - - y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice - - # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - with autocast(device_type=device.type, enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( - y_d_hat_r, - y_d_hat_g, + ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data + ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = map( + lambda x: x.to(device, non_blocking=True), + (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths), ) - loss_disc_all = loss_disc + sv_emb = None + ssl.requires_grad = False - optim_d.zero_grad() - scaler.scale(loss_disc_all).backward() - scaler.unscale_(optim_d) - grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) - scaler.step(optim_d) - - with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): - # Generator - y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) - with autocast(device_type=device.type, enabled=False): - loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl - - loss_fm = feature_loss(fmap_r, fmap_g) - loss_gen, losses_gen = generator_loss(y_d_hat_g) - loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl - - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() - - if device.index == 0: - assert progress - assert step_task - progress.advance(step_task, 1) - if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]["lr"] - losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] - logger.info( - "Train Epoch: {} [{:.0f}%]".format( - epoch, - 100.0 * batch_idx / len(train_loader), - ) - ) - logger.info([x.item() for x in losses] + [global_step, lr]) - - scalar_dict = { - "loss/g/total": loss_gen_all, - "loss/d/total": loss_disc_all, - "learning_rate": lr, - "grad_norm_d": grad_norm_d, - "grad_norm_g": grad_norm_g, - } - scalar_dict.update( - { - "loss/g/fm": loss_fm, - "loss/g/mel": loss_mel, - "loss/g/kl_ssl": kl_ssl, - "loss/g/kl": loss_kl, - } - ) - - # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) - # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) - # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) - image_dict = None - try: # Some people installed the wrong version of matplotlib. - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy( - y_mel[0].data.cpu().numpy(), - ), - "slice/mel_gen": utils.plot_spectrogram_to_numpy( - y_hat_mel[0].data.cpu().numpy(), - ), - "all/mel": utils.plot_spectrogram_to_numpy( - mel[0].data.cpu().numpy(), - ), - "all/stats_ssl": utils.plot_spectrogram_to_numpy( - stats_ssl[0].data.cpu().numpy(), - ), - } - except Exception as _: - pass - if image_dict: - utils.summarize( - writer=writer, - global_step=global_step, - images=image_dict, - scalars=scalar_dict, + with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): + if hps.model.version in {"v2Pro", "v2ProPlus"}: + (y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g( + ssl, spec, spec_lengths, text, text_lengths, sv_emb ) else: - utils.summarize( - writer=writer, - global_step=global_step, - scalars=scalar_dict, + ( + y_hat, + kl_ssl, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + stats_ssl, + ) = net_g(ssl, spec, spec_lengths, text, text_lengths) + + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(device_type=device.type, enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, + y_d_hat_g, ) - global_step += 1 + loss_disc_all = loss_disc + + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(device_type=device.type, enabled=False): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl + + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if device.index == 0 and progress is not None and step_task is not None: + progress.advance(step_task, 1) + + if device.index == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, + 100.0 * batch_idx / len(train_loader), + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } + scalar_dict.update( + { + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/kl_ssl": kl_ssl, + "loss/g/kl": loss_kl, + } + ) + + # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = None + try: # Some people installed the wrong version of matplotlib. + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy(), + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy(), + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy(), + ), + "all/stats_ssl": utils.plot_spectrogram_to_numpy( + stats_ssl[0].data.cpu().numpy(), + ), + } + except Exception as _: + pass + if image_dict: + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + else: + utils.summarize( + writer=writer, + global_step=global_step, + scalars=scalar_dict, + ) + global_step += 1 + + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", + f"G_{global_step}.pth", + ), + logger, + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join( + "{hps.data.exp_dir}/logs_s2_{hps.model.version}", + "D_{global_step}.pth", + ), + logger, + ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", + "G_233333333333.pth", + ), + logger, + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join( + f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", + "D_233333333333.pth", + ), + logger, + ) + if epoch % hps.train.save_every_epoch == 0 and device.index == 0: - if hps.train.if_save_latest == 0: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join( - f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", - f"G_{global_step}.pth", - ), - ) - utils.save_checkpoint( - net_d, - optim_d, - hps.train.learning_rate, - epoch, - os.path.join( - "{hps.data.exp_dir}/logs_s2_{hps.model.version}", - "D_{global_step}.pth", - ), - ) - else: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join( - f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", - "G_233333333333.pth", - ), - ) - utils.save_checkpoint( - net_d, - optim_d, - hps.train.learning_rate, - epoch, - os.path.join( - "{hps.data.exp_dir}/logs_s2_{hps.model.version}", - "D_233333333333.pth", - ), - ) - if device.index == 0 and hps.train.if_save_every_weights is True: + if hps.train.if_save_every_weights is True: if hasattr(net_g, "module"): ckpt = net_g.module.state_dict() else: diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py index fcb9b1ae..82897190 100644 --- a/GPT_SoVITS/s2_train_v3.py +++ b/GPT_SoVITS/s2_train_v3.py @@ -8,7 +8,7 @@ from random import randint import torch import torch.distributed as dist -from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn from torch.amp.autocast_mode import autocast from torch.amp.grad_scaler import GradScaler from torch.multiprocessing.spawn import spawn @@ -208,9 +208,10 @@ def run(rank, n_gpus, hps): TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), - SpeedColumnIteration(show_speed=True), - TimeRemainingColumn(elapsed_when_finished=True), + TimeElapsedColumn(), console=console, + redirect_stderr=True, + redirect_stdout=True, ) if rank == 0 else nullcontext() as progress @@ -221,7 +222,6 @@ def run(rank, n_gpus, hps): total=int(hps.train.epochs), completed=int(epoch_str) - 1, ) - step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader)) else: epoch_task = step_task = None @@ -229,8 +229,7 @@ def run(rank, n_gpus, hps): if rank == 0: assert epoch_task is not None assert progress is not None - assert step_task is not None - progress.reset(step_task, total=len(train_loader)) + progress.advance(epoch_task, 1) train_and_evaluate( device, epoch, @@ -242,7 +241,6 @@ def run(rank, n_gpus, hps): (train_loader, None), logger, (writer, writer_eval), - (progress, step_task), ) progress.advance(epoch_task, 1) else: @@ -257,10 +255,11 @@ def run(rank, n_gpus, hps): (train_loader, None), None, (None, None), - (None, None), ) scheduler_g.step() if rank == 0: + assert progress + progress.stop() logger.info("Training Done") sys.exit(0) @@ -276,101 +275,122 @@ def train_and_evaluate( loaders, logger, writers, - progress_group: tuple[Progress, TaskID] | tuple[None, None], ): net_g, net_d = nets optim_g, optim_d = optims # scheduler_g, scheduler_d = schedulers train_loader, eval_loader = loaders writer, writer_eval = writers - progress, step_task = progress_group train_loader.batch_sampler.set_epoch(epoch) global global_step net_g.train() - for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( - train_loader + with ( + Progress( + TextColumn("[cyan]{task.description}"), + BarColumn(), + TextColumn("{task.completed}/{task.total}"), + SpeedColumnIteration(show_speed=True), + TimeRemainingColumn(elapsed_when_finished=True), + console=console, + redirect_stderr=True, + redirect_stdout=True, + transient=not (int(epoch) == int(hps.train.epochs)), + ) + if device.index == 0 + else nullcontext() as progress ): - spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True) - mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True) - ssl = ssl.to(device, non_blocking=True) - ssl.requires_grad = False - # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) - text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True) - - with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): - cfm_loss = net_g( - ssl, - spec, - mel, - ssl_lengths, - spec_lengths, - text, - text_lengths, - mel_lengths, - use_grad_ckpt=hps.train.grad_ckpt, - ) - loss_gen_all = cfm_loss - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() - - if device.index == 0: - assert progress - assert step_task - progress.advance(step_task, 1) - if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]["lr"] - # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] - losses = [cfm_loss] - logger.info( - "Train Epoch: {} [{:.0f}%]".format( - epoch, - 100.0 * batch_idx / len(train_loader), - ) - ) - logger.info([x.item() for x in losses] + [global_step, lr]) - - scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} - utils.summarize( - writer=writer, - global_step=global_step, - # images=image_dict, - scalars=scalar_dict, - ) - - global_step += 1 - if epoch % hps.train.save_every_epoch == 0 and device.index == 0: - if hps.train.if_save_latest == 0: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join( - f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", - f"G_{global_step}.pth", - ), - ) - + if isinstance(progress, Progress): + step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader)) else: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join( - f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", - "G_233333333333.pth", - ), - ) + step_task = None - if device.index == 0 and hps.train.if_save_every_weights is True: + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( + train_loader + ): + spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True) + mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True) + ssl = ssl.to(device, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True) + + with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): + cfm_loss = net_g( + ssl, + spec, + mel, + ssl_lengths, + spec_lengths, + text, + text_lengths, + mel_lengths, + use_grad_ckpt=hps.train.grad_ckpt, + ) + loss_gen_all = cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if device.index == 0 and progress is not None and step_task is not None: + progress.advance(step_task, 1) + + if device.index == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + losses = [cfm_loss] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, + 100.0 * batch_idx / len(train_loader), + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + utils.summarize( + writer=writer, + global_step=global_step, + # images=image_dict, + scalars=scalar_dict, + ) + + global_step += 1 + + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", + f"G_{global_step}.pth", + ), + logger, + ) + + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", + "G_233333333333.pth", + ), + logger, + ) + + if epoch % hps.train.save_every_epoch == 0 and device.index == 0: + if hps.train.if_save_every_weights is True: if hasattr(net_g, "module"): ckpt = net_g.module.state_dict() else: diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py index 70a8b60d..69cd1c4d 100644 --- a/GPT_SoVITS/s2_train_v3_lora.py +++ b/GPT_SoVITS/s2_train_v3_lora.py @@ -12,14 +12,13 @@ import torch import torch.distributed as dist from peft import LoraConfig, get_peft_model from rich import print -from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn from torch.amp.autocast_mode import autocast from torch.amp.grad_scaler import GradScaler from torch.multiprocessing.spawn import spawn from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm import GPT_SoVITS.utils as utils from GPT_SoVITS.Accelerate import console, logger @@ -234,9 +233,10 @@ def run(rank, n_gpus, hps): TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), - SpeedColumnIteration(show_speed=True), - TimeRemainingColumn(elapsed_when_finished=True), + TimeElapsedColumn(), console=console, + redirect_stderr=True, + redirect_stdout=True, ) if rank == 0 else nullcontext() as progress @@ -247,7 +247,6 @@ def run(rank, n_gpus, hps): total=int(hps.train.epochs), completed=int(epoch_str) - 1, ) - step_task: TaskID | None = progress.add_task("steps", total=len(train_loader)) else: epoch_task = step_task = None @@ -255,8 +254,7 @@ def run(rank, n_gpus, hps): if rank == 0: assert epoch_task is not None assert progress is not None - assert step_task is not None - progress.reset(step_task, total=len(train_loader)) + progress.advance(epoch_task, 1) train_and_evaluate( device, epoch, @@ -268,9 +266,7 @@ def run(rank, n_gpus, hps): (train_loader, None), logger, (writer, writer_eval), - (progress, step_task), ) - progress.advance(epoch_task, 1) else: train_and_evaluate( device, @@ -283,10 +279,11 @@ def run(rank, n_gpus, hps): (train_loader, None), None, (None, None), - (None, None), ) scheduler_g.step() if rank == 0: + assert progress + progress.stop() logger.info("Training Done") sys.exit(0) @@ -302,84 +299,97 @@ def train_and_evaluate( loaders, logger, writers, - progress_group: tuple[Progress, TaskID] | tuple[None, None], ): net_g, net_d = nets optim_g, optim_d = optims # scheduler_g, scheduler_d = schedulers train_loader, eval_loader = loaders writer, writer_eval = writers - progress, step_task = progress_group train_loader.batch_sampler.set_epoch(epoch) global global_step net_g.train() - for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( - tqdm(train_loader) + + with ( + Progress( + TextColumn("[cyan]{task.description}"), + BarColumn(), + TextColumn("{task.completed}/{task.total}"), + SpeedColumnIteration(show_speed=True), + TimeRemainingColumn(elapsed_when_finished=True), + console=console, + redirect_stderr=True, + redirect_stdout=True, + transient=not (int(epoch) == int(hps.train.epochs)), + ) + if device.index == 0 + else nullcontext() as progress ): - spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True) - mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True) - ssl = ssl.to(device, non_blocking=True) - ssl.requires_grad = False - text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True) - - with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): - cfm_loss = net_g( - ssl, - spec, - mel, - ssl_lengths, - spec_lengths, - text, - text_lengths, - mel_lengths, - use_grad_ckpt=hps.train.grad_ckpt, - ) - loss_gen_all = cfm_loss - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() - - if device.index == 0: - assert progress - assert step_task - progress.advance(step_task, 1) - if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]["lr"] - losses = [cfm_loss] - logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader))) - logger.info([x.item() for x in losses] + [global_step, lr]) - - scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} - utils.summarize( - writer=writer, - global_step=global_step, - scalars=scalar_dict, - ) - - global_step += 1 - if epoch % hps.train.save_every_epoch == 0 and device.index == 0: - if hps.train.if_save_latest == 0: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join(save_root, f"G_{global_step}.pth"), - ) + if isinstance(progress, Progress): + step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader)) else: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join(save_root, "G_233333333333.pth"), - ) - if device.index == 0 and hps.train.if_save_every_weights is True: + step_task = None + + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( + train_loader + ): + spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True) + mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True) + ssl = ssl.to(device, non_blocking=True) + ssl.requires_grad = False + text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True) + + with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): + cfm_loss = net_g( + ssl, + spec, + mel, + ssl_lengths, + spec_lengths, + text, + text_lengths, + mel_lengths, + use_grad_ckpt=hps.train.grad_ckpt, + ) + loss_gen_all = cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if device.index == 0 and progress is not None and step_task is not None: + progress.advance(step_task, 1) + + if device.index == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [cfm_loss] + logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader))) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + utils.summarize( + writer=writer, + global_step=global_step, + scalars=scalar_dict, + ) + + global_step += 1 + + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(save_root, f"G_{global_step}.pth"), logger + ) + else: + utils.save_checkpoint( + net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(save_root, "G_233333333333.pth"), logger + ) + + if epoch % hps.train.save_every_epoch == 0 and device.index == 0: + if hps.train.if_save_every_weights is True: if hasattr(net_g, "module"): ckpt = net_g.module.state_dict() else: diff --git a/GPT_SoVITS/utils.py b/GPT_SoVITS/utils.py index 95c3cbb0..1467c318 100644 --- a/GPT_SoVITS/utils.py +++ b/GPT_SoVITS/utils.py @@ -70,7 +70,7 @@ def save(fea, path): #####fix issue: torch.save doesn't support chinese path shutil.move(tmp_path, "%s/%s" % (dir, name)) -def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, logger): logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path)) if hasattr(model, "module"): state_dict = model.module.state_dict() diff --git a/webui.py b/webui.py index 1e38fa40..f81ac802 100644 --- a/webui.py +++ b/webui.py @@ -605,7 +605,7 @@ def open1Bb( data["output_dir"] = f"{s1_dir}/logs_s1_{version}" env = os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = str(gpu_numbers1Ba).strip("[]").replace(" ", "") + env["CUDA_VISIBLE_DEVICES"] = str(gpu_numbers).strip("[]").replace(" ", "") tmp_config_path = f"{tmp}/tmp_s1.yaml" with open(tmp_config_path, "w") as f: @@ -890,7 +890,6 @@ def open1b( cmd = [ python_exec, "-s", "GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py", "--inp-list", inp_text, - "--wav-dir", inp_wav_dir, "--opt", opt_dir, "--cnhubert", ssl_pretrained_dir, "--device", infer_device.type, @@ -899,6 +898,9 @@ def open1b( ] # fmt: on + if inp_wav_dir: + cmd.extend(["--wav-dir", inp_wav_dir]) + if "Pro" in version: cmd.extend(["--sv", sv_path]) @@ -1124,7 +1126,6 @@ def open1abc( cmd_2 = [ python_exec, "-s", "GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py", "--inp-list", inp_text, - "--wav-dir", inp_wav_dir, "--opt", opt_dir, "--cnhubert", ssl_pretrained_dir, "--device", infer_device.type, @@ -1133,6 +1134,9 @@ def open1abc( ] # fmt: on + if inp_wav_dir: + cmd_2.extend(["--wav-dir", inp_wav_dir]) + if "Pro" in version: cmd_2.extend(["--sv", sv_path]) @@ -1231,7 +1235,7 @@ def close1abc(): kill_process(p1abc.pid, process_name_1abc) except Exception as _: traceback.print_exc() - ps1abc = [] + ps1abc = [None] * 3 return ( process_info(process_name_1abc, "closed"), gr.update(visible=True),