mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-03 20:40:30 +08:00
138 lines
4.1 KiB
Python
138 lines
4.1 KiB
Python
import os
|
|
import sys
|
|
|
|
|
|
inp_text = os.environ.get("inp_text")
|
|
inp_wav_dir = os.environ.get("inp_wav_dir")
|
|
exp_name = os.environ.get("exp_name")
|
|
i_part = os.environ.get("i_part")
|
|
all_parts = os.environ.get("all_parts")
|
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
|
from feature_extractor import cnhubert
|
|
|
|
|
|
opt_dir = os.environ.get("opt_dir")
|
|
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
|
import torch
|
|
|
|
|
|
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
|
|
|
import traceback
|
|
|
|
import librosa
|
|
import numpy as np
|
|
from scipy.io import wavfile
|
|
|
|
|
|
now_dir = os.getcwd()
|
|
sys.path.append(now_dir)
|
|
import shutil
|
|
|
|
# from config import cnhubert_base_path
|
|
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
|
# inp_text=sys.argv[1]
|
|
# inp_wav_dir=sys.argv[2]
|
|
# exp_name=sys.argv[3]
|
|
# i_part=sys.argv[4]
|
|
# all_parts=sys.argv[5]
|
|
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
|
|
# cnhubert.cnhubert_base_path=sys.argv[7]
|
|
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
|
from time import time as ttime
|
|
|
|
from gsv_tools.my_utils import clean_path, load_audio
|
|
|
|
|
|
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
|
dir = os.path.dirname(path)
|
|
name = os.path.basename(path)
|
|
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
|
tmp_path = f"{ttime()}{i_part}.pth"
|
|
torch.save(fea, tmp_path)
|
|
shutil.move(tmp_path, f"{dir}/{name}")
|
|
|
|
|
|
hubert_dir = f"{opt_dir}/4-cnhubert"
|
|
wav32dir = f"{opt_dir}/5-wav32k"
|
|
os.makedirs(opt_dir, exist_ok=True)
|
|
os.makedirs(hubert_dir, exist_ok=True)
|
|
os.makedirs(wav32dir, exist_ok=True)
|
|
|
|
maxx = 0.95
|
|
alpha = 0.5
|
|
if torch.cuda.is_available():
|
|
device = "cuda:0"
|
|
# elif torch.backends.mps.is_available():
|
|
# device = "mps"
|
|
else:
|
|
device = "cpu"
|
|
model = cnhubert.get_model()
|
|
# is_half=False
|
|
if is_half:
|
|
model = model.half().to(device)
|
|
else:
|
|
model = model.to(device)
|
|
|
|
nan_fails = []
|
|
|
|
|
|
def name2go(wav_name, wav_path):
|
|
hubert_path = f"{hubert_dir}/{wav_name}.pt"
|
|
if os.path.exists(hubert_path):
|
|
return
|
|
tmp_audio = load_audio(wav_path, 32000)
|
|
tmp_max = np.abs(tmp_audio).max()
|
|
if tmp_max > 2.2:
|
|
print(f"{wav_name}-filtered,{tmp_max}")
|
|
return
|
|
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
|
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
|
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
|
|
tensor_wav16 = torch.from_numpy(tmp_audio)
|
|
if is_half:
|
|
tensor_wav16 = tensor_wav16.half().to(device)
|
|
else:
|
|
tensor_wav16 = tensor_wav16.to(device)
|
|
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
|
|
if np.isnan(ssl.detach().numpy()).sum() != 0:
|
|
nan_fails.append((wav_name, wav_path))
|
|
print(f"nan filtered:{wav_name}")
|
|
return
|
|
wavfile.write(
|
|
f"{wav32dir}/{wav_name}",
|
|
32000,
|
|
tmp_audio32.astype("int16"),
|
|
)
|
|
my_save(ssl, hubert_path)
|
|
|
|
|
|
with open(inp_text, encoding="utf8") as f:
|
|
lines = f.read().strip("\n").split("\n")
|
|
|
|
for line in lines[int(i_part) :: int(all_parts)]:
|
|
try:
|
|
# wav_name,text=line.split("\t")
|
|
wav_name, spk_name, language, text = line.split("|")
|
|
wav_name = clean_path(wav_name)
|
|
if inp_wav_dir != "" and inp_wav_dir is not None:
|
|
wav_name = os.path.basename(wav_name)
|
|
wav_path = f"{inp_wav_dir}/{wav_name}"
|
|
|
|
else:
|
|
wav_path = wav_name
|
|
wav_name = os.path.basename(wav_name)
|
|
name2go(wav_name, wav_path)
|
|
except:
|
|
print(line, traceback.format_exc())
|
|
|
|
if len(nan_fails) > 0 and is_half:
|
|
is_half = False
|
|
model = model.float()
|
|
for wav in nan_fails:
|
|
try:
|
|
name2go(wav[0], wav[1])
|
|
except:
|
|
print(wav_name, traceback.format_exc())
|