mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
144 lines
4.6 KiB
Python
144 lines
4.6 KiB
Python
import argparse
|
|
import random
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import soundfile
|
|
import torch
|
|
from deepspeed import DeepSpeedConfig
|
|
from torch import Tensor
|
|
from tqdm import tqdm
|
|
|
|
from ..data import create_dataloaders, mix_fg_bg
|
|
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
|
|
from ..utils.distributed import is_local_leader
|
|
from .enhancer import Enhancer
|
|
from .hparams import HParams
|
|
from .univnet.discriminator import Discriminator
|
|
|
|
|
|
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
|
|
if hp is None:
|
|
hp = HParams.load(run_dir)
|
|
assert isinstance(hp, HParams)
|
|
model = Enhancer(hp)
|
|
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
|
|
if training:
|
|
engine.load_checkpoint()
|
|
else:
|
|
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
|
|
return engine
|
|
|
|
|
|
def load_D(run_dir: Path, hp: HParams):
|
|
if hp is None:
|
|
hp = HParams.load(run_dir)
|
|
assert isinstance(hp, HParams)
|
|
model = Discriminator(hp)
|
|
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D")
|
|
engine.load_checkpoint()
|
|
return engine
|
|
|
|
|
|
def save_wav(path: Path, wav: Tensor, rate: int):
|
|
wav = wav.detach().cpu().numpy()
|
|
soundfile.write(path, wav, samplerate=rate)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("run_dir", type=Path)
|
|
parser.add_argument("--yaml", type=Path, default=None)
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
args = parser.parse_args()
|
|
|
|
setup_logging(args.run_dir)
|
|
hp = HParams.load(args.run_dir, yaml=args.yaml)
|
|
|
|
if is_local_leader():
|
|
hp.save_if_not_exists(args.run_dir)
|
|
hp.print()
|
|
|
|
train_dl, val_dl = create_dataloaders(hp, mode="enhancer")
|
|
|
|
def feed_G(engine: Engine, batch: dict[str, Tensor]):
|
|
if hp.lcfm_training_mode == "ae":
|
|
pred = engine(batch["fg_wavs"], batch["fg_wavs"])
|
|
elif hp.lcfm_training_mode == "cfm":
|
|
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
|
|
mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn)
|
|
pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"])
|
|
else:
|
|
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
|
|
losses = engine.gather_attribute("losses")
|
|
return pred, losses
|
|
|
|
def feed_D(engine: Engine, batch: dict | None, fake: Tensor):
|
|
if batch is None:
|
|
losses = engine(fake=fake)
|
|
else:
|
|
losses = engine(fake=fake, real=batch["fg_wavs"])
|
|
return losses
|
|
|
|
@torch.no_grad()
|
|
def eval_fn(engine: Engine, eval_dir, n_saved=10):
|
|
assert isinstance(hp, HParams)
|
|
|
|
model = engine.module
|
|
model.eval()
|
|
|
|
step = engine.global_step
|
|
|
|
for i, batch in enumerate(tqdm(val_dl), 1):
|
|
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
|
|
|
|
fg_wavs = batch["fg_wavs"] # 1 t
|
|
|
|
if hp.lcfm_training_mode == "ae":
|
|
in_dwavs = fg_wavs
|
|
elif hp.lcfm_training_mode == "cfm":
|
|
in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"])
|
|
else:
|
|
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
|
|
|
|
pred_fg_wavs = model(in_dwavs) # 1 t
|
|
|
|
in_mels = model.to_mel(in_dwavs) # 1 c t
|
|
fg_mels = model.to_mel(fg_wavs) # 1 c t
|
|
pred_fg_mels = model.to_mel(pred_fg_wavs) # 1 c t
|
|
|
|
rate = model.hp.wav_rate
|
|
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
|
|
|
|
save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate)
|
|
save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate)
|
|
save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate)
|
|
|
|
save_mels(
|
|
get_path(".png"),
|
|
cond_mel=in_mels[0].cpu().numpy(),
|
|
pred_mel=pred_fg_mels[0].cpu().numpy(),
|
|
targ_mel=fg_mels[0].cpu().numpy(),
|
|
)
|
|
|
|
if i >= n_saved:
|
|
break
|
|
|
|
train_loop = TrainLoop(
|
|
run_dir=args.run_dir,
|
|
train_dl=train_dl,
|
|
load_G=partial(load_G, hp=hp),
|
|
load_D=partial(load_D, hp=hp),
|
|
device=args.device,
|
|
feed_G=feed_G,
|
|
feed_D=feed_D,
|
|
eval_fn=eval_fn,
|
|
gan_training_start_step=hp.gan_training_start_step,
|
|
)
|
|
|
|
train_loop.run(max_steps=hp.max_steps)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|