GPT-SoVITS/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py
2025-03-19 00:24:44 -07:00

159 lines
4.8 KiB
Python

"""
Step 2 of data preparation: Extract HuBERT features from the audio files,
and resample the audio to 32kHz and saving it.
"""
import os
import shutil
import sys
import traceback
from time import time as ttime
import librosa
import numpy as np
import torch
from feature_extractor import cnhubert
from scipy.io import wavfile
from tqdm import tqdm
from tools.my_utils import clean_path, load_audio
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
now_dir = os.getcwd()
sys.path.append(now_dir)
# from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path
# inp_text=sys.argv[1]
# inp_wav_dir=sys.argv[2]
# exp_name=sys.argv[3]
# i_part=sys.argv[4]
# all_parts=sys.argv[5]
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
# cnhubert.cnhubert_base_path=sys.argv[7]
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
def my_save(fea, path): # fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path = f"{ttime()}{i_part}.pth"
torch.save(fea, tmp_path)
shutil.move(tmp_path, f"{dir}/{name}")
hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
if torch.cuda.is_available():
device = "cuda:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
model = cnhubert.get_model()
# is_half=False
if (is_half == True):
model = model.half().to(device)
else:
model = model.to(device)
nan_fails = []
def name2go(wav_name, wav_path):
"""
Extract HuBERT features from the audio files, and resample the audio to 32kHz and saving it.
"""
# Skip if the file already exists
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if (os.path.exists(hubert_path)):
return
# Load the audio file in 32kHz sampling rate
tmp_audio = load_audio(wav_path, 32000)
# Check the maximum amplitude of the audio file
tmp_max = np.abs(tmp_audio).max()
# Skip if the maximum amplitude is too high (volume is too loud)
if tmp_max > 2.2:
print(f"{wav_name}-filtered,{tmp_max}")
return
# Normalize the audio
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
tmp_audio = librosa.resample(
tmp_audio32b, orig_sr=32000, target_sr=16000
) # 不是重采样问题
tensor_wav16 = torch.from_numpy(tmp_audio)
# if half-precision is enabled, convert the tensor to half-precision
if is_half:
tensor_wav16 = tensor_wav16.half().to(device)
else:
tensor_wav16 = tensor_wav16.to(device)
# Extract HuBERT features from the audio file
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
nan_fails.append((wav_name, wav_path))
print(f"nan filtered:{wav_name}")
return
wavfile.write(
f"{wav32dir}/{wav_name}",
32000,
tmp_audio32.astype("int16"),
)
my_save(ssl, hubert_path)
with open(inp_text, "r", encoding="utf8")as f:
lines = f.read().strip("\n").split("\n")
for line in tqdm(lines[int(i_part)::int(all_parts)]):
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name = clean_path(wav_name)
if (inp_wav_dir != "" and inp_wav_dir is not None):
wav_name = os.path.basename(wav_name)
wav_path = f"{inp_wav_dir}/{wav_name}"
else:
wav_path = wav_name
wav_name = os.path.basename(wav_name)
name2go(wav_name, wav_path)
except:
print(line, traceback.format_exc())
if (len(nan_fails) > 0 and is_half):
is_half = False
model = model.float()
for wav in nan_fails:
try:
name2go(wav[0], wav[1])
except:
print(wav_name, traceback.format_exc())