mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-05 05:31:07 +08:00
.
This commit is contained in:
parent
820195917a
commit
49667f44e8
@ -39,12 +39,12 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
|
||||
num_replicas = dist.get_world_size() if torch.cuda.device_count() > 1 else 1
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank() if torch.cuda.is_available() else 0
|
||||
if torch.cuda.is_available():
|
||||
rank = dist.get_rank() if torch.cuda.device_count() > 1 else 0
|
||||
if torch.cuda.device_count() > 1:
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||
|
||||
@ -48,6 +48,7 @@ with contextlib.suppress(ImportError):
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
|
||||
|
||||
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
@ -12,11 +12,11 @@ from typing import List
|
||||
import torch
|
||||
import torch.multiprocessing as tmp
|
||||
import typer
|
||||
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
|
||||
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
||||
from torch.multiprocessing.spawn import spawn
|
||||
from transformers import BertForMaskedLM, BertTokenizerFast
|
||||
|
||||
from GPT_SoVITS.Accelerate.logger import console, logger
|
||||
from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration
|
||||
from GPT_SoVITS.text.cleaner import clean_text
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
@ -250,7 +250,7 @@ def main(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
TaskProgressColumn(show_speed=True),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
console=console,
|
||||
redirect_stderr=False,
|
||||
@ -286,7 +286,7 @@ def main(
|
||||
console.print(e)
|
||||
finally:
|
||||
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
|
||||
raise SystemExit(1)
|
||||
sys.exit(1)
|
||||
ctx.join()
|
||||
|
||||
with open(merged_path, "w", encoding="utf8") as fout:
|
||||
|
||||
@ -7,21 +7,24 @@ import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as tmp
|
||||
import torchaudio
|
||||
import typer
|
||||
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
|
||||
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
||||
from scipy.io import wavfile
|
||||
from torch.multiprocessing.spawn import spawn
|
||||
|
||||
from GPT_SoVITS.Accelerate.logger import console, logger
|
||||
from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration
|
||||
from GPT_SoVITS.eres2net.ERes2NetV2 import ERes2NetV2
|
||||
from GPT_SoVITS.feature_extractor import cnhubert as cnhubert_mod
|
||||
from tools.my_utils import clean_path, load_audio
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
tmp.set_start_method("spawn", force=True)
|
||||
@ -351,7 +354,7 @@ def main(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
TaskProgressColumn(show_speed=True),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
console=console,
|
||||
redirect_stderr=False,
|
||||
@ -390,7 +393,7 @@ def main(
|
||||
console.print(e)
|
||||
finally:
|
||||
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
|
||||
raise SystemExit(1)
|
||||
sys.exit(1)
|
||||
|
||||
ctx.join()
|
||||
|
||||
|
||||
@ -12,10 +12,10 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.multiprocessing as tmp
|
||||
import typer
|
||||
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
|
||||
from rich.progress import BarColumn, Progress, TimeRemainingColumn, TextColumn
|
||||
from torch.multiprocessing.spawn import spawn
|
||||
|
||||
from GPT_SoVITS.Accelerate.logger import console, logger
|
||||
from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration
|
||||
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from GPT_SoVITS.process_ckpt import inspect_version
|
||||
from tools.my_utils import DictToAttrRecursive, clean_path
|
||||
@ -244,7 +244,7 @@ def main(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
TaskProgressColumn(show_speed=True),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
console=console,
|
||||
) as progress:
|
||||
@ -279,7 +279,7 @@ def main(
|
||||
console.print(e)
|
||||
finally:
|
||||
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
|
||||
raise SystemExit(1)
|
||||
sys.exit(1)
|
||||
ctx.join()
|
||||
|
||||
with open(merged_path, "w", encoding="utf8") as fout:
|
||||
|
||||
@ -29,7 +29,7 @@ os.environ["MASTER_ADDR"] = "localhost"
|
||||
if platform.system() == "Windows":
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
|
||||
class ARModelCheckpoint(ModelCheckpoint):
|
||||
@ -52,15 +52,15 @@ class ARModelCheckpoint(ModelCheckpoint):
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
if self.if_save_latest is True: # 如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
to_clean = list(os.listdir(self.dirpath))
|
||||
for name in to_clean:
|
||||
try:
|
||||
os.remove(f"{self.dirpath}/{name}")
|
||||
except Exception as _:
|
||||
pass
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
if self.if_save_latest is True: # 如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
to_clean = list(os.listdir(self.dirpath))
|
||||
for name in to_clean:
|
||||
try:
|
||||
os.remove(f"{self.dirpath}/{name}")
|
||||
except Exception as _:
|
||||
pass
|
||||
if self.if_save_every_weights is True:
|
||||
to_save_od: OrderedDict[str, Any] = OrderedDict()
|
||||
to_save_od["weight"] = OrderedDict()
|
||||
@ -94,7 +94,7 @@ def main(args):
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo", find_unused_parameters=False
|
||||
)
|
||||
else:
|
||||
strategy = SingleDeviceStrategy("cuda")
|
||||
strategy = SingleDeviceStrategy("cuda:0")
|
||||
else:
|
||||
strategy = SingleDeviceStrategy("cpu")
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from random import randint
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn
|
||||
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from torch.amp.grad_scaler import GradScaler
|
||||
from torch.multiprocessing.spawn import spawn
|
||||
@ -290,9 +290,10 @@ def run(rank, n_gpus, hps):
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
redirect_stderr=True,
|
||||
redirect_stdout=True,
|
||||
)
|
||||
if rank == 0
|
||||
else nullcontext() as progress
|
||||
@ -303,7 +304,6 @@ def run(rank, n_gpus, hps):
|
||||
total=int(hps.train.epochs),
|
||||
completed=int(epoch_str) - 1,
|
||||
)
|
||||
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
|
||||
else:
|
||||
epoch_task = step_task = None
|
||||
|
||||
@ -311,8 +311,7 @@ def run(rank, n_gpus, hps):
|
||||
if rank == 0:
|
||||
assert epoch_task is not None
|
||||
assert progress is not None
|
||||
assert step_task is not None
|
||||
progress.reset(step_task, total=len(train_loader))
|
||||
progress.advance(epoch_task, 1)
|
||||
train_and_evaluate(
|
||||
device,
|
||||
epoch,
|
||||
@ -325,9 +324,7 @@ def run(rank, n_gpus, hps):
|
||||
(train_loader, None),
|
||||
logger,
|
||||
(writer, writer_eval),
|
||||
(progress, step_task),
|
||||
)
|
||||
progress.advance(epoch_task, 1)
|
||||
else:
|
||||
train_and_evaluate(
|
||||
device,
|
||||
@ -340,11 +337,12 @@ def run(rank, n_gpus, hps):
|
||||
(train_loader, None),
|
||||
None,
|
||||
(None, None),
|
||||
(None, None),
|
||||
)
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
if rank == 0:
|
||||
assert progress
|
||||
progress.stop()
|
||||
logger.info("Training Done")
|
||||
sys.exit(0)
|
||||
|
||||
@ -360,215 +358,242 @@ def train_and_evaluate(
|
||||
loaders,
|
||||
logger,
|
||||
writers,
|
||||
progress_group: tuple[Progress, TaskID] | tuple[None, None],
|
||||
):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
train_loader, eval_loader = loaders
|
||||
writer, writer_eval = writers
|
||||
progress, step_task = progress_group
|
||||
|
||||
train_loader.batch_sampler.set_epoch(epoch)
|
||||
global global_step
|
||||
|
||||
for batch_idx, data in enumerate(train_loader):
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = map(
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb),
|
||||
)
|
||||
else:
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = map(
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths),
|
||||
)
|
||||
sv_emb = None
|
||||
ssl.requires_grad = False
|
||||
net_g.train()
|
||||
net_d.train()
|
||||
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
with (
|
||||
Progress(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
console=console,
|
||||
redirect_stderr=True,
|
||||
redirect_stdout=True,
|
||||
transient=not (int(epoch) == int(hps.train.epochs)),
|
||||
)
|
||||
if device.index == 0
|
||||
else nullcontext() as progress
|
||||
):
|
||||
if isinstance(progress, Progress):
|
||||
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
|
||||
else:
|
||||
step_task = None
|
||||
|
||||
for batch_idx, data in enumerate(train_loader):
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
|
||||
ssl, spec, spec_lengths, text, text_lengths, sv_emb
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = map(
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb),
|
||||
)
|
||||
else:
|
||||
(
|
||||
y_hat,
|
||||
kl_ssl,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
stats_ssl,
|
||||
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
|
||||
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||
with autocast(device_type=device.type, enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r,
|
||||
y_d_hat_g,
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = map(
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths),
|
||||
)
|
||||
loss_disc_all = loss_disc
|
||||
sv_emb = None
|
||||
ssl.requires_grad = False
|
||||
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc_all).backward()
|
||||
scaler.unscale_(optim_d)
|
||||
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
||||
scaler.step(optim_d)
|
||||
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
# Generator
|
||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
||||
with autocast(device_type=device.type, enabled=False):
|
||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
||||
|
||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
||||
loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
|
||||
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if device.index == 0:
|
||||
assert progress
|
||||
assert step_task
|
||||
progress.advance(step_task, 1)
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch,
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {
|
||||
"loss/g/total": loss_gen_all,
|
||||
"loss/d/total": loss_disc_all,
|
||||
"learning_rate": lr,
|
||||
"grad_norm_d": grad_norm_d,
|
||||
"grad_norm_g": grad_norm_g,
|
||||
}
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/g/fm": loss_fm,
|
||||
"loss/g/mel": loss_mel,
|
||||
"loss/g/kl_ssl": kl_ssl,
|
||||
"loss/g/kl": loss_kl,
|
||||
}
|
||||
)
|
||||
|
||||
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
||||
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
||||
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
||||
image_dict = None
|
||||
try: # Some people installed the wrong version of matplotlib.
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
||||
stats_ssl[0].data.cpu().numpy(),
|
||||
),
|
||||
}
|
||||
except Exception as _:
|
||||
pass
|
||||
if image_dict:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
|
||||
ssl, spec, spec_lengths, text, text_lengths, sv_emb
|
||||
)
|
||||
else:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict,
|
||||
(
|
||||
y_hat,
|
||||
kl_ssl,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
stats_ssl,
|
||||
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
|
||||
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||
with autocast(device_type=device.type, enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r,
|
||||
y_d_hat_g,
|
||||
)
|
||||
global_step += 1
|
||||
loss_disc_all = loss_disc
|
||||
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc_all).backward()
|
||||
scaler.unscale_(optim_d)
|
||||
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
||||
scaler.step(optim_d)
|
||||
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
# Generator
|
||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
||||
with autocast(device_type=device.type, enabled=False):
|
||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
||||
|
||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
||||
loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
|
||||
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if device.index == 0 and progress is not None and step_task is not None:
|
||||
progress.advance(step_task, 1)
|
||||
|
||||
if device.index == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch,
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {
|
||||
"loss/g/total": loss_gen_all,
|
||||
"loss/d/total": loss_disc_all,
|
||||
"learning_rate": lr,
|
||||
"grad_norm_d": grad_norm_d,
|
||||
"grad_norm_g": grad_norm_g,
|
||||
}
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/g/fm": loss_fm,
|
||||
"loss/g/mel": loss_mel,
|
||||
"loss/g/kl_ssl": kl_ssl,
|
||||
"loss/g/kl": loss_kl,
|
||||
}
|
||||
)
|
||||
|
||||
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
||||
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
||||
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
||||
image_dict = None
|
||||
try: # Some people installed the wrong version of matplotlib.
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
||||
stats_ssl[0].data.cpu().numpy(),
|
||||
),
|
||||
}
|
||||
except Exception as _:
|
||||
pass
|
||||
if image_dict:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
else:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
global_step += 1
|
||||
|
||||
if hps.train.if_save_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
f"G_{global_step}.pth",
|
||||
),
|
||||
logger,
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"D_{global_step}.pth",
|
||||
),
|
||||
logger,
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"G_233333333333.pth",
|
||||
),
|
||||
logger,
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"D_233333333333.pth",
|
||||
),
|
||||
logger,
|
||||
)
|
||||
|
||||
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
f"G_{global_step}.pth",
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"D_{global_step}.pth",
|
||||
),
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"G_233333333333.pth",
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"D_233333333333.pth",
|
||||
),
|
||||
)
|
||||
if device.index == 0 and hps.train.if_save_every_weights is True:
|
||||
if hps.train.if_save_every_weights is True:
|
||||
if hasattr(net_g, "module"):
|
||||
ckpt = net_g.module.state_dict()
|
||||
else:
|
||||
|
||||
@ -8,7 +8,7 @@ from random import randint
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn
|
||||
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from torch.amp.grad_scaler import GradScaler
|
||||
from torch.multiprocessing.spawn import spawn
|
||||
@ -208,9 +208,10 @@ def run(rank, n_gpus, hps):
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
redirect_stderr=True,
|
||||
redirect_stdout=True,
|
||||
)
|
||||
if rank == 0
|
||||
else nullcontext() as progress
|
||||
@ -221,7 +222,6 @@ def run(rank, n_gpus, hps):
|
||||
total=int(hps.train.epochs),
|
||||
completed=int(epoch_str) - 1,
|
||||
)
|
||||
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
|
||||
else:
|
||||
epoch_task = step_task = None
|
||||
|
||||
@ -229,8 +229,7 @@ def run(rank, n_gpus, hps):
|
||||
if rank == 0:
|
||||
assert epoch_task is not None
|
||||
assert progress is not None
|
||||
assert step_task is not None
|
||||
progress.reset(step_task, total=len(train_loader))
|
||||
progress.advance(epoch_task, 1)
|
||||
train_and_evaluate(
|
||||
device,
|
||||
epoch,
|
||||
@ -242,7 +241,6 @@ def run(rank, n_gpus, hps):
|
||||
(train_loader, None),
|
||||
logger,
|
||||
(writer, writer_eval),
|
||||
(progress, step_task),
|
||||
)
|
||||
progress.advance(epoch_task, 1)
|
||||
else:
|
||||
@ -257,10 +255,11 @@ def run(rank, n_gpus, hps):
|
||||
(train_loader, None),
|
||||
None,
|
||||
(None, None),
|
||||
(None, None),
|
||||
)
|
||||
scheduler_g.step()
|
||||
if rank == 0:
|
||||
assert progress
|
||||
progress.stop()
|
||||
logger.info("Training Done")
|
||||
sys.exit(0)
|
||||
|
||||
@ -276,101 +275,122 @@ def train_and_evaluate(
|
||||
loaders,
|
||||
logger,
|
||||
writers,
|
||||
progress_group: tuple[Progress, TaskID] | tuple[None, None],
|
||||
):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
train_loader, eval_loader = loaders
|
||||
writer, writer_eval = writers
|
||||
progress, step_task = progress_group
|
||||
|
||||
train_loader.batch_sampler.set_epoch(epoch)
|
||||
global global_step
|
||||
|
||||
net_g.train()
|
||||
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
train_loader
|
||||
with (
|
||||
Progress(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
console=console,
|
||||
redirect_stderr=True,
|
||||
redirect_stdout=True,
|
||||
transient=not (int(epoch) == int(hps.train.epochs)),
|
||||
)
|
||||
if device.index == 0
|
||||
else nullcontext() as progress
|
||||
):
|
||||
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
|
||||
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
|
||||
ssl = ssl.to(device, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
|
||||
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if device.index == 0:
|
||||
assert progress
|
||||
assert step_task
|
||||
progress.advance(step_task, 1)
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
||||
losses = [cfm_loss]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch,
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
# images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
f"G_{global_step}.pth",
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(progress, Progress):
|
||||
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"G_233333333333.pth",
|
||||
),
|
||||
)
|
||||
step_task = None
|
||||
|
||||
if device.index == 0 and hps.train.if_save_every_weights is True:
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
train_loader
|
||||
):
|
||||
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
|
||||
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
|
||||
ssl = ssl.to(device, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
|
||||
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if device.index == 0 and progress is not None and step_task is not None:
|
||||
progress.advance(step_task, 1)
|
||||
|
||||
if device.index == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
||||
losses = [cfm_loss]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch,
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
# images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if hps.train.if_save_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
f"G_{global_step}.pth",
|
||||
),
|
||||
logger,
|
||||
)
|
||||
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
|
||||
"G_233333333333.pth",
|
||||
),
|
||||
logger,
|
||||
)
|
||||
|
||||
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
|
||||
if hps.train.if_save_every_weights is True:
|
||||
if hasattr(net_g, "module"):
|
||||
ckpt = net_g.module.state_dict()
|
||||
else:
|
||||
|
||||
@ -12,14 +12,13 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from rich import print
|
||||
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeRemainingColumn
|
||||
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from torch.amp.grad_scaler import GradScaler
|
||||
from torch.multiprocessing.spawn import spawn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
import GPT_SoVITS.utils as utils
|
||||
from GPT_SoVITS.Accelerate import console, logger
|
||||
@ -234,9 +233,10 @@ def run(rank, n_gpus, hps):
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
redirect_stderr=True,
|
||||
redirect_stdout=True,
|
||||
)
|
||||
if rank == 0
|
||||
else nullcontext() as progress
|
||||
@ -247,7 +247,6 @@ def run(rank, n_gpus, hps):
|
||||
total=int(hps.train.epochs),
|
||||
completed=int(epoch_str) - 1,
|
||||
)
|
||||
step_task: TaskID | None = progress.add_task("steps", total=len(train_loader))
|
||||
else:
|
||||
epoch_task = step_task = None
|
||||
|
||||
@ -255,8 +254,7 @@ def run(rank, n_gpus, hps):
|
||||
if rank == 0:
|
||||
assert epoch_task is not None
|
||||
assert progress is not None
|
||||
assert step_task is not None
|
||||
progress.reset(step_task, total=len(train_loader))
|
||||
progress.advance(epoch_task, 1)
|
||||
train_and_evaluate(
|
||||
device,
|
||||
epoch,
|
||||
@ -268,9 +266,7 @@ def run(rank, n_gpus, hps):
|
||||
(train_loader, None),
|
||||
logger,
|
||||
(writer, writer_eval),
|
||||
(progress, step_task),
|
||||
)
|
||||
progress.advance(epoch_task, 1)
|
||||
else:
|
||||
train_and_evaluate(
|
||||
device,
|
||||
@ -283,10 +279,11 @@ def run(rank, n_gpus, hps):
|
||||
(train_loader, None),
|
||||
None,
|
||||
(None, None),
|
||||
(None, None),
|
||||
)
|
||||
scheduler_g.step()
|
||||
if rank == 0:
|
||||
assert progress
|
||||
progress.stop()
|
||||
logger.info("Training Done")
|
||||
sys.exit(0)
|
||||
|
||||
@ -302,84 +299,97 @@ def train_and_evaluate(
|
||||
loaders,
|
||||
logger,
|
||||
writers,
|
||||
progress_group: tuple[Progress, TaskID] | tuple[None, None],
|
||||
):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
train_loader, eval_loader = loaders
|
||||
writer, writer_eval = writers
|
||||
progress, step_task = progress_group
|
||||
|
||||
train_loader.batch_sampler.set_epoch(epoch)
|
||||
global global_step
|
||||
|
||||
net_g.train()
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
tqdm(train_loader)
|
||||
|
||||
with (
|
||||
Progress(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
SpeedColumnIteration(show_speed=True),
|
||||
TimeRemainingColumn(elapsed_when_finished=True),
|
||||
console=console,
|
||||
redirect_stderr=True,
|
||||
redirect_stdout=True,
|
||||
transient=not (int(epoch) == int(hps.train.epochs)),
|
||||
)
|
||||
if device.index == 0
|
||||
else nullcontext() as progress
|
||||
):
|
||||
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
|
||||
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
|
||||
ssl = ssl.to(device, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
|
||||
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if device.index == 0:
|
||||
assert progress
|
||||
assert step_task
|
||||
progress.advance(step_task, 1)
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
losses = [cfm_loss]
|
||||
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(save_root, f"G_{global_step}.pth"),
|
||||
)
|
||||
if isinstance(progress, Progress):
|
||||
step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader))
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(save_root, "G_233333333333.pth"),
|
||||
)
|
||||
if device.index == 0 and hps.train.if_save_every_weights is True:
|
||||
step_task = None
|
||||
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
train_loader
|
||||
):
|
||||
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
|
||||
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
|
||||
ssl = ssl.to(device, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
|
||||
|
||||
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if device.index == 0 and progress is not None and step_task is not None:
|
||||
progress.advance(step_task, 1)
|
||||
|
||||
if device.index == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
losses = [cfm_loss]
|
||||
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if hps.train.if_save_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(save_root, f"G_{global_step}.pth"), logger
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(save_root, "G_233333333333.pth"), logger
|
||||
)
|
||||
|
||||
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
|
||||
if hps.train.if_save_every_weights is True:
|
||||
if hasattr(net_g, "module"):
|
||||
ckpt = net_g.module.state_dict()
|
||||
else:
|
||||
|
||||
@ -70,7 +70,7 @@ def save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
||||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, logger):
|
||||
logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path))
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
|
||||
12
webui.py
12
webui.py
@ -605,7 +605,7 @@ def open1Bb(
|
||||
data["output_dir"] = f"{s1_dir}/logs_s1_{version}"
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu_numbers1Ba).strip("[]").replace(" ", "")
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu_numbers).strip("[]").replace(" ", "")
|
||||
|
||||
tmp_config_path = f"{tmp}/tmp_s1.yaml"
|
||||
with open(tmp_config_path, "w") as f:
|
||||
@ -890,7 +890,6 @@ def open1b(
|
||||
cmd = [
|
||||
python_exec, "-s", "GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py",
|
||||
"--inp-list", inp_text,
|
||||
"--wav-dir", inp_wav_dir,
|
||||
"--opt", opt_dir,
|
||||
"--cnhubert", ssl_pretrained_dir,
|
||||
"--device", infer_device.type,
|
||||
@ -899,6 +898,9 @@ def open1b(
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
if inp_wav_dir:
|
||||
cmd.extend(["--wav-dir", inp_wav_dir])
|
||||
|
||||
if "Pro" in version:
|
||||
cmd.extend(["--sv", sv_path])
|
||||
|
||||
@ -1124,7 +1126,6 @@ def open1abc(
|
||||
cmd_2 = [
|
||||
python_exec, "-s", "GPT_SoVITS/prepare_datasets/2-get-hubert-sv-wav32k.py",
|
||||
"--inp-list", inp_text,
|
||||
"--wav-dir", inp_wav_dir,
|
||||
"--opt", opt_dir,
|
||||
"--cnhubert", ssl_pretrained_dir,
|
||||
"--device", infer_device.type,
|
||||
@ -1133,6 +1134,9 @@ def open1abc(
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
if inp_wav_dir:
|
||||
cmd_2.extend(["--wav-dir", inp_wav_dir])
|
||||
|
||||
if "Pro" in version:
|
||||
cmd_2.extend(["--sv", sv_path])
|
||||
|
||||
@ -1231,7 +1235,7 @@ def close1abc():
|
||||
kill_process(p1abc.pid, process_name_1abc)
|
||||
except Exception as _:
|
||||
traceback.print_exc()
|
||||
ps1abc = []
|
||||
ps1abc = [None] * 3
|
||||
return (
|
||||
process_info(process_name_1abc, "closed"),
|
||||
gr.update(visible=True),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user