This commit is contained in:
XXXXRT666 2025-09-05 20:50:36 +00:00
parent 820195917a
commit 49667f44e8
11 changed files with 454 additions and 391 deletions

View File

@ -39,12 +39,12 @@ class DistributedBucketSampler(Sampler[T_co]):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
num_replicas = dist.get_world_size() if torch.cuda.device_count() > 1 else 1
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank() if torch.cuda.is_available() else 0
if torch.cuda.is_available():
rank = dist.get_rank() if torch.cuda.device_count() > 1 else 0
if torch.cuda.device_count() > 1:
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))

View File

@ -48,6 +48,7 @@ with contextlib.suppress(ImportError):
warnings.filterwarnings(
"ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
)
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)

View File

@ -12,11 +12,11 @@ from typing import List
import torch
import torch.multiprocessing as tmp
import typer
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
from torch.multiprocessing.spawn import spawn
from transformers import BertForMaskedLM, BertTokenizerFast
from GPT_SoVITS.Accelerate.logger import console, logger
from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration
from GPT_SoVITS.text.cleaner import clean_text
from tools.my_utils import clean_path
@ -250,7 +250,7 @@ def main(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
TaskProgressColumn(show_speed=True),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
console=console,
redirect_stderr=False,
@ -286,7 +286,7 @@ def main(
console.print(e)
finally:
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
raise SystemExit(1)
sys.exit(1)
ctx.join()
with open(merged_path, "w", encoding="utf8") as fout:

View File

@ -7,21 +7,24 @@ import sys
import time
from pathlib import Path
from typing import List, Optional
import warnings
import numpy as np
import torch
import torch.multiprocessing as tmp
import torchaudio
import typer
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
from scipy.io import wavfile
from torch.multiprocessing.spawn import spawn
from GPT_SoVITS.Accelerate.logger import console, logger
from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration
from GPT_SoVITS.eres2net.ERes2NetV2 import ERes2NetV2
from GPT_SoVITS.feature_extractor import cnhubert as cnhubert_mod
from tools.my_utils import clean_path, load_audio
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
torch.set_grad_enabled(False)
tmp.set_start_method("spawn", force=True)
@ -351,7 +354,7 @@ def main(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
TaskProgressColumn(show_speed=True),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
console=console,
redirect_stderr=False,
@ -390,7 +393,7 @@ def main(
console.print(e)
finally:
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
raise SystemExit(1)
sys.exit(1)
ctx.join()

View File

@ -12,10 +12,10 @@ from typing import List, Tuple
import torch
import torch.multiprocessing as tmp
import typer
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TimeRemainingColumn, TextColumn
from torch.multiprocessing.spawn import spawn
from GPT_SoVITS.Accelerate.logger import console, logger
from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
from GPT_SoVITS.process_ckpt import inspect_version
from tools.my_utils import DictToAttrRecursive, clean_path
@ -244,7 +244,7 @@ def main(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
TaskProgressColumn(show_speed=True),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
console=console,
) as progress:
@ -279,7 +279,7 @@ def main(
console.print(e)
finally:
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
raise SystemExit(1)
sys.exit(1)
ctx.join()
with open(merged_path, "w", encoding="utf8") as fout:

View File

@ -29,7 +29,7 @@ os.environ["MASTER_ADDR"] = "localhost"
if platform.system() == "Windows":
os.environ["USE_LIBUV"] = "0"
torch.set_grad_enabled(False)
torch.set_grad_enabled(True)
class ARModelCheckpoint(ModelCheckpoint):
@ -52,15 +52,15 @@ class ARModelCheckpoint(ModelCheckpoint):
def on_train_epoch_end(self, trainer, pl_module):
if self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
if self.if_save_latest is True: # 如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
for name in to_clean:
try:
os.remove(f"{self.dirpath}/{name}")
except Exception as _:
pass
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
if self.if_save_latest is True: # 如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
for name in to_clean:
try:
os.remove(f"{self.dirpath}/{name}")
except Exception as _:
pass
if self.if_save_every_weights is True:
to_save_od: OrderedDict[str, Any] = OrderedDict()
to_save_od["weight"] = OrderedDict()
@ -94,7 +94,7 @@ def main(args):
process_group_backend="nccl" if platform.system() != "Windows" else "gloo", find_unused_parameters=False
)
else:
strategy = SingleDeviceStrategy("cuda")
strategy = SingleDeviceStrategy("cuda:0")
else:
strategy = SingleDeviceStrategy("cpu")

View File

@ -8,7 +8,7 @@ from random import randint
import torch
import torch.distributed as dist
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from torch.multiprocessing.spawn import spawn
@ -290,9 +290,10 @@ def run(rank, n_gpus, hps):
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
TimeElapsedColumn(),
console=console,
redirect_stderr=True,
redirect_stdout=True,
)
if rank == 0
else nullcontext() as progress
@ -303,7 +304,6 @@ def run(rank, n_gpus, hps):
total=int(hps.train.epochs),
completed=int(epoch_str) - 1,
)
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
else:
epoch_task = step_task = None
@ -311,8 +311,7 @@ def run(rank, n_gpus, hps):
if rank == 0:
assert epoch_task is not None
assert progress is not None
assert step_task is not None
progress.reset(step_task, total=len(train_loader))
progress.advance(epoch_task, 1)
train_and_evaluate(
device,
epoch,
@ -325,9 +324,7 @@ def run(rank, n_gpus, hps):
(train_loader, None),
logger,
(writer, writer_eval),
(progress, step_task),
)
progress.advance(epoch_task, 1)
else:
train_and_evaluate(
device,
@ -340,11 +337,12 @@ def run(rank, n_gpus, hps):
(train_loader, None),
None,
(None, None),
(None, None),
)
scheduler_g.step()
scheduler_d.step()
if rank == 0:
assert progress
progress.stop()
logger.info("Training Done")
sys.exit(0)
@ -360,215 +358,242 @@ def train_and_evaluate(
loaders,
logger,
writers,
progress_group: tuple[Progress, TaskID] | tuple[None, None],
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
writer, writer_eval = writers
progress, step_task = progress_group
train_loader.batch_sampler.set_epoch(epoch)
global global_step
for batch_idx, data in enumerate(train_loader):
if hps.model.version in {"v2Pro", "v2ProPlus"}:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = map(
lambda x: x.to(device, non_blocking=True),
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb),
)
else:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = map(
lambda x: x.to(device, non_blocking=True),
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths),
)
sv_emb = None
ssl.requires_grad = False
net_g.train()
net_d.train()
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
with (
Progress(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
console=console,
redirect_stderr=True,
redirect_stdout=True,
transient=not (int(epoch) == int(hps.train.epochs)),
)
if device.index == 0
else nullcontext() as progress
):
if isinstance(progress, Progress):
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
else:
step_task = None
for batch_idx, data in enumerate(train_loader):
if hps.model.version in {"v2Pro", "v2ProPlus"}:
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
ssl, spec, spec_lengths, text, text_lengths, sv_emb
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = map(
lambda x: x.to(device, non_blocking=True),
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb),
)
else:
(
y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(device_type=device.type, enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r,
y_d_hat_g,
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = map(
lambda x: x.to(device, non_blocking=True),
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths),
)
loss_disc_all = loss_disc
sv_emb = None
ssl.requires_grad = False
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(device_type=device.type, enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if device.index == 0:
assert progress
assert step_task
progress.advance(step_task, 1)
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch,
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update(
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl_ssl": kl_ssl,
"loss/g/kl": loss_kl,
}
)
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = None
try: # Some people installed the wrong version of matplotlib.
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy(),
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy(),
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy(),
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy(),
),
}
except Exception as _:
pass
if image_dict:
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
if hps.model.version in {"v2Pro", "v2ProPlus"}:
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
ssl, spec, spec_lengths, text, text_lengths, sv_emb
)
else:
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict,
(
y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(device_type=device.type, enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r,
y_d_hat_g,
)
global_step += 1
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(device_type=device.type, enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if device.index == 0 and progress is not None and step_task is not None:
progress.advance(step_task, 1)
if device.index == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch,
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update(
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl_ssl": kl_ssl,
"loss/g/kl": loss_kl,
}
)
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = None
try: # Some people installed the wrong version of matplotlib.
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy(),
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy(),
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy(),
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy(),
),
}
except Exception as _:
pass
if image_dict:
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
else:
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict,
)
global_step += 1
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
f"G_{global_step}.pth",
),
logger,
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"D_{global_step}.pth",
),
logger,
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"G_233333333333.pth",
),
logger,
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"D_233333333333.pth",
),
logger,
)
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
f"G_{global_step}.pth",
),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"D_{global_step}.pth",
),
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"G_233333333333.pth",
),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"D_233333333333.pth",
),
)
if device.index == 0 and hps.train.if_save_every_weights is True:
if hps.train.if_save_every_weights is True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:

View File

@ -8,7 +8,7 @@ from random import randint
import torch
import torch.distributed as dist
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from torch.multiprocessing.spawn import spawn
@ -208,9 +208,10 @@ def run(rank, n_gpus, hps):
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
TimeElapsedColumn(),
console=console,
redirect_stderr=True,
redirect_stdout=True,
)
if rank == 0
else nullcontext() as progress
@ -221,7 +222,6 @@ def run(rank, n_gpus, hps):
total=int(hps.train.epochs),
completed=int(epoch_str) - 1,
)
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
else:
epoch_task = step_task = None
@ -229,8 +229,7 @@ def run(rank, n_gpus, hps):
if rank == 0:
assert epoch_task is not None
assert progress is not None
assert step_task is not None
progress.reset(step_task, total=len(train_loader))
progress.advance(epoch_task, 1)
train_and_evaluate(
device,
epoch,
@ -242,7 +241,6 @@ def run(rank, n_gpus, hps):
(train_loader, None),
logger,
(writer, writer_eval),
(progress, step_task),
)
progress.advance(epoch_task, 1)
else:
@ -257,10 +255,11 @@ def run(rank, n_gpus, hps):
(train_loader, None),
None,
(None, None),
(None, None),
)
scheduler_g.step()
if rank == 0:
assert progress
progress.stop()
logger.info("Training Done")
sys.exit(0)
@ -276,101 +275,122 @@ def train_and_evaluate(
loaders,
logger,
writers,
progress_group: tuple[Progress, TaskID] | tuple[None, None],
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
writer, writer_eval = writers
progress, step_task = progress_group
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
train_loader
with (
Progress(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
console=console,
redirect_stderr=True,
redirect_stdout=True,
transient=not (int(epoch) == int(hps.train.epochs)),
)
if device.index == 0
else nullcontext() as progress
):
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
ssl = ssl.to(device, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
cfm_loss = net_g(
ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if device.index == 0:
assert progress
assert step_task
progress.advance(step_task, 1)
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
losses = [cfm_loss]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch,
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
# images=image_dict,
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
f"G_{global_step}.pth",
),
)
if isinstance(progress, Progress):
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"G_233333333333.pth",
),
)
step_task = None
if device.index == 0 and hps.train.if_save_every_weights is True:
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
train_loader
):
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
ssl = ssl.to(device, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
cfm_loss = net_g(
ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if device.index == 0 and progress is not None and step_task is not None:
progress.advance(step_task, 1)
if device.index == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
losses = [cfm_loss]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch,
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
# images=image_dict,
scalars=scalar_dict,
)
global_step += 1
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
f"G_{global_step}.pth",
),
logger,
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"G_233333333333.pth",
),
logger,
)
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
if hps.train.if_save_every_weights is True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:

View File

@ -12,14 +12,13 @@ import torch
import torch.distributed as dist
from peft import LoraConfig, get_peft_model
from rich import print
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from torch.multiprocessing.spawn import spawn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import GPT_SoVITS.utils as utils
from GPT_SoVITS.Accelerate import console, logger
@ -234,9 +233,10 @@ def run(rank, n_gpus, hps):
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
TimeElapsedColumn(),
console=console,
redirect_stderr=True,
redirect_stdout=True,
)
if rank == 0
else nullcontext() as progress
@ -247,7 +247,6 @@ def run(rank, n_gpus, hps):
total=int(hps.train.epochs),
completed=int(epoch_str) - 1,
)
step_task: TaskID | None = progress.add_task("steps", total=len(train_loader))
else:
epoch_task = step_task = None
@ -255,8 +254,7 @@ def run(rank, n_gpus, hps):
if rank == 0:
assert epoch_task is not None
assert progress is not None
assert step_task is not None
progress.reset(step_task, total=len(train_loader))
progress.advance(epoch_task, 1)
train_and_evaluate(
device,
epoch,
@ -268,9 +266,7 @@ def run(rank, n_gpus, hps):
(train_loader, None),
logger,
(writer, writer_eval),
(progress, step_task),
)
progress.advance(epoch_task, 1)
else:
train_and_evaluate(
device,
@ -283,10 +279,11 @@ def run(rank, n_gpus, hps):
(train_loader, None),
None,
(None, None),
(None, None),
)
scheduler_g.step()
if rank == 0:
assert progress
progress.stop()
logger.info("Training Done")
sys.exit(0)
@ -302,84 +299,97 @@ def train_and_evaluate(
loaders,
logger,
writers,
progress_group: tuple[Progress, TaskID] | tuple[None, None],
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
writer, writer_eval = writers
progress, step_task = progress_group
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
tqdm(train_loader)
with (
Progress(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
SpeedColumnIteration(show_speed=True),
TimeRemainingColumn(elapsed_when_finished=True),
console=console,
redirect_stderr=True,
redirect_stdout=True,
transient=not (int(epoch) == int(hps.train.epochs)),
)
if device.index == 0
else nullcontext() as progress
):
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
ssl = ssl.to(device, non_blocking=True)
ssl.requires_grad = False
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
cfm_loss = net_g(
ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if device.index == 0:
assert progress
assert step_task
progress.advance(step_task, 1)
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
losses = [cfm_loss]
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(save_root, f"G_{global_step}.pth"),
)
if isinstance(progress, Progress):
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(save_root, "G_233333333333.pth"),
)
if device.index == 0 and hps.train.if_save_every_weights is True:
step_task = None
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
train_loader
):
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
ssl = ssl.to(device, non_blocking=True)
ssl.requires_grad = False
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
cfm_loss = net_g(
ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if device.index == 0 and progress is not None and step_task is not None:
progress.advance(step_task, 1)
if device.index == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
losses = [cfm_loss]
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict,
)
global_step += 1
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(save_root, f"G_{global_step}.pth"), logger
)
else:
utils.save_checkpoint(
net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(save_root, "G_233333333333.pth"), logger
)
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
if hps.train.if_save_every_weights is True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:

View File

@ -70,7 +70,7 @@ def save(fea, path): #####fix issue: torch.save doesn't support chinese path
shutil.move(tmp_path, "%s/%s" % (dir, name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, logger):
logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path))
if hasattr(model, "module"):
state_dict = model.module.state_dict()

View File

@ -605,7 +605,7 @@ def open1Bb(
data["output_dir"] = f"{s1_dir}/logs_s1_{version}"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu_numbers1Ba).strip("[]").replace(" ", "")
env["CUDA_VISIBLE_DEVICES"] = str(gpu_numbers).strip("[]").replace(" ", "")
tmp_config_path = f"{tmp}/tmp_s1.yaml"
with open(tmp_config_path, "w") as f:
@ -890,7 +890,6 @@ def open1b(
cmd = [
python_exec, "-s", "GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py",
"--inp-list", inp_text,
"--wav-dir", inp_wav_dir,
"--opt", opt_dir,
"--cnhubert", ssl_pretrained_dir,
"--device", infer_device.type,
@ -899,6 +898,9 @@ def open1b(
]
# fmt: on
if inp_wav_dir:
cmd.extend(["--wav-dir", inp_wav_dir])
if "Pro" in version:
cmd.extend(["--sv", sv_path])
@ -1124,7 +1126,6 @@ def open1abc(
cmd_2 = [
python_exec, "-s", "GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py",
"--inp-list", inp_text,
"--wav-dir", inp_wav_dir,
"--opt", opt_dir,
"--cnhubert", ssl_pretrained_dir,
"--device", infer_device.type,
@ -1133,6 +1134,9 @@ def open1abc(
]
# fmt: on
if inp_wav_dir:
cmd_2.extend(["--wav-dir", inp_wav_dir])
if "Pro" in version:
cmd_2.extend(["--sv", sv_path])
@ -1231,7 +1235,7 @@ def close1abc():
kill_process(p1abc.pid, process_name_1abc)
except Exception as _:
traceback.print_exc()
ps1abc = []
ps1abc = [None] * 3
return (
process_info(process_name_1abc, "closed"),
gr.update(visible=True),