tqdm progress bar

This commit is contained in:
XXXXRT666 2024-08-26 02:06:35 +08:00
parent fc2161b484
commit be200d182a
4 changed files with 16 additions and 15 deletions

View File

@ -84,11 +84,10 @@ if os.path.exists(txt_path) == False:
return phone_level_feature.T return phone_level_feature.T
def process(data, res): def process(data, res):
for name, text, lan in data: for name, text, lan in tqdm(data,position=int(i_part),delay=0.5):
try: try:
name=clean_path(name) name=clean_path(name)
name = os.path.basename(name) name = os.path.basename(name)
print(name)
phones, word2ph, norm_text = clean_text( phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan, version text.replace("%", "-").replace("", ","), lan, version
) )
@ -102,7 +101,7 @@ if os.path.exists(txt_path) == False:
# res.append([name,phones]) # res.append([name,phones])
res.append([name, phones, word2ph, norm_text]) res.append([name, phones, word2ph, norm_text])
except: except:
print(name, text, traceback.format_exc()) tqdm.write(name, text, traceback.format_exc())
todo = [] todo = []
res = [] res = []
@ -135,9 +134,9 @@ if os.path.exists(txt_path) == False:
[wav_name, text, language_v1_to_language_v2.get(language, language)] [wav_name, text, language_v1_to_language_v2.get(language, language)]
) )
else: else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m") tqdm.write(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except: except:
print(line, traceback.format_exc()) tqdm.write(line, traceback.format_exc())
process(todo, res) process(todo, res)
opt = [] opt = []

View File

@ -20,6 +20,7 @@ import librosa
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from tools.my_utils import load_audio,clean_path from tools.my_utils import load_audio,clean_path
from tqdm import tqdm
# from config import cnhubert_base_path # from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path # cnhubert.cnhubert_base_path=cnhubert_base_path
@ -70,7 +71,7 @@ def name2go(wav_name,wav_path):
tmp_audio = load_audio(wav_path, 32000) tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max() tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2: if tmp_max > 2.2:
print("%s-filtered,%s" % (wav_name, tmp_max)) tqdm.write("%s-filtered,%s" % (wav_name, tmp_max))
return return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
@ -85,7 +86,7 @@ def name2go(wav_name,wav_path):
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0: if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append((wav_name,wav_path)) nan_fails.append((wav_name,wav_path))
print("nan filtered:%s"%wav_name) tqdm.write("nan filtered:%s"%wav_name)
return return
wavfile.write( wavfile.write(
"%s/%s"%(wav32dir,wav_name), "%s/%s"%(wav32dir,wav_name),
@ -97,7 +98,7 @@ def name2go(wav_name,wav_path):
with open(inp_text,"r",encoding="utf8")as f: with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n") lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]: for line in tqdm(lines[int(i_part)::int(all_parts)],position=int(i_part)):
try: try:
# wav_name,text=line.split("\t") # wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
@ -111,13 +112,13 @@ for line in lines[int(i_part)::int(all_parts)]:
wav_name = os.path.basename(wav_name) wav_name = os.path.basename(wav_name)
name2go(wav_name,wav_path) name2go(wav_name,wav_path)
except: except:
print(line,traceback.format_exc()) tqdm.write(line,traceback.format_exc())
if(len(nan_fails)>0 and is_half==True): if(len(nan_fails)>0 and is_half==True):
is_half=False is_half=False
model=model.float() model=model.float()
for wav in nan_fails: for wav in tqdm(nan_fails,position=int(i_part)):
try: try:
name2go(wav[0],wav[1]) name2go(wav[0],wav[1])
except: except:
print(wav_name,traceback.format_exc()) tqdm.write(wav_name,traceback.format_exc())

View File

@ -25,6 +25,7 @@ from tqdm import tqdm
import logging, librosa, utils import logging, librosa, utils
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
from tools.my_utils import clean_path from tools.my_utils import clean_path
from tqdm import tqdm
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G # from config import pretrained_s2G
@ -87,8 +88,8 @@ if os.path.exists(semantic_path) == False:
lines = f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")
lines1 = [] lines1 = []
for line in lines[int(i_part) :: int(all_parts)]: for line in tqdm(lines[int(i_part) :: int(all_parts)],position=int(i_part)):
# print(line) # tqdm.write(line)
try: try:
# wav_name,text=line.split("\t") # wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
@ -97,6 +98,6 @@ if os.path.exists(semantic_path) == False:
# name2go(name,lines1) # name2go(name,lines1)
name2go(wav_name, lines1) name2go(wav_name, lines1)
except: except:
print(line, traceback.format_exc()) tqdm.write(line, traceback.format_exc())
with open(semantic_path, "w", encoding="utf8") as f: with open(semantic_path, "w", encoding="utf8") as f:
f.write("\n".join(lines1)) f.write("\n".join(lines1))

View File

@ -306,7 +306,7 @@ def train_and_evaluate(
y_lengths, y_lengths,
text, text,
text_lengths, text_lengths,
) in enumerate(tqdm(train_loader)): ) in enumerate(tqdm(train_loader,position=rank+1,leave=(epoch==hps.train.epochs),postfix=f'epoch:{epoch}')):
if torch.cuda.is_available(): if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True rank, non_blocking=True