mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
305 lines
9.4 KiB
Python
305 lines
9.4 KiB
Python
import enum
|
|
import gc
|
|
import os
|
|
import os.path as osp
|
|
import queue
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
|
|
import torch
|
|
import torch.multiprocessing as tmp
|
|
import typer
|
|
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
|
from torch.multiprocessing.spawn import spawn
|
|
|
|
from GPT_SoVITS.Accelerate.logger import SpeedColumnIteration, console, logger
|
|
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
|
|
from GPT_SoVITS.process_ckpt import inspect_version
|
|
from tools.my_utils import DictToAttrRecursive, clean_path
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
tmp.set_start_method("spawn", force=True)
|
|
|
|
|
|
class Device(str, enum.Enum):
|
|
cpu = "cpu"
|
|
cuda = "cuda"
|
|
mps = "mps"
|
|
|
|
|
|
app = typer.Typer(
|
|
context_settings={"help_option_names": ["-h", "--help"]},
|
|
add_completion=False,
|
|
)
|
|
|
|
|
|
def parse_inp_text_line(line: str) -> str:
|
|
wav_name, _, __, ___ = line.split("|", 3)
|
|
return wav_name
|
|
|
|
|
|
def build_device_strings(device_type: str, device_ids: List[int], procs_per_device: int) -> List[str]:
|
|
devices: List[str] = []
|
|
for device_id in device_ids:
|
|
dstr = f"{device_type}:{device_id}" if device_type in {"cuda", "mps"} else "cpu"
|
|
devices.extend([dstr] * procs_per_device)
|
|
return devices
|
|
|
|
|
|
def worker_entry(
|
|
rank: int,
|
|
device_strs: List[str],
|
|
tasks_q: "tmp.Queue[Tuple[int, str] | None]",
|
|
results_q: "tmp.Queue[Tuple[int, str]]",
|
|
pretrained_s2g: str,
|
|
opt_dir: str,
|
|
fp16: bool,
|
|
):
|
|
device_str = device_strs[rank]
|
|
device = torch.device(device_str)
|
|
|
|
if device.type == "cuda":
|
|
assert torch.cuda.is_available()
|
|
torch.cuda.set_device(device.index)
|
|
elif device.type == "mps":
|
|
assert torch.mps.is_available()
|
|
elif device.type == "xpu":
|
|
assert torch.xpu.is_available()
|
|
|
|
hubert_dir = osp.join(opt_dir, "4-cnhubert")
|
|
os.makedirs(opt_dir, exist_ok=True)
|
|
|
|
if not osp.exists(hubert_dir):
|
|
raise FileNotFoundError(hubert_dir)
|
|
|
|
version, _, _, hps_dict, dict_s2 = inspect_version(pretrained_s2g)
|
|
hps = DictToAttrRecursive(hps_dict)
|
|
|
|
if version in {"v3", "v4"}:
|
|
vq_model: SynthesizerTrn | SynthesizerTrnV3 = SynthesizerTrnV3(
|
|
hps.data.filter_length // 2 + 1,
|
|
hps.train.segment_size // hps.data.hop_length,
|
|
n_speakers=hps.data.n_speakers,
|
|
version=version,
|
|
**hps.model,
|
|
)
|
|
else:
|
|
vq_model = SynthesizerTrn(
|
|
hps.data.filter_length // 2 + 1,
|
|
hps.train.segment_size // hps.data.hop_length,
|
|
n_speakers=hps.data.n_speakers,
|
|
version=version,
|
|
**hps.model,
|
|
)
|
|
|
|
load_result = vq_model.load_state_dict(dict_s2["weight"])
|
|
if rank == 0:
|
|
console.print("")
|
|
console.print(load_result)
|
|
|
|
for name in list(vq_model._modules.keys()):
|
|
if name not in ["quantizer", "ssl_proj"]:
|
|
del vq_model._modules[name]
|
|
del dict_s2
|
|
|
|
if fp16:
|
|
vq_model = vq_model.to(device).half()
|
|
else:
|
|
vq_model.to(device)
|
|
|
|
match device.index:
|
|
case "cuda":
|
|
torch.cuda.empty_cache()
|
|
case "mps":
|
|
torch.mps.empty_cache()
|
|
case "xpu":
|
|
torch.xpu.empty_cache()
|
|
gc.collect()
|
|
|
|
def extract_semantic_from_hubert_pt(wav_basename: str) -> str | None:
|
|
hubert_path = osp.join(hubert_dir, f"{wav_basename}.pt")
|
|
if not osp.exists(hubert_path):
|
|
return None
|
|
|
|
ssl_content: torch.Tensor = torch.load(hubert_path, map_location="cpu")
|
|
if fp16:
|
|
ssl_content = ssl_content.half().to(device)
|
|
else:
|
|
ssl_content = ssl_content.to(device)
|
|
|
|
codes = vq_model.extract_latent(ssl_content)
|
|
vec = codes[0, 0, :].tolist()
|
|
|
|
return " ".join(str(i) for i in vec)
|
|
|
|
i = 0
|
|
while True:
|
|
item = tasks_q.get()
|
|
if item is None:
|
|
break
|
|
|
|
idx, wav_name = item
|
|
|
|
i += 1
|
|
if i % 10 == 0:
|
|
match device.index:
|
|
case "cuda":
|
|
torch.cuda.empty_cache()
|
|
case "mps":
|
|
torch.mps.empty_cache()
|
|
case "xpu":
|
|
torch.xpu.empty_cache()
|
|
|
|
try:
|
|
name = clean_path(osp.basename(wav_name))
|
|
semantic = extract_semantic_from_hubert_pt(name)
|
|
if semantic is None:
|
|
results_q.put((idx, ""))
|
|
else:
|
|
results_q.put((idx, f"{name}\t{semantic}"))
|
|
except Exception as e:
|
|
del device_str, vq_model, hubert_dir, _, version, hps_dict, hps, item, idx, i, load_result
|
|
logger.exception(f"[W{rank}] Failed on: {wav_name}")
|
|
raise e
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
@app.command()
|
|
def main(
|
|
inp_list: Path = typer.Option(
|
|
...,
|
|
"--inp-list",
|
|
file_okay=True,
|
|
dir_okay=False,
|
|
exists=True,
|
|
readable=True,
|
|
show_default=False,
|
|
help="list file: wav|spk|lang|text",
|
|
),
|
|
opt: Path = typer.Option(
|
|
..., "--opt", file_okay=False, dir_okay=True, writable=True, show_default=False, help="Output Directory"
|
|
),
|
|
pretrained_s2g: Path = typer.Option(
|
|
...,
|
|
"--pretrained-s2g",
|
|
file_okay=True,
|
|
dir_okay=False,
|
|
exists=True,
|
|
readable=True,
|
|
show_default=False,
|
|
help="Path to pretrained s2G checkpoint",
|
|
),
|
|
device: Device = typer.Option(Device.cpu, "--device", help="Compute device"),
|
|
device_id: str = typer.Option("0", "--device-id", help="CUDA_VISIBLE_DEVICES style, e.g. '0,1'"),
|
|
nproc: int = typer.Option(1, "--nproc", min=1, help="Processes per device"),
|
|
fp16: bool = typer.Option(True, is_flag=True, flag_value=True, help="Use FP16 on CUDA"),
|
|
):
|
|
device_ids = [int(x) for x in device_id.split(",") if x.strip() != ""]
|
|
if device in {"cpu", "mps"} and device_ids != [0]:
|
|
raise ValueError(f"Invalid Device ID {device_ids}")
|
|
if nproc < 1:
|
|
raise ValueError(f"Invalid Num Process {nproc}")
|
|
|
|
os.makedirs(opt, exist_ok=True)
|
|
merged_path = osp.join(opt, "6-name2semantic.tsv")
|
|
|
|
with open(inp_list, "r", encoding="utf8") as f:
|
|
raw_lines = [ln for ln in f.read().splitlines() if ln.strip()]
|
|
|
|
tasks_all: List[Tuple[int, str]] = []
|
|
for idx, line in enumerate(raw_lines):
|
|
try:
|
|
wav_name = parse_inp_text_line(line)
|
|
tasks_all.append((idx, wav_name))
|
|
except Exception:
|
|
logger.exception(f"Skip line {idx}: {line}")
|
|
|
|
n_tasks = len(tasks_all)
|
|
if n_tasks == 0:
|
|
logger.warning("Empty list")
|
|
with open(merged_path, "w", encoding="utf8") as fout:
|
|
pass
|
|
return
|
|
|
|
device_strs = build_device_strings(device, device_ids, nproc)
|
|
world_size = len(device_strs)
|
|
|
|
tasks_q: "tmp.Queue[Tuple[int, str] | None]" = tmp.Queue()
|
|
results_q: "tmp.Queue[Tuple[int, str]]" = tmp.Queue()
|
|
|
|
for task in tasks_all:
|
|
tasks_q.put(task)
|
|
for _ in range(world_size):
|
|
tasks_q.put(None)
|
|
|
|
ordered: List[str] = [""] * n_tasks
|
|
completed = 0
|
|
|
|
with Progress(
|
|
TextColumn("[cyan]{task.description}"),
|
|
BarColumn(),
|
|
TextColumn("{task.completed}/{task.total}"),
|
|
SpeedColumnIteration(show_speed=True),
|
|
TimeRemainingColumn(elapsed_when_finished=True),
|
|
console=console,
|
|
) as progress:
|
|
progress_task = progress.add_task("Extract Semantic Codes", total=n_tasks)
|
|
|
|
ctx = spawn(
|
|
worker_entry,
|
|
args=(device_strs, tasks_q, results_q, pretrained_s2g, opt, fp16),
|
|
nprocs=world_size,
|
|
join=False,
|
|
daemon=False,
|
|
)
|
|
assert ctx
|
|
|
|
while completed < n_tasks:
|
|
try:
|
|
idx, line = results_q.get(timeout=0.05)
|
|
if line:
|
|
ordered[idx] = line
|
|
completed += 1
|
|
progress.update(progress_task, advance=1)
|
|
except queue.Empty:
|
|
pass
|
|
|
|
for p in ctx.processes:
|
|
assert p
|
|
if (p.exitcode is not None and p.exitcode != 0) or (not p.is_alive()):
|
|
progress.live.stop()
|
|
try:
|
|
ctx.join()
|
|
except Exception as e:
|
|
console.print(e)
|
|
finally:
|
|
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
|
|
sys.exit(1)
|
|
ctx.join()
|
|
|
|
with open(merged_path, "w", encoding="utf8") as fout:
|
|
for line in ordered:
|
|
if line:
|
|
fout.write(line + "\n")
|
|
|
|
logger.info(f"Done: {merged_path}")
|
|
|
|
|
|
def is_powershell_env(env: dict) -> bool:
|
|
return any(k in env for k in ("PSHOME", "POWERSHELL_DISTRIBUTION_CHANNEL", "PSModulePath"))
|
|
|
|
|
|
def get_prog_name() -> str:
|
|
script_rel = ".".join(["GPT_SoVITS", "prepare_datasets", osp.basename(__file__)]).strip(".py")
|
|
return f"python -s -m {script_rel}"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
t = time.perf_counter()
|
|
app(prog_name=get_prog_name())
|
|
logger.info(f"Exec Time: {time.perf_counter() - t:.2f} secs")
|