mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-30 01:25:58 +08:00
321 lines
10 KiB
Python
321 lines
10 KiB
Python
import enum
|
|
import os
|
|
import os.path as osp
|
|
import platform
|
|
import queue
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.multiprocessing as tmp
|
|
import typer
|
|
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
|
from torch.multiprocessing.spawn import spawn
|
|
from transformers import BertForMaskedLM, BertTokenizerFast
|
|
|
|
from GPT_SoVITS.Accelerate.logger import console, logger, SpeedColumnIteration
|
|
from GPT_SoVITS.text.cleaner import clean_text
|
|
from tools.my_utils import clean_path
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
tmp.set_start_method("spawn", force=True)
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="jieba_fast._compat")
|
|
|
|
|
|
class Device(str, enum.Enum):
|
|
cpu = "cpu"
|
|
cuda = "cuda"
|
|
mps = "mps"
|
|
|
|
|
|
app = typer.Typer(
|
|
context_settings={"help_option_names": ["-h", "--help"]},
|
|
add_completion=False,
|
|
)
|
|
|
|
|
|
def lang_map(lang: str) -> str:
|
|
m = {
|
|
"ZH": "zh",
|
|
"zh": "zh",
|
|
"JP": "ja",
|
|
"jp": "ja",
|
|
"JA": "ja",
|
|
"ja": "ja",
|
|
"EN": "en",
|
|
"en": "en",
|
|
"En": "en",
|
|
"KO": "ko",
|
|
"Ko": "ko",
|
|
"ko": "ko",
|
|
"yue": "yue",
|
|
"YUE": "yue",
|
|
"Yue": "yue",
|
|
}
|
|
return m.get(lang, "")
|
|
|
|
|
|
def parse_inp_text_line(line: str) -> tuple[str, str, str]:
|
|
wav_name, _, language, text = line.split("|", 3)
|
|
return wav_name, language, text
|
|
|
|
|
|
def build_device_strings(device_type: str, device_ids: list[int], procs_per_device: int) -> list[str]:
|
|
devices: list[str] = []
|
|
for device_id in device_ids:
|
|
dstr = f"{device_type}:{device_id}"
|
|
devices.extend([dstr] * procs_per_device)
|
|
return devices
|
|
|
|
|
|
def worker_entry(
|
|
rank: int,
|
|
device_strs: List[str],
|
|
tasks_q: "tmp.Queue[tuple[int, str, str, str] | None]",
|
|
results_q: "tmp.Queue[tuple[int, tuple[str, str, list[int] | None, str]]]",
|
|
bert_pretrained_dir: str,
|
|
opt_dir: str,
|
|
fp16: bool,
|
|
version: str | None,
|
|
):
|
|
device_str = device_strs[rank]
|
|
device = torch.device(device_str)
|
|
|
|
if device.type == "cuda":
|
|
assert torch.cuda.is_available()
|
|
torch.cuda.set_device(device.index)
|
|
elif device.type == "mps":
|
|
assert torch.mps.is_available()
|
|
elif device.type == "xpu":
|
|
assert torch.xpu.is_available()
|
|
|
|
bert_dir = osp.join(opt_dir, "3-bert")
|
|
os.makedirs(bert_dir, exist_ok=True)
|
|
|
|
if not osp.exists(bert_pretrained_dir):
|
|
raise FileNotFoundError(bert_pretrained_dir)
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained(bert_pretrained_dir)
|
|
bert_model = BertForMaskedLM.from_pretrained(bert_pretrained_dir, device_map=device)
|
|
|
|
if fp16:
|
|
bert_model = bert_model.half()
|
|
|
|
def get_bert_feature(text: str, word2ph: list[int]) -> torch.Tensor:
|
|
inputs = tokenizer(text, return_tensors="pt")
|
|
for k in inputs:
|
|
inputs[k] = inputs[k].to(device)
|
|
out: torch.Tensor = bert_model(**inputs, output_hidden_states=True).hidden_states # type: ignore
|
|
layer = out[-3][0].cpu()[1:-1] # [seq-2, hid]
|
|
assert len(word2ph) == len(text)
|
|
phone_level_feature = []
|
|
for i in range(len(word2ph)):
|
|
phone_level_feature.append(layer[i].repeat(word2ph[i], 1))
|
|
feats = torch.cat(phone_level_feature, dim=0) # [phones, hid]
|
|
return feats.T # [hid, phones]
|
|
|
|
i = 0
|
|
while True:
|
|
item = tasks_q.get()
|
|
if item is None:
|
|
break
|
|
|
|
idx, wav_name, language, text = item
|
|
|
|
i += 1
|
|
if i % 10 == 0:
|
|
match device.index:
|
|
case "cuda":
|
|
torch.cuda.empty_cache()
|
|
case "mps":
|
|
torch.mps.empty_cache()
|
|
case "xpu":
|
|
torch.xpu.empty_cache()
|
|
|
|
try:
|
|
name = clean_path(osp.basename(wav_name))
|
|
mapped_lang = lang_map(language)
|
|
if not mapped_lang:
|
|
logger.warning(f"[W{rank}] Unsupported language: {language} of {wav_name}")
|
|
results_q.put((idx, ("", "", [], "")))
|
|
continue
|
|
|
|
phones, word2ph, norm_text = clean_text(
|
|
text.replace("%", "-").replace("¥", ","),
|
|
mapped_lang,
|
|
version,
|
|
)
|
|
|
|
if mapped_lang == "zh":
|
|
path_bert = osp.join(bert_dir, f"{name}.pt")
|
|
if not osp.exists(path_bert):
|
|
assert word2ph
|
|
bert_feature = get_bert_feature(norm_text, word2ph)
|
|
assert bert_feature.shape[-1] == len(phones)
|
|
torch.save(bert_feature, path_bert)
|
|
|
|
phones_str = " ".join(phones)
|
|
results_q.put((idx, (name, phones_str, word2ph, norm_text)))
|
|
except Exception as e:
|
|
del (
|
|
device_str,
|
|
tokenizer,
|
|
bert_model,
|
|
bert_dir,
|
|
bert_pretrained_dir,
|
|
tasks_q,
|
|
results_q,
|
|
opt_dir,
|
|
item,
|
|
idx,
|
|
i,
|
|
)
|
|
logger.exception(f"[W{rank}] Failed: {wav_name} | {text}")
|
|
raise e
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
@app.command()
|
|
def main(
|
|
inp_list: Path = typer.Option(
|
|
...,
|
|
"--inp-list",
|
|
file_okay=True,
|
|
dir_okay=False,
|
|
exists=True,
|
|
readable=True,
|
|
show_default=False,
|
|
help="list File: wav|spk|lang|text",
|
|
),
|
|
opt: Path = typer.Option(
|
|
..., "--opt", file_okay=False, dir_okay=True, writable=True, show_default=False, help="Output Directory"
|
|
),
|
|
bert: Path = typer.Option(
|
|
..., "--bert", exists=True, readable=True, show_default=False, help="Path to Bert Pretrained Models"
|
|
),
|
|
version: str = typer.Option("v2", "--version", help="SoVITS Language Version"),
|
|
device: Device = typer.Option(Device.cpu, "--device", help="Compute device"),
|
|
device_id: str = typer.Option("0", "--device-id", help="CUDA_VISIBLE_DEVICE, Such as '0,1,2'"),
|
|
nproc: int = typer.Option(1, "--nproc", min=1, help="Number of processes per GPU"),
|
|
fp16: bool = typer.Option(False, is_flag=True, flag_value=True, help="Use FP16"),
|
|
):
|
|
device_ids = [int(x) for x in device_id.split(",") if x.strip() != ""]
|
|
if device in {"cpu", "mps"} and device_ids != [0]:
|
|
raise ValueError(f"Invalid Device ID {device_ids}")
|
|
if nproc < 1:
|
|
raise ValueError(f"Invalid Num Process {nproc}")
|
|
|
|
os.makedirs(opt, exist_ok=True)
|
|
merged_path = osp.join(opt, "2-name2text.txt")
|
|
|
|
with open(inp_list, "r", encoding="utf8") as f:
|
|
lines = [ln for ln in f.read().splitlines() if ln.strip()]
|
|
|
|
tasks_all: list[tuple[int, str, str, str]] = []
|
|
for idx, line in enumerate(lines):
|
|
try:
|
|
wav_name, language, text = parse_inp_text_line(line)
|
|
tasks_all.append((idx, wav_name, language, text))
|
|
except Exception:
|
|
logger.exception(f"Skip line {idx}: {line}")
|
|
|
|
n_tasks = len(tasks_all)
|
|
if n_tasks == 0:
|
|
logger.warning("Empty list")
|
|
with open(merged_path, "w", encoding="utf8") as fout:
|
|
pass
|
|
return
|
|
|
|
device_strs = build_device_strings(device, device_ids, nproc)
|
|
world_size = len(device_strs)
|
|
|
|
tasks_q: "tmp.Queue[tuple[int, str, str, str] | None]" = tmp.Queue()
|
|
results_q: "tmp.Queue[tuple[int, tuple[str, str, list[int] | None, str]]]" = tmp.Queue()
|
|
|
|
for task in tasks_all:
|
|
tasks_q.put(task)
|
|
for _ in range(world_size):
|
|
tasks_q.put(None)
|
|
|
|
ordered: list[tuple[str, str, list[int] | None, str]] = [("", "", [], "")] * n_tasks
|
|
completed = 0
|
|
|
|
with Progress(
|
|
TextColumn("[cyan]{task.description}"),
|
|
BarColumn(),
|
|
TextColumn("{task.completed}/{task.total}"),
|
|
SpeedColumnIteration(show_speed=True),
|
|
TimeRemainingColumn(elapsed_when_finished=True),
|
|
console=console,
|
|
redirect_stderr=False,
|
|
redirect_stdout=False,
|
|
) as progress:
|
|
progress_task = progress.add_task("G2P & Extract Bert", total=n_tasks)
|
|
|
|
ctx = spawn(
|
|
worker_entry,
|
|
args=(device_strs, tasks_q, results_q, bert, opt, fp16, version),
|
|
nprocs=world_size,
|
|
join=False,
|
|
daemon=False,
|
|
)
|
|
assert ctx
|
|
|
|
while completed < n_tasks:
|
|
try:
|
|
idx, tup = results_q.get(timeout=0.01)
|
|
ordered[idx] = tup
|
|
completed += 1
|
|
progress.update(progress_task, advance=1)
|
|
except queue.Empty:
|
|
pass
|
|
|
|
for p in ctx.processes:
|
|
assert p
|
|
if (p.exitcode is not None and p.exitcode != 0) or (not p.is_alive()):
|
|
progress.live.stop()
|
|
try:
|
|
ctx.join()
|
|
except Exception as e:
|
|
console.print(e)
|
|
finally:
|
|
logger.critical(f"Worker PID {p.pid} crashed with exit code {p.exitcode}.")
|
|
sys.exit(1)
|
|
ctx.join()
|
|
|
|
with open(merged_path, "w", encoding="utf8") as fout:
|
|
for name, phones_str, word2ph, norm_text in ordered:
|
|
if name:
|
|
fout.write(f"{name}\t{phones_str}\t{word2ph}\t{norm_text}\n")
|
|
|
|
logger.info(f"Done: {merged_path}")
|
|
|
|
|
|
def is_powershell_env(env: dict) -> bool:
|
|
return any(k in env for k in ("PSHOME", "POWERSHELL_DISTRIBUTION_CHANNEL", "PSModulePath"))
|
|
|
|
|
|
def get_prog_name() -> str:
|
|
system = platform.system()
|
|
env = os.environ.copy()
|
|
script_rel = osp.join("GPT_SoVITS", "prepare_datasets", osp.basename(__file__))
|
|
if system == "Windows":
|
|
if is_powershell_env(env):
|
|
return rf"$env:PYTHONPATH='.'; python -s {script_rel}"
|
|
else:
|
|
return rf"set PYTHONPATH=. && python -s {script_rel}"
|
|
else:
|
|
return f"PYTHONPATH=. python -s {script_rel}"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
t = time.perf_counter()
|
|
app(prog_name=get_prog_name())
|
|
logger.info(f"Exec Time: {time.perf_counter() - t:.2f} secs")
|