2024-01-16 22:25:15 +08:00

118 lines
4.4 KiB
Python

# -*- coding: utf-8 -*-
import os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir= os.environ.get("opt_dir")
bert_pretrained_dir= os.environ.get("bert_pretrained_dir")
is_half=eval(os.environ.get("is_half","True"))
import sys,numpy as np,traceback,pdb
import os.path
from glob import glob
from tqdm import tqdm
from text.cleaner import clean_text
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
# inp_text=sys.argv[1]
# inp_wav_dir=sys.argv[2]
# exp_name=sys.argv[3]
# i_part=sys.argv[4]
# all_parts=sys.argv[5]
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
# bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
if(os.path.exists(txt_path)==False):
bert_dir="%s/3-bert"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(bert_dir,exist_ok=True)
device="cuda:0"
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model=AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if (is_half == True):
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def process(data,res):
for name,text,lan in data:
try:
name=os.path.basename(name)
phones, word2ph, norm_text=clean_text(text.replace("%", '-').replace('', ','),lan)
path_bert="%s/%s.pt"%(bert_dir,name)
if (os.path.exists(path_bert) == False and lan == "zh"):
bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones)
# torch.save(bert_feature, path_bert)
my_save(bert_feature, path_bert)
phones = " ".join(phones)
# res.append([name,phones])
res.append([name,phones, word2ph, norm_text])
except:
print(name, text, traceback.format_exc())
todo=[]
res=[]
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
language_v1_to_language_v2={
"ZH":"zh",
"zh":"zh",
"JP":"ja",
"jp":"ja",
"JA":"ja",
"ja":"ja",
"EN":"en",
"en":"en",
"En":"en",
}
for line in lines[int(i_part)::int(all_parts)]:
try:
wav_name,spk_name,language,text=line.split("|")
# todo.append([name,text,"zh"])
todo.append([wav_name,text,language_v1_to_language_v2.get(language,language)])
except:
print(line,traceback.format_exc())
process(todo,res)
opt=[]
for name,phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s"%(name,phones, word2ph, norm_text))
with open(txt_path,"w",encoding="utf8")as f:
f.write("\n".join(opt)+"\n")